Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modify AvroDataReader to perform partition count check in all circumstances #399

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,9 @@
*/
package com.linkedin.photon.ml.data.avro

import breeze.stats._
import org.apache.hadoop.fs.Path
import org.apache.spark.SparkException
import org.apache.spark.ml.linalg.{SparseVector, Vectors}
import org.apache.spark.ml.linalg.SparseVector
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions._
import org.testng.Assert._
Expand All @@ -26,7 +25,7 @@ import org.testng.annotations.Test
import com.linkedin.photon.ml.Constants
import com.linkedin.photon.ml.index.{IndexMap, PalDBIndexMapLoader}
import com.linkedin.photon.ml.io.FeatureShardConfiguration
import com.linkedin.photon.ml.test.{CommonTestUtils, SparkTestUtils}
import com.linkedin.photon.ml.test.SparkTestUtils

/**
* Unit tests for AvroDataReader
Expand Down Expand Up @@ -59,6 +58,24 @@ class AvroDataReaderIntegTest extends SparkTestUtils {
verifyDataFrame(df, expectedRows = 34810)
}

/**
* Test reading a [[DataFrame]].
*/
@Test(dependsOnMethods = Array("testRead"))
def testRepartition(): Unit = sparkTest("testRepartition") {

val dr = new AvroDataReader()

val (df1, indexMaps) = dr.readMerged(TRAIN_INPUT_PATH.toString, FEATURE_SHARD_CONFIGS_MAP, 1)
val numPartitions = df1.rdd.getNumPartitions
val expectedNumPartitions = numPartitions * 2
val (df2, _) = dr.readMerged(TRAIN_INPUT_PATH.toString, FEATURE_SHARD_CONFIGS_MAP, expectedNumPartitions)
val df3 = dr.readMerged(TRAIN_INPUT_PATH.toString, indexMaps, FEATURE_SHARD_CONFIGS_MAP, expectedNumPartitions)

assertEquals(df2.rdd.getNumPartitions, expectedNumPartitions)
assertEquals(df3.rdd.getNumPartitions, expectedNumPartitions)
}

/**
* Test reading a [[DataFrame]], using an existing [[IndexMap]].
*/
Expand Down Expand Up @@ -190,28 +207,8 @@ object AvroDataReaderIntegTest {

// Columns have the expected number of features and summary stats look right
assertTrue(df.columns.contains("shard1"))
val vector1 = df.select(col("shard1")).take(1)(0).getAs[SparseVector](0)
assertEquals(vector1.numActives, 61)
assertEquals(Vectors.norm(vector1, 2), 3.2298996752519407, CommonTestUtils.HIGH_PRECISION_TOLERANCE)
val (mu1: Double, _, var1: Double) = DescriptiveStats.meanAndCov(vector1.values, vector1.values)
assertEquals(mu1, 0.044020727910406766, CommonTestUtils.HIGH_PRECISION_TOLERANCE)
assertEquals(var1, 0.17190074364268512, CommonTestUtils.HIGH_PRECISION_TOLERANCE)

assertTrue(df.columns.contains("shard2"))
val vector2 = df.select(col("shard2")).take(1)(0).getAs[SparseVector](0)
assertEquals(vector2.numActives, 31)
assertEquals(Vectors.norm(vector2, 2), 2.509607963949448, CommonTestUtils.HIGH_PRECISION_TOLERANCE)
val (mu2: Double, _, var2: Double) = DescriptiveStats.meanAndCov(vector2.values, vector2.values)
assertEquals(mu2, 0.05196838235602745, CommonTestUtils.HIGH_PRECISION_TOLERANCE)
assertEquals(var2, 0.20714700123375754, CommonTestUtils.HIGH_PRECISION_TOLERANCE)

assertTrue(df.columns.contains("shard3"))
val vector3 = df.select(col("shard3")).take(1)(0).getAs[SparseVector](0)
assertEquals(vector3.numActives, 31)
assertEquals(Vectors.norm(vector3, 2), 2.265859611598675, CommonTestUtils.HIGH_PRECISION_TOLERANCE)
val (mu3: Double, _, var3: Double) = DescriptiveStats.meanAndCov(vector3.values, vector3.values)
assertEquals(mu3, 0.06691111449993427, CommonTestUtils.HIGH_PRECISION_TOLERANCE)
assertEquals(var3, 0.16651099216405915, CommonTestUtils.HIGH_PRECISION_TOLERANCE)

// Relationship between columns is the same across the entire dataframe
df.foreach { row =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.types.DataTypes._
import org.apache.spark.sql.types.{MapType, StructField, StructType}
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
import org.apache.spark.storage.StorageLevel

import com.linkedin.photon.ml.Constants
import com.linkedin.photon.ml.data.{DataReader, InputColumnsNames}
Expand Down Expand Up @@ -59,13 +60,40 @@ class AvroDataReader(defaultFeatureColumn: String = InputColumnsNames.FEATURES_D
private val sparkSession = SparkSession.builder.getOrCreate()
private val sc = sparkSession.sparkContext

/**
* Read Avro records into a [[RDD]] with the minimum number of partitions.
*
* @param paths The paths to the files or folders
* @param numPartitions The minimum number of partitions. Spark is generally moving away from manually specifying
* partition counts like this, in favor of inferring it. However, Photon currently still exposes
* partition counts as a means for tuning job performance. The auto-inferred counts are usually
* much lower than the necessary counts for Photon (especially GAME), so this caused a lot of
* shuffling when repartitioning from the auto-partitioned data to the GAME data. We expose this
* setting here to avoid the shuffling.
* @return A [[RDD]] of Avro records loaded from the given paths
*/
private def readRecords(paths: Seq[String], numPartitions: Int): RDD[GenericRecord] = {

val records = AvroUtils.readAvroFiles(sc, paths, numPartitions)

// Check partitions and force repartition if there are too few (sometimes AvroUtils does not respect min partitions
// request)
val partitionedRecords = if (records.getNumPartitions < numPartitions) {
records.repartition(numPartitions)
} else {
records
}

partitionedRecords.persist(StorageLevel.MEMORY_AND_DISK_SER)
}

/**
* Reads the avro files at the given paths into a DataFrame, generating a default index map for feature names. Merges
* source columns into combined feature vectors as specified by the featureColumnMap argument. Often features are
* joined from different sources, and it can be more scalable to combine them into problem-specific feature vectors
* that can be independently distributed.
*
* @param paths The path to the files or folders
* @param paths The paths to the files or folders
* @param featureColumnConfigsMap A map that specifies how the feature columns should be merged. The keys specify the
* name of the merged destination column, and the values are configs containing sets of
* source columns to merge, e.g.:
Expand All @@ -90,7 +118,7 @@ class AvroDataReader(defaultFeatureColumn: String = InputColumnsNames.FEATURES_D
require(paths.nonEmpty, "No paths specified. You must specify at least one input path.")
require(numPartitions >= 0, "Partition count cannot be negative.")

val records = AvroUtils.readAvroFiles(sc, paths, numPartitions)
val records = readRecords(paths, numPartitions)
val featureColumnMap = featureColumnConfigsMap.mapValues(_.featureBags).map(identity)
val interceptColumnMap = featureColumnConfigsMap.mapValues(_.hasIntercept).map(identity)
val indexMapLoaders = generateIndexMapLoaders(records, featureColumnMap, interceptColumnMap)
Expand All @@ -104,7 +132,7 @@ class AvroDataReader(defaultFeatureColumn: String = InputColumnsNames.FEATURES_D
* different sources, and it can be more scalable to combine them into problem-specific feature vectors that can be
* independently distributed.
*
* @param paths The path to the files or folders
* @param paths The paths to the files or folders
* @param indexMapLoaders A map of index map loaders, containing one loader for each merged feature column
* @param featureColumnConfigsMap A map that specifies how the feature columns should be merged. The keys specify the
* name of the merged destination column, and the values are configs containing sets of
Expand All @@ -131,17 +159,10 @@ class AvroDataReader(defaultFeatureColumn: String = InputColumnsNames.FEATURES_D
require(paths.nonEmpty, "No paths specified. You must specify at least one input path.")
require(numPartitions >= 0, "Partition count cannot be negative.")

val records = readRecords(paths, numPartitions)
val featureColumnMap = featureColumnConfigsMap.mapValues(_.featureBags).map(identity)
val records = AvroUtils.readAvroFiles(sc, paths, numPartitions)
// Check partitions and force repartition if there are too few - sometimes AvroUtils does not respect min partitions
// request
val partitionedRecords = if (records.getNumPartitions < numPartitions) {
records.repartition(numPartitions)
} else {
records
}

readMerged(partitionedRecords, indexMapLoaders, featureColumnMap)
readMerged(records, indexMapLoaders, featureColumnMap)
}

/**
Expand Down Expand Up @@ -205,6 +226,9 @@ class AvroDataReader(defaultFeatureColumn: String = InputColumnsNames.FEATURES_D
}
val sqlSchema = new StructType((schemaFields ++ featureFields).toArray)

// To save re-computation for the RDD of GenericRecords (which is used 2/3 times), we persist it. However, this call
// is lazy evaluated (it's Spark transformation), and thus we cannot unpersist the RDD until this line is evaluated
// further down in the code. Thus, the 'records' RDD is persisted to memory, but never unpersisted.
sparkSession.createDataFrame(rows, sqlSchema)
}

Expand Down Expand Up @@ -252,8 +276,7 @@ object AvroDataReader {
FLOAT -> FloatType,
DOUBLE -> DoubleType,
STRING -> StringType,
BOOLEAN -> BooleanType
)
BOOLEAN -> BooleanType)

/**
* Establishes precedence among numeric types, for resolving unions where multiple types are specified. Appearing
Expand Down Expand Up @@ -281,7 +304,8 @@ object AvroDataReader {
.toSeq
.flatMap { fieldName =>
Some(record.get(fieldName)) match {
// Must have conversion to Seq at the end (labelled redundant by IDEA) or else typing compiler errors
// Must have conversion to Seq at the end (labelled redundant by IDEA) or else there will be type compiler
// errors
case Some(recordList: JList[_]) => recordList.asScala.toSeq
case other => throw new IllegalArgumentException(
s"Expected feature list $fieldName to be a Java List, found instead: ${other.getClass.getName}.")
Expand All @@ -294,8 +318,10 @@ object AvroDataReader {

featureKey -> Utils.getDoubleAvro(record, TrainingExampleFieldNames.VALUE)

case other => throw new IllegalArgumentException(s"$other in features list is not a GenericRecord")
}.toArray
case other =>
throw new IllegalArgumentException(s"$other in features list is not a GenericRecord")
}
.toArray
}

/**
Expand Down Expand Up @@ -383,54 +409,54 @@ object AvroDataReader {
* @param avroSchema The avro schema for the field
* @return Spark sql schema for the field
*/
protected[data] def avroTypeToSql(name: String, avroSchema: Schema): Option[StructField] =
avroSchema.getType match {
case avroType @ (INT | LONG | FLOAT | DOUBLE | STRING | BOOLEAN) =>
Some(StructField(name, primitiveTypeMap(avroType), nullable = false))

case MAP =>
avroTypeToSql(name, avroSchema.getValueType).map { valueSchema =>
StructField(
name,
MapType(StringType, valueSchema.dataType, valueContainsNull = valueSchema.nullable),
nullable = false)
}
protected[data] def avroTypeToSql(name: String, avroSchema: Schema): Option[StructField] = avroSchema.getType match {

case UNION =>
if (avroSchema.getTypes.asScala.exists(_.getType == NULL)) {
// In case of a union with null, take the first non-null type for the value type
val remainingUnionTypes = avroSchema.getTypes.asScala.filterNot(_.getType == NULL)
if (remainingUnionTypes.size == 1) {
avroTypeToSql(name, remainingUnionTypes.head).map(_.copy(nullable = true))
} else {
avroTypeToSql(name, Schema.createUnion(remainingUnionTypes.asJava)).map(_.copy(nullable = true))
}
case avroType @ (INT | LONG | FLOAT | DOUBLE | STRING | BOOLEAN) =>
Some(StructField(name, primitiveTypeMap(avroType), nullable = false))

} else avroSchema.getTypes.asScala.map(_.getType) match {
case MAP =>
avroTypeToSql(name, avroSchema.getValueType).map { valueSchema =>
StructField(
name,
MapType(StringType, valueSchema.dataType, valueContainsNull = valueSchema.nullable),
nullable = false)
}

case numericTypes if allNumericTypes(numericTypes) =>
Some(StructField(name, primitiveTypeMap(getDominantNumericType(numericTypes)), nullable = false))
case UNION =>
if (avroSchema.getTypes.asScala.exists(_.getType == NULL)) {
// In case of a union with null, take the first non-null type for the value type
val remainingUnionTypes = avroSchema.getTypes.asScala.filterNot(_.getType == NULL)
if (remainingUnionTypes.size == 1) {
avroTypeToSql(name, remainingUnionTypes.head).map(_.copy(nullable = true))
} else {
avroTypeToSql(name, Schema.createUnion(remainingUnionTypes.asJava)).map(_.copy(nullable = true))
}

// When there are cases of multiple non-null types, resolve to a single sql type
case types: Seq[Schema.Type] =>
// If String is in the union, choose String
if (types.contains(STRING)) {
Some(StructField(name, primitiveTypeMap(STRING), nullable = false))
} else avroSchema.getTypes.asScala.map(_.getType) match {

// Otherwise, choose first type in list
} else {
avroTypeToSql(name, avroSchema.getTypes.get(0))
}
case numericTypes if allNumericTypes(numericTypes) =>
Some(StructField(name, primitiveTypeMap(getDominantNumericType(numericTypes)), nullable = false))

case _ =>
// Unsupported union type. Drop this for now.
None
}
// When there are cases of multiple non-null types, resolve to a single sql type
case types: Seq[Schema.Type] =>
// If String is in the union, choose String
if (types.contains(STRING)) {
Some(StructField(name, primitiveTypeMap(STRING), nullable = false))

case _ =>
// Unsupported avro field type. Drop this for now.
None
}
// Otherwise, choose first type in list
} else {
avroTypeToSql(name, avroSchema.getTypes.get(0))
}

case _ =>
// Unsupported union type. Drop this for now.
None
}

case _ =>
// Unsupported avro field type. Drop this for now.
None
}

/**
* Read the fields from the avro record into column values according to the supplied spark sql schema.
Expand Down Expand Up @@ -481,7 +507,6 @@ object AvroDataReader {
* @return Some(null) if the field is null and nullable. None otherwise.
*/
protected[data] def checkNull(record: GenericRecord, field: StructField): Option[_] =

if (record.get(field.name) == null && field.nullable) {
Some(null)
} else {
Expand Down