Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve Avro GenericRecord and SpecificRecord based row-level extractor performance #723

Merged
merged 4 commits into from
Oct 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
package com.linkedin.feathr.common

import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema

/**
* Provides feature values based on some "raw" data element
*
Expand Down Expand Up @@ -39,12 +37,14 @@ trait AnchorExtractor[T] extends AnchorExtractorBase[T] with SparkRowExtractor {
* @param datum input row
* @return list of feature keys
*/
def getKeyFromRow(datum: GenericRowWithSchema): Seq[String] = getKey(datum.asInstanceOf[T])
def getKeyFromRow(datum: Any): Seq[String] = getKey(datum.asInstanceOf[T])

/**
* Get the feature value from the row
* @param datum input row
* @return A map of feature name to feature value
*/
def getFeaturesFromRow(datum: GenericRowWithSchema): Map[String, FeatureValue] = getFeatures(datum.asInstanceOf[T])
def getFeaturesFromRow(datum: Any): Map[String, FeatureValue] = getFeatures(datum.asInstanceOf[T])

override def toString: String = getClass.getSimpleName
jaymo001 marked this conversation as resolved.
Show resolved Hide resolved
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package com.linkedin.feathr.common

import org.apache.avro.generic.IndexedRecord
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame

/**
* If an AnchorExtractor only works on a Avro record, it should extends
* this trait, and use convertToAvroRdd to do a one-time batch conversion of DataFrame to RDD of their choice.
* convertToAvroRdd will be called by Feathr engine before calling getKeyFromRow() and getFeaturesFromRow() in AnchorExtractor.
*/
trait CanConvertToAvroRDD {

/**
* One time batch converting the input data source into a RDD[IndexedRecord] for feature extraction later
* @param df input data source
* @return batch preprocessed dataframe, as RDD[IndexedRecord]
*/
def convertToAvroRdd(df: DataFrame) : RDD[IndexedRecord]
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
package com.linkedin.feathr.common

import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema

/**
* An extractor trait that provides APIs to transform a Spark GenericRowWithSchema into feature values
*/
Expand All @@ -12,12 +10,12 @@ trait SparkRowExtractor {
* @param datum input row
* @return list of feature keys
*/
def getKeyFromRow(datum: GenericRowWithSchema): Seq[String]
def getKeyFromRow(datum: Any): Seq[String]

/**
* Get the feature value from the row
* @param datum input row
* @return A map of feature name to feature value
*/
def getFeaturesFromRow(datum: GenericRowWithSchema): Map[String, FeatureValue]
def getFeaturesFromRow(datum: Any): Map[String, FeatureValue]
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import com.linkedin.feathr.offline.mvel.plugins.FeathrExpressionExecutionContext
import com.linkedin.feathr.offline.mvel.{MvelContext, MvelUtils}
import com.linkedin.feathr.offline.util.FeatureValueTypeValidator
import org.apache.log4j.Logger
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.types._
import org.mvel2.MVEL

Expand Down Expand Up @@ -66,7 +65,7 @@ private[offline] class SimpleConfigurableAnchorExtractor( @JsonProperty("key") k
* @param datum input row
* @return list of feature keys
*/
override def getKeyFromRow(datum: GenericRowWithSchema): Seq[String] = {
override def getKeyFromRow(datum: Any): Seq[String] = {
getKey(datum.asInstanceOf[Any])
}

Expand Down Expand Up @@ -107,7 +106,7 @@ private[offline] class SimpleConfigurableAnchorExtractor( @JsonProperty("key") k
* @param row input row
* @return A map of feature name to feature value
*/
override def getFeaturesFromRow(row: GenericRowWithSchema) = {
override def getFeaturesFromRow(row: Any) = {
getFeatures(row.asInstanceOf[Any])
}

Expand Down Expand Up @@ -147,7 +146,7 @@ private[offline] class SimpleConfigurableAnchorExtractor( @JsonProperty("key") k
featureTypeConfigs(featureRefStr)
}
val featureValue = offline.FeatureValue.fromTypeConfig(value, featureTypeConfig)
FeatureValueTypeValidator.validate(featureValue, featureTypeConfigs(featureRefStr))
FeatureValueTypeValidator.validate(featureRefStr, featureValue, featureTypeConfigs(featureRefStr) )
(featureRefStr, featureValue)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ private[feathr] class MVELSourceKeyExtractor(val anchorExtractorV1: AnchorExtrac
.toDF()
}

def getKey(datum: GenericRowWithSchema): Seq[String] = {
def getKey(datum: Any): Seq[String] = {
anchorExtractorV1.getKeyFromRow(datum)
}

Expand All @@ -55,7 +55,7 @@ private[feathr] class MVELSourceKeyExtractor(val anchorExtractorV1: AnchorExtrac
*/
override def getKeyColumnNames(datum: Option[Any]): Seq[String] = {
if (datum.isDefined) {
val size = getKey(datum.get.asInstanceOf[GenericRowWithSchema]).size
val size = getKey(datum.get).size
(1 to size).map(JOIN_KEY_PREFIX + _)
} else {
// return empty join key to signal empty dataset
Expand Down Expand Up @@ -86,5 +86,6 @@ private[feathr] class MVELSourceKeyExtractor(val anchorExtractorV1: AnchorExtrac
// this helps to reduce the number of joins
// to the observation data
// The default toString does not work, because toString of each object have different values
jaymo001 marked this conversation as resolved.
Show resolved Hide resolved
override def toString: String = getClass.getSimpleName + " with keyExprs:" + keyExprs.mkString(" key:")
override def toString: String = getClass.getSimpleName + " with keyExprs:" + keyExprs.mkString(" key:") +
"anchorExtractor:" + anchorExtractorV1.toString
}
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ private[offline] class AnchorLoader extends JsonDeserializer[FeatureAnchor] {
case Some(tType) => offline.FeatureValue.fromTypeConfig(rawValue, tType)
case None => offline.FeatureValue(rawValue, featureType, key)
}
FeatureValueTypeValidator.validate(featureValue, featureTypeConfig)
FeatureValueTypeValidator.validate(featureValue, featureTypeConfig, key)
(key, featureValue)
}
.toMap
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import com.linkedin.feathr.common.JoiningFeatureParams
import com.linkedin.feathr.offline.config.location.KafkaEndpoint
import com.linkedin.feathr.offline.generation.outputProcessor.PushToRedisOutputProcessor.TABLE_PARAM_CONFIG_NAME
import com.linkedin.feathr.offline.generation.outputProcessor.RedisOutputUtils
import com.linkedin.feathr.offline.job.FeatureTransformation.getFeatureJoinKey
import com.linkedin.feathr.offline.job.FeatureTransformation.getFeatureKeyColumnNames
import com.linkedin.feathr.offline.job.{FeatureGenSpec, FeatureTransformation}
import com.linkedin.feathr.offline.logical.FeatureGroups
import com.linkedin.feathr.offline.source.accessor.DataPathHandler
Expand Down Expand Up @@ -111,7 +111,7 @@ class StreamingFeatureGenerator(dataPathHandlers: List[DataPathHandler]) {
// Apply feature transformation
val transformedResult = DataFrameBasedSqlEvaluator.transform(anchor.featureAnchor.extractor.asInstanceOf[SimpleAnchorExtractorSpark],
withKeyColumnDF, featureNamePrefixPairs, anchor.featureAnchor.featureTypeConfigs)
val outputJoinKeyColumnNames = getFeatureJoinKey(keyExtractor, withKeyColumnDF)
val outputJoinKeyColumnNames = getFeatureKeyColumnNames(keyExtractor, withKeyColumnDF)
val selectedColumns = outputJoinKeyColumnNames ++ anchor.selectedFeatures.filter(keyTaggedFeatures.map(_.featureName).contains(_))
val cleanedDF = transformedResult.df.select(selectedColumns.head, selectedColumns.tail:_*)
val keyColumnNames = FeatureTransformation.getStandardizedKeyNames(outputJoinKeyColumnNames.size)
Expand Down
Loading