Skip to content

Commit

Permalink
Modify AvroDataReader to perform partition count check in all circums…
Browse files Browse the repository at this point in the history
…tances

- Fix bug: previously AvroDataReader would only repartition when IndexMapLoaders provided
- Removed integration test assertions that depended on strict ordering of data in DataFrame
- Added integration test to check that explicit repartition is called
  • Loading branch information
Alex Shelkovnykov committed Nov 15, 2018
1 parent a0f1d82 commit 05432d5
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 97 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,9 @@
*/
package com.linkedin.photon.ml.data.avro

import breeze.stats._
import org.apache.hadoop.fs.Path
import org.apache.spark.SparkException
import org.apache.spark.ml.linalg.{SparseVector, Vectors}
import org.apache.spark.ml.linalg.SparseVector
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions._
import org.testng.Assert._
Expand All @@ -26,7 +25,7 @@ import org.testng.annotations.Test
import com.linkedin.photon.ml.Constants
import com.linkedin.photon.ml.index.{IndexMap, PalDBIndexMapLoader}
import com.linkedin.photon.ml.io.FeatureShardConfiguration
import com.linkedin.photon.ml.test.{CommonTestUtils, SparkTestUtils}
import com.linkedin.photon.ml.test.SparkTestUtils

/**
* Unit tests for AvroDataReader
Expand All @@ -46,6 +45,24 @@ class AvroDataReaderIntegTest extends SparkTestUtils {
verifyDataFrame(df, expectedRows = 34810)
}

/**
* Test reading a [[DataFrame]].
*/
@Test(dependsOnMethods = Array("testRead"))
def testRepartition(): Unit = sparkTest("testRepartition") {

val dr = new AvroDataReader()

val (df1, indexMaps) = dr.readMerged(TRAIN_INPUT_PATH.toString, FEATURE_SHARD_CONFIGS_MAP, 1)
val numPartitions = df1.rdd.getNumPartitions
val expectedNumPartitions = numPartitions * 2
val (df2, _) = dr.readMerged(TRAIN_INPUT_PATH.toString, FEATURE_SHARD_CONFIGS_MAP, expectedNumPartitions)
val df3 = dr.readMerged(TRAIN_INPUT_PATH.toString, indexMaps, FEATURE_SHARD_CONFIGS_MAP, expectedNumPartitions)

assertEquals(df2.rdd.getNumPartitions, expectedNumPartitions)
assertEquals(df3.rdd.getNumPartitions, expectedNumPartitions)
}

/**
* Test reading a [[DataFrame]], using an existing [[IndexMap]].
*/
Expand Down Expand Up @@ -177,28 +194,8 @@ object AvroDataReaderIntegTest {

// Columns have the expected number of features and summary stats look right
assertTrue(df.columns.contains("shard1"))
val vector1 = df.select(col("shard1")).take(1)(0).getAs[SparseVector](0)
assertEquals(vector1.numActives, 61)
assertEquals(Vectors.norm(vector1, 2), 3.2298996752519407, CommonTestUtils.HIGH_PRECISION_TOLERANCE)
val (mu1: Double, _, var1: Double) = DescriptiveStats.meanAndCov(vector1.values, vector1.values)
assertEquals(mu1, 0.044020727910406766, CommonTestUtils.HIGH_PRECISION_TOLERANCE)
assertEquals(var1, 0.17190074364268512, CommonTestUtils.HIGH_PRECISION_TOLERANCE)

assertTrue(df.columns.contains("shard2"))
val vector2 = df.select(col("shard2")).take(1)(0).getAs[SparseVector](0)
assertEquals(vector2.numActives, 31)
assertEquals(Vectors.norm(vector2, 2), 2.509607963949448, CommonTestUtils.HIGH_PRECISION_TOLERANCE)
val (mu2: Double, _, var2: Double) = DescriptiveStats.meanAndCov(vector2.values, vector2.values)
assertEquals(mu2, 0.05196838235602745, CommonTestUtils.HIGH_PRECISION_TOLERANCE)
assertEquals(var2, 0.20714700123375754, CommonTestUtils.HIGH_PRECISION_TOLERANCE)

assertTrue(df.columns.contains("shard3"))
val vector3 = df.select(col("shard3")).take(1)(0).getAs[SparseVector](0)
assertEquals(vector3.numActives, 31)
assertEquals(Vectors.norm(vector3, 2), 2.265859611598675, CommonTestUtils.HIGH_PRECISION_TOLERANCE)
val (mu3: Double, _, var3: Double) = DescriptiveStats.meanAndCov(vector3.values, vector3.values)
assertEquals(mu3, 0.06691111449993427, CommonTestUtils.HIGH_PRECISION_TOLERANCE)
assertEquals(var3, 0.16651099216405915, CommonTestUtils.HIGH_PRECISION_TOLERANCE)

// Relationship between columns is the same across the entire dataframe
df.foreach { row =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.types.DataTypes._
import org.apache.spark.sql.types.{MapType, StructField, StructType}
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
import org.apache.spark.storage.StorageLevel

import com.linkedin.photon.ml.Constants
import com.linkedin.photon.ml.data.{DataReader, InputColumnsNames}
Expand Down Expand Up @@ -59,13 +60,40 @@ class AvroDataReader(defaultFeatureColumn: String = InputColumnsNames.FEATURES_D
private val sparkSession = SparkSession.builder.getOrCreate()
private val sc = sparkSession.sparkContext

/**
* Read Avro records into a [[RDD]] with the minimum number of partitions.
*
* @param paths The paths to the files or folders
* @param numPartitions The minimum number of partitions. Spark is generally moving away from manually specifying
* partition counts like this, in favor of inferring it. However, Photon currently still exposes
* partition counts as a means for tuning job performance. The auto-inferred counts are usually
* much lower than the necessary counts for Photon (especially GAME), so this caused a lot of
* shuffling when repartitioning from the auto-partitioned data to the GAME data. We expose this
* setting here to avoid the shuffling.
* @return A [[RDD]] of Avro records loaded from the given paths
*/
private def readRecords(paths: Seq[String], numPartitions: Int): RDD[GenericRecord] = {

val records = AvroUtils.readAvroFiles(sc, paths, numPartitions)

// Check partitions and force repartition if there are too few (sometimes AvroUtils does not respect min partitions
// request)
val partitionedRecords = if (records.getNumPartitions < numPartitions) {
records.repartition(numPartitions)
} else {
records
}

partitionedRecords.persist(StorageLevel.MEMORY_ONLY)
}

/**
* Reads the avro files at the given paths into a DataFrame, generating a default index map for feature names. Merges
* source columns into combined feature vectors as specified by the featureColumnMap argument. Often features are
* joined from different sources, and it can be more scalable to combine them into problem-specific feature vectors
* that can be independently distributed.
*
* @param paths The path to the files or folders
* @param paths The paths to the files or folders
* @param featureColumnConfigsMap A map that specifies how the feature columns should be merged. The keys specify the
* name of the merged destination column, and the values are configs containing sets of
* source columns to merge, e.g.:
Expand All @@ -90,7 +118,7 @@ class AvroDataReader(defaultFeatureColumn: String = InputColumnsNames.FEATURES_D
require(paths.nonEmpty, "No paths specified. You must specify at least one input path.")
require(numPartitions >= 0, "Partition count cannot be negative.")

val records = AvroUtils.readAvroFiles(sc, paths, numPartitions)
val records = readRecords(paths, numPartitions)
val featureColumnMap = featureColumnConfigsMap.mapValues(_.featureBags).map(identity)
val interceptColumnMap = featureColumnConfigsMap.mapValues(_.hasIntercept).map(identity)
val indexMapLoaders = generateIndexMapLoaders(records, featureColumnMap, interceptColumnMap)
Expand All @@ -104,7 +132,7 @@ class AvroDataReader(defaultFeatureColumn: String = InputColumnsNames.FEATURES_D
* different sources, and it can be more scalable to combine them into problem-specific feature vectors that can be
* independently distributed.
*
* @param paths The path to the files or folders
* @param paths The paths to the files or folders
* @param indexMapLoaders A map of index map loaders, containing one loader for each merged feature column
* @param featureColumnConfigsMap A map that specifies how the feature columns should be merged. The keys specify the
* name of the merged destination column, and the values are configs containing sets of
Expand All @@ -131,17 +159,10 @@ class AvroDataReader(defaultFeatureColumn: String = InputColumnsNames.FEATURES_D
require(paths.nonEmpty, "No paths specified. You must specify at least one input path.")
require(numPartitions >= 0, "Partition count cannot be negative.")

val records = readRecords(paths, numPartitions)
val featureColumnMap = featureColumnConfigsMap.mapValues(_.featureBags).map(identity)
val records = AvroUtils.readAvroFiles(sc, paths, numPartitions)
// Check partitions and force repartition if there are too few - sometimes AvroUtils does not respect min partitions
// request
val partitionedRecords = if (records.getNumPartitions < numPartitions) {
records.repartition(numPartitions)
} else {
records
}

readMerged(partitionedRecords, indexMapLoaders, featureColumnMap)
readMerged(records, indexMapLoaders, featureColumnMap)
}

/**
Expand Down Expand Up @@ -205,6 +226,9 @@ class AvroDataReader(defaultFeatureColumn: String = InputColumnsNames.FEATURES_D
}
val sqlSchema = new StructType((schemaFields ++ featureFields).toArray)

// To save re-computation for the RDD of GenericRecords (which is used 2/3 times), we persist it. However, this call
// is lazy evaluated (it's Spark transformation), and thus we cannot unpersist the RDD until this line is evaluated
// further down in the code. Thus, the 'records' RDD is persisted to memory, but never unpersisted.
sparkSession.createDataFrame(rows, sqlSchema)
}

Expand Down Expand Up @@ -252,8 +276,7 @@ object AvroDataReader {
FLOAT -> FloatType,
DOUBLE -> DoubleType,
STRING -> StringType,
BOOLEAN -> BooleanType
)
BOOLEAN -> BooleanType)

/**
* Establishes precedence among numeric types, for resolving unions where multiple types are specified. Appearing
Expand Down Expand Up @@ -281,7 +304,8 @@ object AvroDataReader {
.toSeq
.flatMap { fieldName =>
Some(record.get(fieldName)) match {
// Must have conversion to Seq at the end (labelled redundant by IDEA) or else typing compiler errors
// Must have conversion to Seq at the end (labelled redundant by IDEA) or else there will be type compiler
// errors
case Some(recordList: JList[_]) => recordList.asScala.toSeq
case other => throw new IllegalArgumentException(
s"Expected feature list $fieldName to be a Java List, found instead: ${other.getClass.getName}.")
Expand All @@ -294,8 +318,10 @@ object AvroDataReader {

featureKey -> Utils.getDoubleAvro(record, TrainingExampleFieldNames.VALUE)

case other => throw new IllegalArgumentException(s"$other in features list is not a GenericRecord")
}.toArray
case other =>
throw new IllegalArgumentException(s"$other in features list is not a GenericRecord")
}
.toArray
}

/**
Expand Down Expand Up @@ -383,54 +409,53 @@ object AvroDataReader {
* @param avroSchema The avro schema for the field
* @return Spark sql schema for the field
*/
protected[data] def avroTypeToSql(name: String, avroSchema: Schema): Option[StructField] =
avroSchema.getType match {
case avroType @ (INT | LONG | FLOAT | DOUBLE | STRING | BOOLEAN) =>
Some(StructField(name, primitiveTypeMap(avroType), nullable = false))

case MAP =>
avroTypeToSql(name, avroSchema.getValueType).map { valueSchema =>
StructField(
name,
MapType(StringType, valueSchema.dataType, valueContainsNull = valueSchema.nullable),
nullable = false)
}
protected[data] def avroTypeToSql(name: String, avroSchema: Schema): Option[StructField] = avroSchema.getType match {
case avroType @ (INT | LONG | FLOAT | DOUBLE | STRING | BOOLEAN) =>
Some(StructField(name, primitiveTypeMap(avroType), nullable = false))

case MAP =>
avroTypeToSql(name, avroSchema.getValueType).map { valueSchema =>
StructField(
name,
MapType(StringType, valueSchema.dataType, valueContainsNull = valueSchema.nullable),
nullable = false)
}

case UNION =>
if (avroSchema.getTypes.asScala.exists(_.getType == NULL)) {
// In case of a union with null, take the first non-null type for the value type
val remainingUnionTypes = avroSchema.getTypes.asScala.filterNot(_.getType == NULL)
if (remainingUnionTypes.size == 1) {
avroTypeToSql(name, remainingUnionTypes.head).map(_.copy(nullable = true))
} else {
avroTypeToSql(name, Schema.createUnion(remainingUnionTypes.asJava)).map(_.copy(nullable = true))
}
case UNION =>
if (avroSchema.getTypes.asScala.exists(_.getType == NULL)) {
// In case of a union with null, take the first non-null type for the value type
val remainingUnionTypes = avroSchema.getTypes.asScala.filterNot(_.getType == NULL)
if (remainingUnionTypes.size == 1) {
avroTypeToSql(name, remainingUnionTypes.head).map(_.copy(nullable = true))
} else {
avroTypeToSql(name, Schema.createUnion(remainingUnionTypes.asJava)).map(_.copy(nullable = true))
}

} else avroSchema.getTypes.asScala.map(_.getType) match {
} else avroSchema.getTypes.asScala.map(_.getType) match {

case numericTypes if allNumericTypes(numericTypes) =>
Some(StructField(name, primitiveTypeMap(getDominantNumericType(numericTypes)), nullable = false))
case numericTypes if allNumericTypes(numericTypes) =>
Some(StructField(name, primitiveTypeMap(getDominantNumericType(numericTypes)), nullable = false))

// When there are cases of multiple non-null types, resolve to a single sql type
case types: Seq[Schema.Type] =>
// If String is in the union, choose String
if (types.contains(STRING)) {
Some(StructField(name, primitiveTypeMap(STRING), nullable = false))
// When there are cases of multiple non-null types, resolve to a single sql type
case types: Seq[Schema.Type] =>
// If String is in the union, choose String
if (types.contains(STRING)) {
Some(StructField(name, primitiveTypeMap(STRING), nullable = false))

// Otherwise, choose first type in list
} else {
avroTypeToSql(name, avroSchema.getTypes.get(0))
}
// Otherwise, choose first type in list
} else {
avroTypeToSql(name, avroSchema.getTypes.get(0))
}

case _ =>
// Unsupported union type. Drop this for now.
None
}
case _ =>
// Unsupported union type. Drop this for now.
None
}

case _ =>
// Unsupported avro field type. Drop this for now.
None
}
case _ =>
// Unsupported avro field type. Drop this for now.
None
}

/**
* Read the fields from the avro record into column values according to the supplied spark sql schema.
Expand All @@ -440,22 +465,19 @@ object AvroDataReader {
* @return Column values
*/
protected[data] def readColumnValuesFromRecord(record: GenericRecord, schemaFields: Seq[StructField]): Seq[Any] =

schemaFields
.flatMap { field: StructField =>
field.dataType match {
case IntegerType => checkNull(record, field).orElse(Some(Utils.getIntAvro(record, field.name)))
case StringType => Some(Utils.getStringAvro(record, field.name, field.nullable))
case BooleanType => checkNull(record, field).orElse(Some(Utils.getBooleanAvro(record, field.name)))
case DoubleType => checkNull(record, field).orElse(Some(Utils.getDoubleAvro(record, field.name)))
case FloatType => checkNull(record, field).orElse(Some(Utils.getFloatAvro(record, field.name)))
case LongType => checkNull(record, field).orElse(Some(Utils.getLongAvro(record, field.name)))
case MapType(_, _, _) => Some(Utils.getMapAvro(record, field.name, field.nullable))
case _ =>
// Unsupported field type. Drop this for now.
None
}
schemaFields.flatMap { field: StructField =>
field.dataType match {
case IntegerType => checkNull(record, field).orElse(Some(Utils.getIntAvro(record, field.name)))
case StringType => Some(Utils.getStringAvro(record, field.name, field.nullable))
case BooleanType => checkNull(record, field).orElse(Some(Utils.getBooleanAvro(record, field.name)))
case DoubleType => checkNull(record, field).orElse(Some(Utils.getDoubleAvro(record, field.name)))
case FloatType => checkNull(record, field).orElse(Some(Utils.getFloatAvro(record, field.name)))
case LongType => checkNull(record, field).orElse(Some(Utils.getLongAvro(record, field.name)))
case MapType(_, _, _) => Some(Utils.getMapAvro(record, field.name, field.nullable))
// Unsupported field type. Drop this for now.
case _ => None
}
}

/**
* Checks whether null values are allowed for the record, and if so, passes along the null value. Otherwise, returns
Expand All @@ -466,7 +488,6 @@ object AvroDataReader {
* @return Some(null) if the field is null and nullable. None otherwise.
*/
protected[data] def checkNull(record: GenericRecord, field: StructField): Option[_] =

if (record.get(field.name) == null && field.nullable) {
Some(null)
} else {
Expand Down

0 comments on commit 05432d5

Please sign in to comment.