From f88b12537ee81d914ef7c51a08f80cb28d93c8ed Mon Sep 17 00:00:00 2001 From: lewuathe Date: Thu, 9 Jul 2015 08:16:26 -0700 Subject: [PATCH 01/23] [SPARK-6266] [MLLIB] PySpark SparseVector missing doc for size, indices, values Write missing pydocs in `SparseVector` attributes. Author: lewuathe Closes #7290 from Lewuathe/SPARK-6266 and squashes the following commits: 51d9895 [lewuathe] Update docs 0480d35 [lewuathe] Merge branch 'master' into SPARK-6266 ba42cf3 [lewuathe] [SPARK-6266] PySpark SparseVector missing doc for size, indices, values --- python/pyspark/mllib/linalg.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg.py index 51ac198305711..040886f71775b 100644 --- a/python/pyspark/mllib/linalg.py +++ b/python/pyspark/mllib/linalg.py @@ -445,8 +445,10 @@ def __init__(self, size, *args): values (sorted by index). :param size: Size of the vector. - :param args: Non-zero entries, as a dictionary, list of tupes, - or two sorted lists containing indices and values. + :param args: Active entries, as a dictionary {index: value, ...}, + a list of tuples [(index, value), ...], or a list of strictly i + ncreasing indices and a list of corresponding values [index, ...], + [value, ...]. Inactive entries are treated as zeros. >>> SparseVector(4, {1: 1.0, 3: 5.5}) SparseVector(4, {1: 1.0, 3: 5.5}) @@ -456,6 +458,7 @@ def __init__(self, size, *args): SparseVector(4, {1: 1.0, 3: 5.5}) """ self.size = int(size) + """ Size of the vector. """ assert 1 <= len(args) <= 2, "must pass either 2 or 3 arguments" if len(args) == 1: pairs = args[0] @@ -463,7 +466,9 @@ def __init__(self, size, *args): pairs = pairs.items() pairs = sorted(pairs) self.indices = np.array([p[0] for p in pairs], dtype=np.int32) + """ A list of indices corresponding to active entries. """ self.values = np.array([p[1] for p in pairs], dtype=np.float64) + """ A list of values corresponding to active entries. """ else: if isinstance(args[0], bytes): assert isinstance(args[1], bytes), "values should be string too" From 23448a9e988a1b92bd05ee8c6c1a096c83375a12 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 9 Jul 2015 09:20:16 -0700 Subject: [PATCH 02/23] [SPARK-8931] [SQL] Fallback to interpreted evaluation if failed to compile in codegen Exception will not be catched during tests. cc marmbrus rxin Author: Davies Liu Closes #7309 from davies/fallback and squashes the following commits: 969a612 [Davies Liu] throw exception during tests f844f77 [Davies Liu] fallback a3091bc [Davies Liu] Merge branch 'master' of github.com:apache/spark into fallback 364a0d6 [Davies Liu] fallback to interpret mode if failed to compile --- .../spark/sql/execution/SparkPlan.scala | 51 +++++++++++++++++-- .../apache/spark/sql/sources/commands.scala | 13 ++++- 2 files changed, 58 insertions(+), 6 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index ca53186383237..4d7d8626a0ecc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -153,12 +153,24 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ buf.toArray.map(converter(_).asInstanceOf[Row]) } + private[this] def isTesting: Boolean = sys.props.contains("spark.testing") + protected def newProjection( expressions: Seq[Expression], inputSchema: Seq[Attribute]): Projection = { log.debug( s"Creating Projection: $expressions, inputSchema: $inputSchema, codegen:$codegenEnabled") if (codegenEnabled) { - GenerateProjection.generate(expressions, inputSchema) + try { + GenerateProjection.generate(expressions, inputSchema) + } catch { + case e: Exception => + if (isTesting) { + throw e + } else { + log.error("Failed to generate projection, fallback to interpret", e) + new InterpretedProjection(expressions, inputSchema) + } + } } else { new InterpretedProjection(expressions, inputSchema) } @@ -170,17 +182,36 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ log.debug( s"Creating MutableProj: $expressions, inputSchema: $inputSchema, codegen:$codegenEnabled") if(codegenEnabled) { - GenerateMutableProjection.generate(expressions, inputSchema) + try { + GenerateMutableProjection.generate(expressions, inputSchema) + } catch { + case e: Exception => + if (isTesting) { + throw e + } else { + log.error("Failed to generate mutable projection, fallback to interpreted", e) + () => new InterpretedMutableProjection(expressions, inputSchema) + } + } } else { () => new InterpretedMutableProjection(expressions, inputSchema) } } - protected def newPredicate( expression: Expression, inputSchema: Seq[Attribute]): (InternalRow) => Boolean = { if (codegenEnabled) { - GeneratePredicate.generate(expression, inputSchema) + try { + GeneratePredicate.generate(expression, inputSchema) + } catch { + case e: Exception => + if (isTesting) { + throw e + } else { + log.error("Failed to generate predicate, fallback to interpreted", e) + InterpretedPredicate.create(expression, inputSchema) + } + } } else { InterpretedPredicate.create(expression, inputSchema) } @@ -190,7 +221,17 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ order: Seq[SortOrder], inputSchema: Seq[Attribute]): Ordering[InternalRow] = { if (codegenEnabled) { - GenerateOrdering.generate(order, inputSchema) + try { + GenerateOrdering.generate(order, inputSchema) + } catch { + case e: Exception => + if (isTesting) { + throw e + } else { + log.error("Failed to generate ordering, fallback to interpreted", e) + new RowOrdering(order, inputSchema) + } + } } else { new RowOrdering(order, inputSchema) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala index ecbc889770625..9189d176111d5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala @@ -276,7 +276,18 @@ private[sql] case class InsertIntoHadoopFsRelation( log.debug( s"Creating Projection: $expressions, inputSchema: $inputSchema, codegen:$codegenEnabled") if (codegenEnabled) { - GenerateProjection.generate(expressions, inputSchema) + + try { + GenerateProjection.generate(expressions, inputSchema) + } catch { + case e: Exception => + if (sys.props.contains("spark.testing")) { + throw e + } else { + log.error("failed to generate projection, fallback to interpreted", e) + new InterpretedProjection(expressions, inputSchema) + } + } } else { new InterpretedProjection(expressions, inputSchema) } From a1964e9d902bb31f001893da8bc81f6dce08c908 Mon Sep 17 00:00:00 2001 From: Tarek Auel Date: Thu, 9 Jul 2015 09:22:24 -0700 Subject: [PATCH 03/23] [SPARK-8830] [SQL] native levenshtein distance Jira: https://issues.apache.org/jira/browse/SPARK-8830 rxin and HuJiayin can you have a look on it. Author: Tarek Auel Closes #7236 from tarekauel/native-levenshtein-distance and squashes the following commits: ee4c4de [Tarek Auel] [SPARK-8830] implemented improvement proposals c252e71 [Tarek Auel] [SPARK-8830] removed chartAt; use unsafe method for byte array comparison ddf2222 [Tarek Auel] Merge branch 'master' into native-levenshtein-distance 179920a [Tarek Auel] [SPARK-8830] added description 5e9ed54 [Tarek Auel] [SPARK-8830] removed StringUtils import dce4308 [Tarek Auel] [SPARK-8830] native levenshtein distance --- .../expressions/stringOperations.scala | 9 ++- .../expressions/StringFunctionsSuite.scala | 5 ++ .../apache/spark/unsafe/types/UTF8String.java | 66 ++++++++++++++++++- .../spark/unsafe/types/UTF8StringSuite.java | 24 +++++++ 4 files changed, 97 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 47fc7cdaa826c..57f436485becf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -284,13 +284,12 @@ case class Levenshtein(left: Expression, right: Expression) extends BinaryExpres override def dataType: DataType = IntegerType - protected override def nullSafeEval(input1: Any, input2: Any): Any = - StringUtils.getLevenshteinDistance(input1.toString, input2.toString) + protected override def nullSafeEval(leftValue: Any, rightValue: Any): Any = + leftValue.asInstanceOf[UTF8String].levenshteinDistance(rightValue.asInstanceOf[UTF8String]) override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val stringUtils = classOf[StringUtils].getName - defineCodeGen(ctx, ev, (left, right) => - s"$stringUtils.getLevenshteinDistance($left.toString(), $right.toString())") + nullSafeCodeGen(ctx, ev, (left, right) => + s"${ev.primitive} = $left.levenshteinDistance($right);") } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala index 1efbe1a245e83..69bef1c63e9dc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala @@ -282,5 +282,10 @@ class StringFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Levenshtein(Literal("abc"), Literal("abc")), 0) checkEvaluation(Levenshtein(Literal("kitten"), Literal("sitting")), 3) checkEvaluation(Levenshtein(Literal("frog"), Literal("fog")), 1) + // scalastyle:off + // non ascii characters are not allowed in the code, so we disable the scalastyle here. + checkEvaluation(Levenshtein(Literal("千世"), Literal("fog")), 3) + checkEvaluation(Levenshtein(Literal("世界千世"), Literal("大a界b")), 4) + // scalastyle:on } } diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index d2a25096a5e7a..847d80ad583f6 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -99,8 +99,6 @@ public int numBytes() { /** * Returns the number of code points in it. - * - * This is only used by Substring() when `start` is negative. */ public int numChars() { int len = 0; @@ -254,6 +252,70 @@ public boolean equals(final Object other) { } } + /** + * Levenshtein distance is a metric for measuring the distance of two strings. The distance is + * defined by the minimum number of single-character edits (i.e. insertions, deletions or + * substitutions) that are required to change one of the strings into the other. + */ + public int levenshteinDistance(UTF8String other) { + // Implementation adopted from org.apache.common.lang3.StringUtils.getLevenshteinDistance + + int n = numChars(); + int m = other.numChars(); + + if (n == 0) { + return m; + } else if (m == 0) { + return n; + } + + UTF8String s, t; + + if (n <= m) { + s = this; + t = other; + } else { + s = other; + t = this; + int swap; + swap = n; + n = m; + m = swap; + } + + int p[] = new int[n + 1]; + int d[] = new int[n + 1]; + int swap[]; + + int i, i_bytes, j, j_bytes, num_bytes_j, cost; + + for (i = 0; i <= n; i++) { + p[i] = i; + } + + for (j = 0, j_bytes = 0; j < m; j_bytes += num_bytes_j, j++) { + num_bytes_j = numBytesForFirstByte(t.getByte(j_bytes)); + d[0] = j + 1; + + for (i = 0, i_bytes = 0; i < n; i_bytes += numBytesForFirstByte(s.getByte(i_bytes)), i++) { + if (s.getByte(i_bytes) != t.getByte(j_bytes) || + num_bytes_j != numBytesForFirstByte(s.getByte(i_bytes))) { + cost = 1; + } else { + cost = (ByteArrayMethods.arrayEquals(t.base, t.offset + j_bytes, s.base, + s.offset + i_bytes, num_bytes_j)) ? 0 : 1; + } + d[i + 1] = Math.min(Math.min(d[i] + 1, p[i + 1] + 1), p[i] + cost); + } + + swap = p; + p = d; + d = swap; + } + + return p[n]; + } + @Override public int hashCode() { int result = 1; diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java index 8ec69ebac8b37..fb463ba17f50b 100644 --- a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java +++ b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java @@ -128,4 +128,28 @@ public void substring() { assertEquals(fromString("数据砖头").substring(3, 5), fromString("头")); assertEquals(fromString("ߵ梷").substring(0, 2), fromString("ߵ梷")); } + + @Test + public void levenshteinDistance() { + assertEquals( + UTF8String.fromString("").levenshteinDistance(UTF8String.fromString("")), 0); + assertEquals( + UTF8String.fromString("").levenshteinDistance(UTF8String.fromString("a")), 1); + assertEquals( + UTF8String.fromString("aaapppp").levenshteinDistance(UTF8String.fromString("")), 7); + assertEquals( + UTF8String.fromString("frog").levenshteinDistance(UTF8String.fromString("fog")), 1); + assertEquals( + UTF8String.fromString("fly").levenshteinDistance(UTF8String.fromString("ant")),3); + assertEquals( + UTF8String.fromString("elephant").levenshteinDistance(UTF8String.fromString("hippo")), 7); + assertEquals( + UTF8String.fromString("hippo").levenshteinDistance(UTF8String.fromString("elephant")), 7); + assertEquals( + UTF8String.fromString("hippo").levenshteinDistance(UTF8String.fromString("zzzzzzzz")), 8); + assertEquals( + UTF8String.fromString("hello").levenshteinDistance(UTF8String.fromString("hallo")),1); + assertEquals( + UTF8String.fromString("世界千世").levenshteinDistance(UTF8String.fromString("千a世b")),4); + } } From 59cc38944fe5c1dffc6551775bd939e2ac66c65e Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 9 Jul 2015 09:57:12 -0700 Subject: [PATCH 04/23] [SPARK-8940] [SPARKR] Don't overwrite given schema in createDataFrame JIRA: https://issues.apache.org/jira/browse/SPARK-8940 The given `schema` parameter will be overwritten in `createDataFrame` now. If it is not null, we shouldn't overwrite it. Author: Liang-Chi Hsieh Closes #7311 from viirya/df_not_overwrite_schema and squashes the following commits: 2385139 [Liang-Chi Hsieh] Don't overwrite given schema if it is not null. --- R/pkg/R/SQLContext.R | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index 9a743a3411533..30978bb50d339 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -86,7 +86,9 @@ infer_type <- function(x) { createDataFrame <- function(sqlContext, data, schema = NULL, samplingRatio = 1.0) { if (is.data.frame(data)) { # get the names of columns, they will be put into RDD - schema <- names(data) + if (is.null(schema)) { + schema <- names(data) + } n <- nrow(data) m <- ncol(data) # get rid of factor type From e204d22bb70f28b1cc090ab60f12078479be4ae0 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 9 Jul 2015 10:01:01 -0700 Subject: [PATCH 05/23] [SPARK-8948][SQL] Remove ExtractValueWithOrdinal abstract class Also added more documentation for the file. Author: Reynold Xin Closes #7316 from rxin/extract-value and squashes the following commits: 069cb7e [Reynold Xin] Removed ExtractValueWithOrdinal. 621b705 [Reynold Xin] Reverted a line. 11ebd6c [Reynold Xin] [Minor][SQL] Improve documentation for complex type extractors. --- ...alue.scala => complexTypeExtractors.scala} | 54 ++++++++++++------- 1 file changed, 34 insertions(+), 20 deletions(-) rename sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/{ExtractValue.scala => complexTypeExtractors.scala} (86%) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala similarity index 86% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 2b25ba03579ec..73cc930c45832 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -25,6 +25,11 @@ import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} import org.apache.spark.sql.types._ +//////////////////////////////////////////////////////////////////////////////////////////////////// +// This file defines all the expressions to extract values out of complex types. +// For example, getting a field out of an array, map, or struct. +//////////////////////////////////////////////////////////////////////////////////////////////////// + object ExtractValue { /** @@ -73,11 +78,10 @@ object ExtractValue { } } - def unapply(g: ExtractValue): Option[(Expression, Expression)] = { - g match { - case o: ExtractValueWithOrdinal => Some((o.child, o.ordinal)) - case s: ExtractValueWithStruct => Some((s.child, null)) - } + def unapply(g: ExtractValue): Option[(Expression, Expression)] = g match { + case o: GetArrayItem => Some((o.child, o.ordinal)) + case o: GetMapValue => Some((o.child, o.key)) + case s: ExtractValueWithStruct => Some((s.child, null)) } /** @@ -117,6 +121,8 @@ abstract class ExtractValueWithStruct extends UnaryExpression with ExtractValue /** * Returns the value of fields in the Struct `child`. + * + * No need to do type checking since it is handled by [[ExtractValue]]. */ case class GetStructField(child: Expression, field: StructField, ordinal: Int) extends ExtractValueWithStruct { @@ -142,6 +148,8 @@ case class GetStructField(child: Expression, field: StructField, ordinal: Int) /** * Returns the array of value of fields in the Array of Struct `child`. + * + * No need to do type checking since it is handled by [[ExtractValue]]. */ case class GetArrayStructFields( child: Expression, @@ -178,25 +186,21 @@ case class GetArrayStructFields( } } -abstract class ExtractValueWithOrdinal extends BinaryExpression with ExtractValue { - self: Product => +/** + * Returns the field at `ordinal` in the Array `child`. + * + * No need to do type checking since it is handled by [[ExtractValue]]. + */ +case class GetArrayItem(child: Expression, ordinal: Expression) + extends BinaryExpression with ExtractValue { - def ordinal: Expression - def child: Expression + override def toString: String = s"$child[$ordinal]" override def left: Expression = child override def right: Expression = ordinal /** `Null` is returned for invalid ordinals. */ override def nullable: Boolean = true - override def toString: String = s"$child[$ordinal]" -} - -/** - * Returns the field at `ordinal` in the Array `child` - */ -case class GetArrayItem(child: Expression, ordinal: Expression) - extends ExtractValueWithOrdinal { override def dataType: DataType = child.dataType.asInstanceOf[ArrayType].elementType @@ -227,10 +231,20 @@ case class GetArrayItem(child: Expression, ordinal: Expression) } /** - * Returns the value of key `ordinal` in Map `child` + * Returns the value of key `ordinal` in Map `child`. + * + * No need to do type checking since it is handled by [[ExtractValue]]. */ -case class GetMapValue(child: Expression, ordinal: Expression) - extends ExtractValueWithOrdinal { +case class GetMapValue(child: Expression, key: Expression) + extends BinaryExpression with ExtractValue { + + override def toString: String = s"$child[$key]" + + override def left: Expression = child + override def right: Expression = key + + /** `Null` is returned for invalid ordinals. */ + override def nullable: Boolean = true override def dataType: DataType = child.dataType.asInstanceOf[MapType].valueType From a870a82fb6f57bb63bd6f1e95da944a30f67519a Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 9 Jul 2015 10:01:33 -0700 Subject: [PATCH 06/23] [SPARK-8926][SQL] Code review followup. I merged https://github.com/apache/spark/pull/7303 so it unblocks another PR. This addresses my own code review comment for that PR. Author: Reynold Xin Closes #7313 from rxin/adt and squashes the following commits: 7ade82b [Reynold Xin] Fixed unit tests. f8d5533 [Reynold Xin] [SPARK-8926][SQL] Code review followup. --- .../catalyst/expressions/ExpectsInputTypes.scala | 4 ++-- .../spark/sql/types/AbstractDataType.scala | 16 ++++++++++++++++ .../catalyst/analysis/AnalysisErrorSuite.scala | 8 ++++---- .../analysis/HiveTypeCoercionSuite.scala | 1 + 4 files changed, 23 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala index 986cc09499d1f..3eb0eb195c80d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala @@ -39,8 +39,8 @@ trait ExpectsInputTypes { self: Expression => override def checkInputDataTypes(): TypeCheckResult = { val mismatches = children.zip(inputTypes).zipWithIndex.collect { case ((child, expected), idx) if !expected.acceptsType(child.dataType) => - s"Argument ${idx + 1} is expected to be of type ${expected.simpleString}, " + - s"however, ${child.prettyString} is of type ${child.dataType.simpleString}." + s"argument ${idx + 1} is expected to be of type ${expected.simpleString}, " + + s"however, '${child.prettyString}' is of type ${child.dataType.simpleString}." } if (mismatches.isEmpty) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala index ad75fa2e31d90..32f87440b4e37 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala @@ -36,12 +36,28 @@ private[sql] abstract class AbstractDataType { /** * Returns true if this data type is the same type as `other`. This is different that equality * as equality will also consider data type parametrization, such as decimal precision. + * + * {{{ + * // this should return true + * DecimalType.isSameType(DecimalType(10, 2)) + * + * // this should return false + * NumericType.isSameType(DecimalType(10, 2)) + * }}} */ private[sql] def isSameType(other: DataType): Boolean /** * Returns true if `other` is an acceptable input type for a function that expectes this, * possibly abstract, DataType. + * + * {{{ + * // this should return true + * DecimalType.isSameType(DecimalType(10, 2)) + * + * // this should return true as well + * NumericType.acceptsType(DecimalType(10, 2)) + * }}} */ private[sql] def acceptsType(other: DataType): Boolean = isSameType(other) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index 73236c3acbca2..9d0c69a2451d1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -58,7 +58,7 @@ class AnalysisErrorSuite extends SparkFunSuite with BeforeAndAfter { } } - errorMessages.foreach(m => assert(error.getMessage.toLowerCase contains m.toLowerCase)) + errorMessages.foreach(m => assert(error.getMessage.toLowerCase.contains(m.toLowerCase))) } } @@ -68,21 +68,21 @@ class AnalysisErrorSuite extends SparkFunSuite with BeforeAndAfter { "single invalid type, single arg", testRelation.select(TestFunction(dateLit :: Nil, IntegerType :: Nil).as('a)), "cannot resolve" :: "testfunction" :: "argument 1" :: "expected to be of type int" :: - "null is of type date" ::Nil) + "'null' is of type date" ::Nil) errorTest( "single invalid type, second arg", testRelation.select( TestFunction(dateLit :: dateLit :: Nil, DateType :: IntegerType :: Nil).as('a)), "cannot resolve" :: "testfunction" :: "argument 2" :: "expected to be of type int" :: - "null is of type date" ::Nil) + "'null' is of type date" ::Nil) errorTest( "multiple invalid type", testRelation.select( TestFunction(dateLit :: dateLit :: Nil, IntegerType :: IntegerType :: Nil).as('a)), "cannot resolve" :: "testfunction" :: "argument 1" :: "argument 2" :: - "expected to be of type int" :: "null is of type date" ::Nil) + "expected to be of type int" :: "'null' is of type date" ::Nil) errorTest( "unresolved window function", diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala index 6e3aa0eebeb15..acb9a433de903 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala @@ -79,6 +79,7 @@ class HiveTypeCoercionSuite extends PlanTest { shouldCast(IntegerType, TypeCollection(DecimalType(10, 2), StringType), DecimalType(10, 2)) shouldCast(StringType, NumericType, DoubleType) + shouldCast(StringType, TypeCollection(NumericType, BinaryType), DoubleType) // NumericType should not be changed when function accepts any of them. Seq(ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType, From f6c0bd5c3755b2f9bab633a5d478240fdaf1c593 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 9 Jul 2015 10:04:42 -0700 Subject: [PATCH 07/23] [SPARK-8938][SQL] Implement toString for Interval data type Author: Wenchen Fan Closes #7315 from cloud-fan/toString and squashes the following commits: 4fc8d80 [Wenchen Fan] Implement toString for Interval data type --- .../apache/spark/sql/catalyst/SqlParser.scala | 24 ++++++-- .../apache/spark/unsafe/types/Interval.java | 42 +++++++++++++ .../spark/unsafe/types/IntervalSuite.java | 59 +++++++++++++++++++ 3 files changed, 119 insertions(+), 6 deletions(-) create mode 100644 unsafe/src/test/java/org/apache/spark/unsafe/types/IntervalSuite.java diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index dedd8c8fa3620..d4ef04c2294a2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -353,22 +353,34 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { integral <~ intervalUnit("microsecond") ^^ { case num => num.toLong } protected lazy val millisecond: Parser[Long] = - integral <~ intervalUnit("millisecond") ^^ { case num => num.toLong * 1000 } + integral <~ intervalUnit("millisecond") ^^ { + case num => num.toLong * Interval.MICROS_PER_MILLI + } protected lazy val second: Parser[Long] = - integral <~ intervalUnit("second") ^^ { case num => num.toLong * 1000 * 1000 } + integral <~ intervalUnit("second") ^^ { + case num => num.toLong * Interval.MICROS_PER_SECOND + } protected lazy val minute: Parser[Long] = - integral <~ intervalUnit("minute") ^^ { case num => num.toLong * 1000 * 1000 * 60 } + integral <~ intervalUnit("minute") ^^ { + case num => num.toLong * Interval.MICROS_PER_MINUTE + } protected lazy val hour: Parser[Long] = - integral <~ intervalUnit("hour") ^^ { case num => num.toLong * 1000 * 1000 * 3600 } + integral <~ intervalUnit("hour") ^^ { + case num => num.toLong * Interval.MICROS_PER_HOUR + } protected lazy val day: Parser[Long] = - integral <~ intervalUnit("day") ^^ { case num => num.toLong * 1000 * 1000 * 3600 * 24 } + integral <~ intervalUnit("day") ^^ { + case num => num.toLong * Interval.MICROS_PER_DAY + } protected lazy val week: Parser[Long] = - integral <~ intervalUnit("week") ^^ { case num => num.toLong * 1000 * 1000 * 3600 * 24 * 7 } + integral <~ intervalUnit("week") ^^ { + case num => num.toLong * Interval.MICROS_PER_WEEK + } protected lazy val intervalLiteral: Parser[Literal] = INTERVAL ~> year.? ~ month.? ~ week.? ~ day.? ~ hour.? ~ minute.? ~ second.? ~ diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/Interval.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/Interval.java index 3eb67ede062d9..0af982d4844c2 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/types/Interval.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/Interval.java @@ -23,6 +23,13 @@ * The internal representation of interval type. */ public final class Interval implements Serializable { + public static final long MICROS_PER_MILLI = 1000L; + public static final long MICROS_PER_SECOND = MICROS_PER_MILLI * 1000; + public static final long MICROS_PER_MINUTE = MICROS_PER_SECOND * 60; + public static final long MICROS_PER_HOUR = MICROS_PER_MINUTE * 60; + public static final long MICROS_PER_DAY = MICROS_PER_HOUR * 24; + public static final long MICROS_PER_WEEK = MICROS_PER_DAY * 7; + public final int months; public final long microseconds; @@ -44,4 +51,39 @@ public boolean equals(Object other) { public int hashCode() { return 31 * months + (int) microseconds; } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder("interval"); + + if (months != 0) { + appendUnit(sb, months / 12, "year"); + appendUnit(sb, months % 12, "month"); + } + + if (microseconds != 0) { + long rest = microseconds; + appendUnit(sb, rest / MICROS_PER_WEEK, "week"); + rest %= MICROS_PER_WEEK; + appendUnit(sb, rest / MICROS_PER_DAY, "day"); + rest %= MICROS_PER_DAY; + appendUnit(sb, rest / MICROS_PER_HOUR, "hour"); + rest %= MICROS_PER_HOUR; + appendUnit(sb, rest / MICROS_PER_MINUTE, "minute"); + rest %= MICROS_PER_MINUTE; + appendUnit(sb, rest / MICROS_PER_SECOND, "second"); + rest %= MICROS_PER_SECOND; + appendUnit(sb, rest / MICROS_PER_MILLI, "millisecond"); + rest %= MICROS_PER_MILLI; + appendUnit(sb, rest, "microsecond"); + } + + return sb.toString(); + } + + private void appendUnit(StringBuilder sb, long value, String unit) { + if (value != 0) { + sb.append(" " + value + " " + unit + "s"); + } + } } diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/types/IntervalSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/types/IntervalSuite.java new file mode 100644 index 0000000000000..0f4f38b2b03be --- /dev/null +++ b/unsafe/src/test/java/org/apache/spark/unsafe/types/IntervalSuite.java @@ -0,0 +1,59 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.unsafe.types; + +import org.junit.Test; + +import static junit.framework.Assert.*; +import static org.apache.spark.unsafe.types.Interval.*; + +public class IntervalSuite { + + @Test + public void equalsTest() { + Interval i1 = new Interval(3, 123); + Interval i2 = new Interval(3, 321); + Interval i3 = new Interval(1, 123); + Interval i4 = new Interval(3, 123); + + assertNotSame(i1, i2); + assertNotSame(i1, i3); + assertNotSame(i2, i3); + assertEquals(i1, i4); + } + + @Test + public void toStringTest() { + Interval i; + + i = new Interval(34, 0); + assertEquals(i.toString(), "interval 2 years 10 months"); + + i = new Interval(-34, 0); + assertEquals(i.toString(), "interval -2 years -10 months"); + + i = new Interval(0, 3 * MICROS_PER_WEEK + 13 * MICROS_PER_HOUR + 123); + assertEquals(i.toString(), "interval 3 weeks 13 hours 123 microseconds"); + + i = new Interval(0, -3 * MICROS_PER_WEEK - 13 * MICROS_PER_HOUR - 123); + assertEquals(i.toString(), "interval -3 weeks -13 hours -123 microseconds"); + + i = new Interval(34, 3 * MICROS_PER_WEEK + 13 * MICROS_PER_HOUR + 123); + assertEquals(i.toString(), "interval 2 years 10 months 3 weeks 13 hours 123 microseconds"); + } +} From c59e268d17cf10e46dbdbe760e2a7580a6364692 Mon Sep 17 00:00:00 2001 From: JPark Date: Thu, 9 Jul 2015 10:23:36 -0700 Subject: [PATCH 08/23] [SPARK-8863] [EC2] Check aws access key from aws credentials if there is no boto config 'spark_ec2.py' use boto to control ec2. And boto can support '~/.aws/credentials' which is AWS CLI default configuration file. We can check this information from ref of boto. "A boto config file is a text file formatted like an .ini configuration file that specifies values for options that control the behavior of the boto library. In Unix/Linux systems, on startup, the boto library looks for configuration files in the following locations and in the following order: /etc/boto.cfg - for site-wide settings that all users on this machine will use (if profile is given) ~/.aws/credentials - for credentials shared between SDKs (if profile is given) ~/.boto - for user-specific settings ~/.aws/credentials - for credentials shared between SDKs ~/.boto - for user-specific settings" * ref of boto: http://boto.readthedocs.org/en/latest/boto_config_tut.html * ref of aws cli : http://docs.aws.amazon.com/cli/latest/userguide/cli-chap-getting-started.html However 'spark_ec2.py' only check boto config & environment variable even if there is '~/.aws/credentials', and 'spark_ec2.py' is terminated. So I changed to check '~/.aws/credentials'. cc rxin Jira : https://issues.apache.org/jira/browse/SPARK-8863 Author: JPark Closes #7252 from JuhongPark/master and squashes the following commits: 23c5792 [JPark] Check aws access key from aws credentials if there is no boto config --- ec2/spark_ec2.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py index dd0c12d25980b..ae4f2ecc5bde7 100755 --- a/ec2/spark_ec2.py +++ b/ec2/spark_ec2.py @@ -325,14 +325,16 @@ def parse_args(): home_dir = os.getenv('HOME') if home_dir is None or not os.path.isfile(home_dir + '/.boto'): if not os.path.isfile('/etc/boto.cfg'): - if os.getenv('AWS_ACCESS_KEY_ID') is None: - print("ERROR: The environment variable AWS_ACCESS_KEY_ID must be set", - file=stderr) - sys.exit(1) - if os.getenv('AWS_SECRET_ACCESS_KEY') is None: - print("ERROR: The environment variable AWS_SECRET_ACCESS_KEY must be set", - file=stderr) - sys.exit(1) + # If there is no boto config, check aws credentials + if not os.path.isfile(home_dir + '/.aws/credentials'): + if os.getenv('AWS_ACCESS_KEY_ID') is None: + print("ERROR: The environment variable AWS_ACCESS_KEY_ID must be set", + file=stderr) + sys.exit(1) + if os.getenv('AWS_SECRET_ACCESS_KEY') is None: + print("ERROR: The environment variable AWS_SECRET_ACCESS_KEY must be set", + file=stderr) + sys.exit(1) return (opts, action, cluster_name) From 0cd84c86cac68600a74d84e50ad40c0c8b84822a Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Thu, 9 Jul 2015 10:26:38 -0700 Subject: [PATCH 09/23] [SPARK-8703] [ML] Add CountVectorizer as a ml transformer to convert document to words count vector jira: https://issues.apache.org/jira/browse/SPARK-8703 Converts a text document to a sparse vector of token counts. I can further add an estimator to extract vocabulary from corpus if that's appropriate. Author: Yuhao Yang Closes #7084 from hhbyyh/countVectorization and squashes the following commits: 5f3f655 [Yuhao Yang] text change 24728e4 [Yuhao Yang] style improvement 576728a [Yuhao Yang] rename to model and some fix 1deca28 [Yuhao Yang] Merge remote-tracking branch 'upstream/master' into countVectorization 99b0c14 [Yuhao Yang] undo extension from HashingTF 12c2dc8 [Yuhao Yang] Merge remote-tracking branch 'upstream/master' into countVectorization 7ee1c31 [Yuhao Yang] extends HashingTF 809fb59 [Yuhao Yang] minor fix for ut 7c61fb3 [Yuhao Yang] add countVectorizer --- .../ml/feature/CountVectorizerModel.scala | 82 +++++++++++++++++++ .../ml/feature/CountVectorizorSuite.scala | 73 +++++++++++++++++ 2 files changed, 155 insertions(+) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizerModel.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizorSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizerModel.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizerModel.scala new file mode 100644 index 0000000000000..6b77de89a0330 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizerModel.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.ml.feature + +import scala.collection.mutable + +import org.apache.spark.annotation.Experimental +import org.apache.spark.ml.UnaryTransformer +import org.apache.spark.ml.param.{ParamMap, ParamValidators, IntParam} +import org.apache.spark.ml.util.Identifiable +import org.apache.spark.mllib.linalg.{Vectors, VectorUDT, Vector} +import org.apache.spark.sql.types.{StringType, ArrayType, DataType} + +/** + * :: Experimental :: + * Converts a text document to a sparse vector of token counts. + * @param vocabulary An Array over terms. Only the terms in the vocabulary will be counted. + */ +@Experimental +class CountVectorizerModel (override val uid: String, val vocabulary: Array[String]) + extends UnaryTransformer[Seq[String], Vector, CountVectorizerModel] { + + def this(vocabulary: Array[String]) = + this(Identifiable.randomUID("cntVec"), vocabulary) + + /** + * Corpus-specific filter to ignore scarce words in a document. For each document, terms with + * frequency (count) less than the given threshold are ignored. + * Default: 1 + * @group param + */ + val minTermFreq: IntParam = new IntParam(this, "minTermFreq", + "minimum frequency (count) filter used to neglect scarce words (>= 1). For each document, " + + "terms with frequency less than the given threshold are ignored.", ParamValidators.gtEq(1)) + + /** @group setParam */ + def setMinTermFreq(value: Int): this.type = set(minTermFreq, value) + + /** @group getParam */ + def getMinTermFreq: Int = $(minTermFreq) + + setDefault(minTermFreq -> 1) + + override protected def createTransformFunc: Seq[String] => Vector = { + val dict = vocabulary.zipWithIndex.toMap + document => + val termCounts = mutable.HashMap.empty[Int, Double] + document.foreach { term => + dict.get(term) match { + case Some(index) => termCounts.put(index, termCounts.getOrElse(index, 0.0) + 1.0) + case None => // ignore terms not in the vocabulary + } + } + Vectors.sparse(dict.size, termCounts.filter(_._2 >= $(minTermFreq)).toSeq) + } + + override protected def validateInputType(inputType: DataType): Unit = { + require(inputType.sameType(ArrayType(StringType)), + s"Input type must be ArrayType(StringType) but got $inputType.") + } + + override protected def outputDataType: DataType = new VectorUDT() + + override def copy(extra: ParamMap): CountVectorizerModel = { + val copied = new CountVectorizerModel(uid, vocabulary) + copyValues(copied, extra) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizorSuite.scala new file mode 100644 index 0000000000000..e90d9d4ef21ff --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizorSuite.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.ml.feature + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ + +class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext { + + test("params") { + ParamsSuite.checkParams(new CountVectorizerModel(Array("empty"))) + } + + test("CountVectorizerModel common cases") { + val df = sqlContext.createDataFrame(Seq( + (0, "a b c d".split(" ").toSeq, + Vectors.sparse(4, Seq((0, 1.0), (1, 1.0), (2, 1.0), (3, 1.0)))), + (1, "a b b c d a".split(" ").toSeq, + Vectors.sparse(4, Seq((0, 2.0), (1, 2.0), (2, 1.0), (3, 1.0)))), + (2, "a".split(" ").toSeq, Vectors.sparse(4, Seq((0, 1.0)))), + (3, "".split(" ").toSeq, Vectors.sparse(4, Seq())), // empty string + (4, "a notInDict d".split(" ").toSeq, + Vectors.sparse(4, Seq((0, 1.0), (3, 1.0)))) // with words not in vocabulary + )).toDF("id", "words", "expected") + val cv = new CountVectorizerModel(Array("a", "b", "c", "d")) + .setInputCol("words") + .setOutputCol("features") + val output = cv.transform(df).collect() + output.foreach { p => + val features = p.getAs[Vector]("features") + val expected = p.getAs[Vector]("expected") + assert(features ~== expected absTol 1e-14) + } + } + + test("CountVectorizerModel with minTermFreq") { + val df = sqlContext.createDataFrame(Seq( + (0, "a a a b b c c c d ".split(" ").toSeq, Vectors.sparse(4, Seq((0, 3.0), (2, 3.0)))), + (1, "c c c c c c".split(" ").toSeq, Vectors.sparse(4, Seq((2, 6.0)))), + (2, "a".split(" ").toSeq, Vectors.sparse(4, Seq())), + (3, "e e e e e".split(" ").toSeq, Vectors.sparse(4, Seq()))) + ).toDF("id", "words", "expected") + val cv = new CountVectorizerModel(Array("a", "b", "c", "d")) + .setInputCol("words") + .setOutputCol("features") + .setMinTermFreq(3) + val output = cv.transform(df).collect() + output.foreach { p => + val features = p.getAs[Vector]("features") + val expected = p.getAs[Vector]("expected") + assert(features ~== expected absTol 1e-14) + } + } +} + + From 0b0b9ceaf73de472198c9804fb7ae61fa2a2e097 Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Thu, 9 Jul 2015 11:11:34 -0700 Subject: [PATCH 10/23] [SPARK-8247] [SPARK-8249] [SPARK-8252] [SPARK-8254] [SPARK-8257] [SPARK-8258] [SPARK-8259] [SPARK-8261] [SPARK-8262] [SPARK-8253] [SPARK-8260] [SPARK-8267] [SQL] Add String Expressions Author: Cheng Hao Closes #6762 from chenghao-intel/str_funcs and squashes the following commits: b09a909 [Cheng Hao] update the code as feedback 7ebbf4c [Cheng Hao] Add more string expressions --- .../catalyst/analysis/FunctionRegistry.scala | 12 + .../expressions/stringOperations.scala | 306 ++++++++++++++- .../expressions/StringFunctionsSuite.scala | 138 +++++++ .../org/apache/spark/sql/functions.scala | 353 ++++++++++++++++++ .../spark/sql/DataFrameFunctionsSuite.scala | 132 ++++++- .../apache/spark/unsafe/types/UTF8String.java | 191 ++++++++++ .../spark/unsafe/types/UTF8StringSuite.java | 94 ++++- 7 files changed, 1202 insertions(+), 24 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 5c25181e1cf50..f62d79f8cea6d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -147,12 +147,24 @@ object FunctionRegistry { expression[Base64]("base64"), expression[Encode]("encode"), expression[Decode]("decode"), + expression[StringInstr]("instr"), expression[Lower]("lcase"), expression[Lower]("lower"), expression[StringLength]("length"), expression[Levenshtein]("levenshtein"), + expression[StringLocate]("locate"), + expression[StringLPad]("lpad"), + expression[StringTrimLeft]("ltrim"), + expression[StringFormat]("printf"), + expression[StringRPad]("rpad"), + expression[StringRepeat]("repeat"), + expression[StringReverse]("reverse"), + expression[StringTrimRight]("rtrim"), + expression[StringSpace]("space"), + expression[StringSplit]("split"), expression[Substring]("substr"), expression[Substring]("substring"), + expression[StringTrim]("trim"), expression[UnBase64]("unbase64"), expression[Upper]("ucase"), expression[Unhex]("unhex"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 57f436485becf..f64899c1ed84c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions +import java.util.Locale import java.util.regex.Pattern import org.apache.commons.lang3.StringUtils @@ -104,7 +105,7 @@ case class RLike(left: Expression, right: Expression) override def toString: String = s"$left RLIKE $right" } -trait CaseConversionExpression extends ExpectsInputTypes { +trait String2StringExpression extends ExpectsInputTypes { self: UnaryExpression => def convert(v: UTF8String): UTF8String @@ -119,7 +120,7 @@ trait CaseConversionExpression extends ExpectsInputTypes { /** * A function that converts the characters of a string to uppercase. */ -case class Upper(child: Expression) extends UnaryExpression with CaseConversionExpression { +case class Upper(child: Expression) extends UnaryExpression with String2StringExpression { override def convert(v: UTF8String): UTF8String = v.toUpperCase @@ -131,7 +132,7 @@ case class Upper(child: Expression) extends UnaryExpression with CaseConversionE /** * A function that converts the characters of a string to lowercase. */ -case class Lower(child: Expression) extends UnaryExpression with CaseConversionExpression { +case class Lower(child: Expression) extends UnaryExpression with String2StringExpression { override def convert(v: UTF8String): UTF8String = v.toLowerCase @@ -187,6 +188,301 @@ case class EndsWith(left: Expression, right: Expression) } } +/** + * A function that trim the spaces from both ends for the specified string. + */ +case class StringTrim(child: Expression) + extends UnaryExpression with String2StringExpression { + + def convert(v: UTF8String): UTF8String = v.trim() + + override def prettyName: String = "trim" + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + defineCodeGen(ctx, ev, c => s"($c).trim()") + } +} + +/** + * A function that trim the spaces from left end for given string. + */ +case class StringTrimLeft(child: Expression) + extends UnaryExpression with String2StringExpression { + + def convert(v: UTF8String): UTF8String = v.trimLeft() + + override def prettyName: String = "ltrim" + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + defineCodeGen(ctx, ev, c => s"($c).trimLeft()") + } +} + +/** + * A function that trim the spaces from right end for given string. + */ +case class StringTrimRight(child: Expression) + extends UnaryExpression with String2StringExpression { + + def convert(v: UTF8String): UTF8String = v.trimRight() + + override def prettyName: String = "rtrim" + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + defineCodeGen(ctx, ev, c => s"($c).trimRight()") + } +} + +/** + * A function that returns the position of the first occurrence of substr in the given string. + * Returns null if either of the arguments are null and + * returns 0 if substr could not be found in str. + * + * NOTE: that this is not zero based, but 1-based index. The first character in str has index 1. + */ +case class StringInstr(str: Expression, substr: Expression) + extends BinaryExpression with ExpectsInputTypes { + + override def left: Expression = str + override def right: Expression = substr + override def dataType: DataType = IntegerType + override def inputTypes: Seq[DataType] = Seq(StringType, StringType) + + override def nullSafeEval(string: Any, sub: Any): Any = { + string.asInstanceOf[UTF8String].indexOf(sub.asInstanceOf[UTF8String], 0) + 1 + } + + override def prettyName: String = "instr" + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + defineCodeGen(ctx, ev, (l, r) => + s"($l).indexOf($r, 0) + 1") + } +} + +/** + * A function that returns the position of the first occurrence of substr + * in given string after position pos. + */ +case class StringLocate(substr: Expression, str: Expression, start: Expression) + extends Expression with ExpectsInputTypes { + + def this(substr: Expression, str: Expression) = { + this(substr, str, Literal(0)) + } + + override def children: Seq[Expression] = substr :: str :: start :: Nil + override def foldable: Boolean = children.forall(_.foldable) + override def nullable: Boolean = substr.nullable || str.nullable + override def dataType: DataType = IntegerType + override def inputTypes: Seq[DataType] = Seq(StringType, StringType, IntegerType) + + override def eval(input: InternalRow): Any = { + val s = start.eval(input) + if (s == null) { + // if the start position is null, we need to return 0, (conform to Hive) + 0 + } else { + val r = substr.eval(input) + if (r == null) { + null + } else { + val l = str.eval(input) + if (l == null) { + null + } else { + l.asInstanceOf[UTF8String].indexOf( + r.asInstanceOf[UTF8String], + s.asInstanceOf[Int]) + 1 + } + } + } + } + + override def prettyName: String = "locate" +} + +/** + * Returns str, left-padded with pad to a length of len. + */ +case class StringLPad(str: Expression, len: Expression, pad: Expression) + extends Expression with ExpectsInputTypes { + + override def children: Seq[Expression] = str :: len :: pad :: Nil + override def foldable: Boolean = children.forall(_.foldable) + override def nullable: Boolean = children.exists(_.nullable) + override def dataType: DataType = StringType + override def inputTypes: Seq[DataType] = Seq(StringType, IntegerType, StringType) + + override def eval(input: InternalRow): Any = { + val s = str.eval(input) + if (s == null) { + null + } else { + val l = len.eval(input) + if (l == null) { + null + } else { + val p = pad.eval(input) + if (p == null) { + null + } else { + val len = l.asInstanceOf[Int] + val str = s.asInstanceOf[UTF8String] + val pad = p.asInstanceOf[UTF8String] + + str.lpad(len, pad) + } + } + } + } + + override def prettyName: String = "lpad" +} + +/** + * Returns str, right-padded with pad to a length of len. + */ +case class StringRPad(str: Expression, len: Expression, pad: Expression) + extends Expression with ExpectsInputTypes { + + override def children: Seq[Expression] = str :: len :: pad :: Nil + override def foldable: Boolean = children.forall(_.foldable) + override def nullable: Boolean = children.exists(_.nullable) + override def dataType: DataType = StringType + override def inputTypes: Seq[DataType] = Seq(StringType, IntegerType, StringType) + + override def eval(input: InternalRow): Any = { + val s = str.eval(input) + if (s == null) { + null + } else { + val l = len.eval(input) + if (l == null) { + null + } else { + val p = pad.eval(input) + if (p == null) { + null + } else { + val len = l.asInstanceOf[Int] + val str = s.asInstanceOf[UTF8String] + val pad = p.asInstanceOf[UTF8String] + + str.rpad(len, pad) + } + } + } + } + + override def prettyName: String = "rpad" +} + +/** + * Returns the input formatted according do printf-style format strings + */ +case class StringFormat(children: Expression*) extends Expression { + + require(children.length >=1, "printf() should take at least 1 argument") + + override def foldable: Boolean = children.forall(_.foldable) + override def nullable: Boolean = children(0).nullable + override def dataType: DataType = StringType + private def format: Expression = children(0) + private def args: Seq[Expression] = children.tail + + override def eval(input: InternalRow): Any = { + val pattern = format.eval(input) + if (pattern == null) { + null + } else { + val sb = new StringBuffer() + val formatter = new java.util.Formatter(sb, Locale.US) + + val arglist = args.map(_.eval(input).asInstanceOf[AnyRef]) + formatter.format(pattern.asInstanceOf[UTF8String].toString(), arglist: _*) + + UTF8String.fromString(sb.toString) + } + } + + override def prettyName: String = "printf" +} + +/** + * Returns the string which repeat the given string value n times. + */ +case class StringRepeat(str: Expression, times: Expression) + extends BinaryExpression with ExpectsInputTypes { + + override def left: Expression = str + override def right: Expression = times + override def dataType: DataType = StringType + override def inputTypes: Seq[DataType] = Seq(StringType, IntegerType) + + override def nullSafeEval(string: Any, n: Any): Any = { + string.asInstanceOf[UTF8String].repeat(n.asInstanceOf[Integer]) + } + + override def prettyName: String = "repeat" + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + defineCodeGen(ctx, ev, (l, r) => s"($l).repeat($r)") + } +} + +/** + * Returns the reversed given string. + */ +case class StringReverse(child: Expression) extends UnaryExpression with String2StringExpression { + override def convert(v: UTF8String): UTF8String = v.reverse() + + override def prettyName: String = "reverse" + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + defineCodeGen(ctx, ev, c => s"($c).reverse()") + } +} + +/** + * Returns a n spaces string. + */ +case class StringSpace(child: Expression) extends UnaryExpression with ExpectsInputTypes { + + override def dataType: DataType = StringType + override def inputTypes: Seq[DataType] = Seq(IntegerType) + + override def nullSafeEval(s: Any): Any = { + val length = s.asInstanceOf[Integer] + + val spaces = new Array[Byte](if (length < 0) 0 else length) + java.util.Arrays.fill(spaces, ' '.asInstanceOf[Byte]) + UTF8String.fromBytes(spaces) + } + + override def prettyName: String = "space" +} + +/** + * Splits str around pat (pattern is a regular expression). + */ +case class StringSplit(str: Expression, pattern: Expression) + extends BinaryExpression with ExpectsInputTypes { + + override def left: Expression = str + override def right: Expression = pattern + override def dataType: DataType = ArrayType(StringType) + override def inputTypes: Seq[DataType] = Seq(StringType, StringType) + + override def nullSafeEval(string: Any, regex: Any): Any = { + val splits = + string.asInstanceOf[UTF8String].toString.split(regex.asInstanceOf[UTF8String].toString, -1) + splits.toSeq.map(UTF8String.fromString) + } + + override def prettyName: String = "split" +} + /** * A function that takes a substring of its first argument starting at a given position. * Defined for String and Binary types. @@ -199,8 +495,7 @@ case class Substring(str: Expression, pos: Expression, len: Expression) } override def foldable: Boolean = str.foldable && pos.foldable && len.foldable - - override def nullable: Boolean = str.nullable || pos.nullable || len.nullable + override def nullable: Boolean = str.nullable || pos.nullable || len.nullable override def dataType: DataType = { if (!resolved) { @@ -373,4 +668,3 @@ case class Encode(value: Expression, charset: Expression) } } - diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala index 69bef1c63e9dc..b19f4ee37a109 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala @@ -288,4 +288,142 @@ class StringFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Levenshtein(Literal("世界千世"), Literal("大a界b")), 4) // scalastyle:on } + + test("TRIM/LTRIM/RTRIM") { + val s = 'a.string.at(0) + checkEvaluation(StringTrim(Literal(" aa ")), "aa", create_row(" abdef ")) + checkEvaluation(StringTrim(s), "abdef", create_row(" abdef ")) + + checkEvaluation(StringTrimLeft(Literal(" aa ")), "aa ", create_row(" abdef ")) + checkEvaluation(StringTrimLeft(s), "abdef ", create_row(" abdef ")) + + checkEvaluation(StringTrimRight(Literal(" aa ")), " aa", create_row(" abdef ")) + checkEvaluation(StringTrimRight(s), " abdef", create_row(" abdef ")) + + // scalastyle:off + // non ascii characters are not allowed in the source code, so we disable the scalastyle. + checkEvaluation(StringTrimRight(s), " 花花世界", create_row(" 花花世界 ")) + checkEvaluation(StringTrimLeft(s), "花花世界 ", create_row(" 花花世界 ")) + checkEvaluation(StringTrim(s), "花花世界", create_row(" 花花世界 ")) + // scalastyle:on + } + + test("FORMAT") { + val f = 'f.string.at(0) + val d1 = 'd.int.at(1) + val s1 = 's.int.at(2) + + val row1 = create_row("aa%d%s", 12, "cc") + val row2 = create_row(null, 12, "cc") + checkEvaluation(StringFormat(Literal("aa%d%s"), Literal(123), Literal("a")), "aa123a", row1) + checkEvaluation(StringFormat(Literal("aa")), "aa", create_row(null)) + checkEvaluation(StringFormat(Literal("aa%d%s"), Literal(123), Literal("a")), "aa123a", row1) + + checkEvaluation(StringFormat(f, d1, s1), "aa12cc", row1) + checkEvaluation(StringFormat(f, d1, s1), null, row2) + } + + test("INSTR") { + val s1 = 'a.string.at(0) + val s2 = 'b.string.at(1) + val s3 = 'c.string.at(2) + val row1 = create_row("aaads", "aa", "zz") + + checkEvaluation(StringInstr(Literal("aaads"), Literal("aa")), 1, row1) + checkEvaluation(StringInstr(Literal("aaads"), Literal("de")), 0, row1) + checkEvaluation(StringInstr(Literal.create(null, StringType), Literal("de")), null, row1) + checkEvaluation(StringInstr(Literal("aaads"), Literal.create(null, StringType)), null, row1) + + checkEvaluation(StringInstr(s1, s2), 1, row1) + checkEvaluation(StringInstr(s1, s3), 0, row1) + + // scalastyle:off + // non ascii characters are not allowed in the source code, so we disable the scalastyle. + checkEvaluation(StringInstr(s1, s2), 3, create_row("花花世界", "世界")) + checkEvaluation(StringInstr(s1, s2), 1, create_row("花花世界", "花")) + checkEvaluation(StringInstr(s1, s2), 0, create_row("花花世界", "小")) + // scalastyle:on + } + + test("LOCATE") { + val s1 = 'a.string.at(0) + val s2 = 'b.string.at(1) + val s3 = 'c.string.at(2) + val s4 = 'd.int.at(3) + val row1 = create_row("aaads", "aa", "zz", 1) + + checkEvaluation(new StringLocate(Literal("aa"), Literal("aaads")), 1, row1) + checkEvaluation(StringLocate(Literal("aa"), Literal("aaads"), Literal(1)), 2, row1) + checkEvaluation(StringLocate(Literal("aa"), Literal("aaads"), Literal(2)), 0, row1) + checkEvaluation(new StringLocate(Literal("de"), Literal("aaads")), 0, row1) + checkEvaluation(StringLocate(Literal("de"), Literal("aaads"), 1), 0, row1) + + checkEvaluation(new StringLocate(s2, s1), 1, row1) + checkEvaluation(StringLocate(s2, s1, s4), 2, row1) + checkEvaluation(new StringLocate(s3, s1), 0, row1) + checkEvaluation(StringLocate(s3, s1, Literal.create(null, IntegerType)), 0, row1) + } + + test("LPAD/RPAD") { + val s1 = 'a.string.at(0) + val s2 = 'b.int.at(1) + val s3 = 'c.string.at(2) + val row1 = create_row("hi", 5, "??") + val row2 = create_row("hi", 1, "?") + val row3 = create_row(null, 1, "?") + + checkEvaluation(StringLPad(Literal("hi"), Literal(5), Literal("??")), "???hi", row1) + checkEvaluation(StringLPad(Literal("hi"), Literal(1), Literal("??")), "h", row1) + checkEvaluation(StringLPad(s1, s2, s3), "???hi", row1) + checkEvaluation(StringLPad(s1, s2, s3), "h", row2) + checkEvaluation(StringLPad(s1, s2, s3), null, row3) + + checkEvaluation(StringRPad(Literal("hi"), Literal(5), Literal("??")), "hi???", row1) + checkEvaluation(StringRPad(Literal("hi"), Literal(1), Literal("??")), "h", row1) + checkEvaluation(StringRPad(s1, s2, s3), "hi???", row1) + checkEvaluation(StringRPad(s1, s2, s3), "h", row2) + checkEvaluation(StringRPad(s1, s2, s3), null, row3) + } + + test("REPEAT") { + val s1 = 'a.string.at(0) + val s2 = 'b.int.at(1) + val row1 = create_row("hi", 2) + val row2 = create_row(null, 1) + + checkEvaluation(StringRepeat(Literal("hi"), Literal(2)), "hihi", row1) + checkEvaluation(StringRepeat(Literal("hi"), Literal(-1)), "", row1) + checkEvaluation(StringRepeat(s1, s2), "hihi", row1) + checkEvaluation(StringRepeat(s1, s2), null, row2) + } + + test("REVERSE") { + val s = 'a.string.at(0) + val row1 = create_row("abccc") + checkEvaluation(StringReverse(Literal("abccc")), "cccba", row1) + checkEvaluation(StringReverse(s), "cccba", row1) + } + + test("SPACE") { + val s1 = 'b.int.at(0) + val row1 = create_row(2) + val row2 = create_row(null) + + checkEvaluation(StringSpace(Literal(2)), " ", row1) + checkEvaluation(StringSpace(Literal(-1)), "", row1) + checkEvaluation(StringSpace(Literal(0)), "", row1) + checkEvaluation(StringSpace(s1), " ", row1) + checkEvaluation(StringSpace(s1), null, row2) + } + + test("SPLIT") { + val s1 = 'a.string.at(0) + val s2 = 'b.string.at(1) + val row1 = create_row("aa2bb3cc", "[1-9]+") + + checkEvaluation( + StringSplit(Literal("aa2bb3cc"), Literal("[1-9]+")), Seq("aa", "bb", "cc"), row1) + checkEvaluation( + StringSplit(s1, s2), Seq("aa", "bb", "cc"), row1) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 4da9ffc495e17..08bf37a5c223c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1626,6 +1626,179 @@ object functions { */ def ascii(columnName: String): Column = ascii(Column(columnName)) + /** + * Trim the spaces from both ends for the specified string value. + * + * @group string_funcs + * @since 1.5.0 + */ + def trim(e: Column): Column = StringTrim(e.expr) + + /** + * Trim the spaces from both ends for the specified column. + * + * @group string_funcs + * @since 1.5.0 + */ + def trim(columnName: String): Column = trim(Column(columnName)) + + /** + * Trim the spaces from left end for the specified string value. + * + * @group string_funcs + * @since 1.5.0 + */ + def ltrim(e: Column): Column = StringTrimLeft(e.expr) + + /** + * Trim the spaces from left end for the specified column. + * + * @group string_funcs + * @since 1.5.0 + */ + def ltrim(columnName: String): Column = ltrim(Column(columnName)) + + /** + * Trim the spaces from right end for the specified string value. + * + * @group string_funcs + * @since 1.5.0 + */ + def rtrim(e: Column): Column = StringTrimRight(e.expr) + + /** + * Trim the spaces from right end for the specified column. + * + * @group string_funcs + * @since 1.5.0 + */ + def rtrim(columnName: String): Column = rtrim(Column(columnName)) + + /** + * Format strings in printf-style. + * + * @group string_funcs + * @since 1.5.0 + */ + @scala.annotation.varargs + def formatString(format: Column, arguments: Column*): Column = { + StringFormat((format +: arguments).map(_.expr): _*) + } + + /** + * Format strings in printf-style. + * NOTE: `format` is the string value of the formatter, not column name. + * + * @group string_funcs + * @since 1.5.0 + */ + @scala.annotation.varargs + def formatString(format: String, arguNames: String*): Column = { + StringFormat(lit(format).expr +: arguNames.map(Column(_).expr): _*) + } + + /** + * Locate the position of the first occurrence of substr value in the given string. + * Returns null if either of the arguments are null. + * + * NOTE: The position is not zero based, but 1 based index, returns 0 if substr + * could not be found in str. + * + * @group string_funcs + * @since 1.5.0 + */ + def instr(substr: String, sub: String): Column = instr(Column(substr), Column(sub)) + + /** + * Locate the position of the first occurrence of substr column in the given string. + * Returns null if either of the arguments are null. + * + * NOTE: The position is not zero based, but 1 based index, returns 0 if substr + * could not be found in str. + * + * @group string_funcs + * @since 1.5.0 + */ + def instr(substr: Column, sub: Column): Column = StringInstr(substr.expr, sub.expr) + + /** + * Locate the position of the first occurrence of substr. + * + * NOTE: The position is not zero based, but 1 based index, returns 0 if substr + * could not be found in str. + * + * @group string_funcs + * @since 1.5.0 + */ + def locate(substr: String, str: String): Column = { + locate(Column(substr), Column(str)) + } + + /** + * Locate the position of the first occurrence of substr. + * + * NOTE: The position is not zero based, but 1 based index, returns 0 if substr + * could not be found in str. + * + * @group string_funcs + * @since 1.5.0 + */ + def locate(substr: Column, str: Column): Column = { + new StringLocate(substr.expr, str.expr) + } + + /** + * Locate the position of the first occurrence of substr in a given string after position pos. + * + * NOTE: The position is not zero based, but 1 based index, returns 0 if substr + * could not be found in str. + * + * @group string_funcs + * @since 1.5.0 + */ + def locate(substr: String, str: String, pos: String): Column = { + locate(Column(substr), Column(str), Column(pos)) + } + + /** + * Locate the position of the first occurrence of substr in a given string after position pos. + * + * NOTE: The position is not zero based, but 1 based index, returns 0 if substr + * could not be found in str. + * + * @group string_funcs + * @since 1.5.0 + */ + def locate(substr: Column, str: Column, pos: Column): Column = { + StringLocate(substr.expr, str.expr, pos.expr) + } + + /** + * Locate the position of the first occurrence of substr in a given string after position pos. + * + * NOTE: The position is not zero based, but 1 based index, returns 0 if substr + * could not be found in str. + * + * @group string_funcs + * @since 1.5.0 + */ + def locate(substr: Column, str: Column, pos: Int): Column = { + StringLocate(substr.expr, str.expr, lit(pos).expr) + } + + /** + * Locate the position of the first occurrence of substr in a given string after position pos. + * + * NOTE: The position is not zero based, but 1 based index, returns 0 if substr + * could not be found in str. + * + * @group string_funcs + * @since 1.5.0 + */ + def locate(substr: String, str: String, pos: Int): Column = { + locate(Column(substr), Column(str), lit(pos)) + } + /** * Computes the specified value from binary to a base64 string. * @@ -1658,6 +1831,46 @@ object functions { */ def unbase64(columnName: String): Column = unbase64(Column(columnName)) + /** + * Left-padded with pad to a length of len. + * + * @group string_funcs + * @since 1.5.0 + */ + def lpad(str: String, len: String, pad: String): Column = { + lpad(Column(str), Column(len), Column(pad)) + } + + /** + * Left-padded with pad to a length of len. + * + * @group string_funcs + * @since 1.5.0 + */ + def lpad(str: Column, len: Column, pad: Column): Column = { + StringLPad(str.expr, len.expr, pad.expr) + } + + /** + * Left-padded with pad to a length of len. + * + * @group string_funcs + * @since 1.5.0 + */ + def lpad(str: Column, len: Int, pad: Column): Column = { + StringLPad(str.expr, lit(len).expr, pad.expr) + } + + /** + * Left-padded with pad to a length of len. + * + * @group string_funcs + * @since 1.5.0 + */ + def lpad(str: String, len: Int, pad: String): Column = { + lpad(Column(str), len, Column(pad)) + } + /** * Computes the first argument into a binary from a string using the provided character set * (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16'). @@ -1702,6 +1915,146 @@ object functions { def decode(columnName: String, charset: String): Column = decode(Column(columnName), charset) + /** + * Right-padded with pad to a length of len. + * + * @group string_funcs + * @since 1.5.0 + */ + def rpad(str: String, len: String, pad: String): Column = { + rpad(Column(str), Column(len), Column(pad)) + } + + /** + * Right-padded with pad to a length of len. + * + * @group string_funcs + * @since 1.5.0 + */ + def rpad(str: Column, len: Column, pad: Column): Column = { + StringRPad(str.expr, len.expr, pad.expr) + } + + /** + * Right-padded with pad to a length of len. + * + * @group string_funcs + * @since 1.5.0 + */ + def rpad(str: String, len: Int, pad: String): Column = { + rpad(Column(str), len, Column(pad)) + } + + /** + * Right-padded with pad to a length of len. + * + * @group string_funcs + * @since 1.5.0 + */ + def rpad(str: Column, len: Int, pad: Column): Column = { + StringRPad(str.expr, lit(len).expr, pad.expr) + } + + /** + * Repeat the string value of the specified column n times. + * + * @group string_funcs + * @since 1.5.0 + */ + def repeat(strColumn: String, timesColumn: String): Column = { + repeat(Column(strColumn), Column(timesColumn)) + } + + /** + * Repeat the string expression value n times. + * + * @group string_funcs + * @since 1.5.0 + */ + def repeat(str: Column, times: Column): Column = { + StringRepeat(str.expr, times.expr) + } + + /** + * Repeat the string value of the specified column n times. + * + * @group string_funcs + * @since 1.5.0 + */ + def repeat(strColumn: String, times: Int): Column = { + repeat(Column(strColumn), times) + } + + /** + * Repeat the string expression value n times. + * + * @group string_funcs + * @since 1.5.0 + */ + def repeat(str: Column, times: Int): Column = { + StringRepeat(str.expr, lit(times).expr) + } + + /** + * Splits str around pattern (pattern is a regular expression). + * + * @group string_funcs + * @since 1.5.0 + */ + def split(strColumnName: String, pattern: String): Column = { + split(Column(strColumnName), pattern) + } + + /** + * Splits str around pattern (pattern is a regular expression). + * NOTE: pattern is a string represent the regular expression. + * + * @group string_funcs + * @since 1.5.0 + */ + def split(str: Column, pattern: String): Column = { + StringSplit(str.expr, lit(pattern).expr) + } + + /** + * Reversed the string for the specified column. + * + * @group string_funcs + * @since 1.5.0 + */ + def reverse(str: String): Column = { + reverse(Column(str)) + } + + /** + * Reversed the string for the specified value. + * + * @group string_funcs + * @since 1.5.0 + */ + def reverse(str: Column): Column = { + StringReverse(str.expr) + } + + /** + * Make a n spaces of string. + * + * @group string_funcs + * @since 1.5.0 + */ + def space(n: String): Column = { + space(Column(n)) + } + + /** + * Make a n spaces of string. + * + * @group string_funcs + * @since 1.5.0 + */ + def space(n: Column): Column = { + StringSpace(n.expr) + } ////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index afba28515e032..173280375c411 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -209,21 +209,14 @@ class DataFrameFunctionsSuite extends QueryTest { } test("string length function") { + val df = Seq(("abc", "")).toDF("a", "b") checkAnswer( - nullStrings.select(strlen($"s"), strlen("s")), - nullStrings.collect().toSeq.map { r => - val v = r.getString(1) - val l = if (v == null) null else v.length - Row(l, l) - }) + df.select(strlen($"a"), strlen("b")), + Row(3, 0)) checkAnswer( - nullStrings.selectExpr("length(s)"), - nullStrings.collect().toSeq.map { r => - val v = r.getString(1) - val l = if (v == null) null else v.length - Row(l) - }) + df.selectExpr("length(a)", "length(b)"), + Row(3, 0)) } test("Levenshtein distance") { @@ -273,4 +266,119 @@ class DataFrameFunctionsSuite extends QueryTest { Row(bytes, "大千世界")) // scalastyle:on } + + test("string trim functions") { + val df = Seq((" example ", "")).toDF("a", "b") + + checkAnswer( + df.select(ltrim($"a"), rtrim($"a"), trim($"a")), + Row("example ", " example", "example")) + + checkAnswer( + df.selectExpr("ltrim(a)", "rtrim(a)", "trim(a)"), + Row("example ", " example", "example")) + } + + test("string formatString function") { + val df = Seq(("aa%d%s", 123, "cc")).toDF("a", "b", "c") + + checkAnswer( + df.select(formatString($"a", $"b", $"c"), formatString("aa%d%s", "b", "c")), + Row("aa123cc", "aa123cc")) + + checkAnswer( + df.selectExpr("printf(a, b, c)"), + Row("aa123cc")) + } + + test("string instr function") { + val df = Seq(("aaads", "aa", "zz")).toDF("a", "b", "c") + + checkAnswer( + df.select(instr($"a", $"b"), instr("a", "b")), + Row(1, 1)) + + checkAnswer( + df.selectExpr("instr(a, b)"), + Row(1)) + } + + test("string locate function") { + val df = Seq(("aaads", "aa", "zz", 1)).toDF("a", "b", "c", "d") + + checkAnswer( + df.select( + locate($"b", $"a"), locate("b", "a"), locate($"b", $"a", 1), + locate("b", "a", 1), locate($"b", $"a", $"d"), locate("b", "a", "d")), + Row(1, 1, 2, 2, 2, 2)) + + checkAnswer( + df.selectExpr("locate(b, a)", "locate(b, a, d)"), + Row(1, 2)) + } + + test("string padding functions") { + val df = Seq(("hi", 5, "??")).toDF("a", "b", "c") + + checkAnswer( + df.select( + lpad($"a", $"b", $"c"), rpad("a", "b", "c"), + lpad($"a", 1, $"c"), rpad("a", 1, "c")), + Row("???hi", "hi???", "h", "h")) + + checkAnswer( + df.selectExpr("lpad(a, b, c)", "rpad(a, b, c)", "lpad(a, 1, c)", "rpad(a, 1, c)"), + Row("???hi", "hi???", "h", "h")) + } + + test("string repeat function") { + val df = Seq(("hi", 2)).toDF("a", "b") + + checkAnswer( + df.select( + repeat($"a", 2), repeat("a", 2), repeat($"a", $"b"), repeat("a", "b")), + Row("hihi", "hihi", "hihi", "hihi")) + + checkAnswer( + df.selectExpr("repeat(a, 2)", "repeat(a, b)"), + Row("hihi", "hihi")) + } + + test("string reverse function") { + val df = Seq(("hi", "hhhi")).toDF("a", "b") + + checkAnswer( + df.select(reverse($"a"), reverse("b")), + Row("ih", "ihhh")) + + checkAnswer( + df.selectExpr("reverse(b)"), + Row("ihhh")) + } + + test("string space function") { + val df = Seq((2, 3)).toDF("a", "b") + + checkAnswer( + df.select(space($"a"), space("b")), + Row(" ", " ")) + + checkAnswer( + df.selectExpr("space(b)"), + Row(" ")) + } + + test("string split function") { + val df = Seq(("aa2bb3cc", "[1-9]+")).toDF("a", "b") + + checkAnswer( + df.select( + split($"a", "[1-9]+"), + split("a", "[1-9]+")), + Row(Seq("aa", "bb", "cc"), Seq("aa", "bb", "cc"))) + + checkAnswer( + df.selectExpr("split(a, '[1-9]+')"), + Row(Seq("aa", "bb", "cc"))) + } } diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index 847d80ad583f6..60d050b0a0c97 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -25,6 +25,7 @@ import static org.apache.spark.unsafe.PlatformDependent.*; + /** * A UTF-8 String for internal Spark use. *

@@ -204,6 +205,196 @@ public UTF8String toLowerCase() { return fromString(toString().toLowerCase()); } + /** + * Copy the bytes from the current UTF8String, and make a new UTF8String. + * @param start the start position of the current UTF8String in bytes. + * @param end the end position of the current UTF8String in bytes. + * @return a new UTF8String in the position of [start, end] of current UTF8String bytes. + */ + private UTF8String copyUTF8String(int start, int end) { + int len = end - start + 1; + byte[] newBytes = new byte[len]; + copyMemory(base, offset + start, newBytes, BYTE_ARRAY_OFFSET, len); + return UTF8String.fromBytes(newBytes); + } + + public UTF8String trim() { + int s = 0; + int e = this.numBytes - 1; + // skip all of the space (0x20) in the left side + while (s < this.numBytes && getByte(s) == 0x20) s++; + // skip all of the space (0x20) in the right side + while (e >= 0 && getByte(e) == 0x20) e--; + + if (s > e) { + // empty string + return UTF8String.fromBytes(new byte[0]); + } else { + return copyUTF8String(s, e); + } + } + + public UTF8String trimLeft() { + int s = 0; + // skip all of the space (0x20) in the left side + while (s < this.numBytes && getByte(s) == 0x20) s++; + if (s == this.numBytes) { + // empty string + return UTF8String.fromBytes(new byte[0]); + } else { + return copyUTF8String(s, this.numBytes - 1); + } + } + + public UTF8String trimRight() { + int e = numBytes - 1; + // skip all of the space (0x20) in the right side + while (e >= 0 && getByte(e) == 0x20) e--; + + if (e < 0) { + // empty string + return UTF8String.fromBytes(new byte[0]); + } else { + return copyUTF8String(0, e); + } + } + + public UTF8String reverse() { + byte[] bytes = getBytes(); + byte[] result = new byte[bytes.length]; + + int i = 0; // position in byte + while (i < numBytes) { + int len = numBytesForFirstByte(getByte(i)); + System.arraycopy(bytes, i, result, result.length - i - len, len); + + i += len; + } + + return UTF8String.fromBytes(result); + } + + public UTF8String repeat(int times) { + if (times <=0) { + return fromBytes(new byte[0]); + } + + byte[] newBytes = new byte[numBytes * times]; + System.arraycopy(getBytes(), 0, newBytes, 0, numBytes); + + int copied = 1; + while (copied < times) { + int toCopy = Math.min(copied, times - copied); + System.arraycopy(newBytes, 0, newBytes, copied * numBytes, numBytes * toCopy); + copied += toCopy; + } + + return UTF8String.fromBytes(newBytes); + } + + /** + * Returns the position of the first occurrence of substr in + * current string from the specified position (0-based index). + * + * @param v the string to be searched + * @param start the start position of the current string for searching + * @return the position of the first occurrence of substr, if not found, -1 returned. + */ + public int indexOf(UTF8String v, int start) { + if (v.numBytes() == 0) { + return 0; + } + + // locate to the start position. + int i = 0; // position in byte + int c = 0; // position in character + while (i < numBytes && c < start) { + i += numBytesForFirstByte(getByte(i)); + c += 1; + } + + do { + if (i + v.numBytes > numBytes) { + return -1; + } + if (ByteArrayMethods.arrayEquals(base, offset + i, v.base, v.offset, v.numBytes)) { + return c; + } + i += numBytesForFirstByte(getByte(i)); + c += 1; + } while(i < numBytes); + + return -1; + } + + /** + * Returns str, right-padded with pad to a length of len + * For example: + * ('hi', 5, '??') => 'hi???' + * ('hi', 1, '??') => 'h' + */ + public UTF8String rpad(int len, UTF8String pad) { + int spaces = len - this.numChars(); // number of char need to pad + if (spaces <= 0) { + // no padding at all, return the substring of the current string + return substring(0, len); + } else { + int padChars = pad.numChars(); + int count = spaces / padChars; // how many padding string needed + // the partial string of the padding + UTF8String remain = pad.substring(0, spaces - padChars * count); + + byte[] data = new byte[this.numBytes + pad.numBytes * count + remain.numBytes]; + System.arraycopy(getBytes(), 0, data, 0, this.numBytes); + int offset = this.numBytes; + int idx = 0; + byte[] padBytes = pad.getBytes(); + while (idx < count) { + System.arraycopy(padBytes, 0, data, offset, pad.numBytes); + ++idx; + offset += pad.numBytes; + } + System.arraycopy(remain.getBytes(), 0, data, offset, remain.numBytes); + + return UTF8String.fromBytes(data); + } + } + + /** + * Returns str, left-padded with pad to a length of len. + * For example: + * ('hi', 5, '??') => '???hi' + * ('hi', 1, '??') => 'h' + */ + public UTF8String lpad(int len, UTF8String pad) { + int spaces = len - this.numChars(); // number of char need to pad + if (spaces <= 0) { + // no padding at all, return the substring of the current string + return substring(0, len); + } else { + int padChars = pad.numChars(); + int count = spaces / padChars; // how many padding string needed + // the partial string of the padding + UTF8String remain = pad.substring(0, spaces - padChars * count); + + byte[] data = new byte[this.numBytes + pad.numBytes * count + remain.numBytes]; + + int offset = 0; + int idx = 0; + byte[] padBytes = pad.getBytes(); + while (idx < count) { + System.arraycopy(padBytes, 0, data, offset, pad.numBytes); + ++idx; + offset += pad.numBytes; + } + System.arraycopy(remain.getBytes(), 0, data, offset, remain.numBytes); + offset += remain.numBytes; + System.arraycopy(getBytes(), 0, data, offset, numBytes()); + + return UTF8String.fromBytes(data); + } + } + @Override public String toString() { try { diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java index fb463ba17f50b..694bdc29f39d1 100644 --- a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java +++ b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java @@ -121,12 +121,94 @@ public void endsWith() { @Test public void substring() { - assertEquals(fromString("hello").substring(0, 0), fromString("")); - assertEquals(fromString("hello").substring(1, 3), fromString("el")); - assertEquals(fromString("数据砖头").substring(0, 1), fromString("数")); - assertEquals(fromString("数据砖头").substring(1, 3), fromString("据砖")); - assertEquals(fromString("数据砖头").substring(3, 5), fromString("头")); - assertEquals(fromString("ߵ梷").substring(0, 2), fromString("ߵ梷")); + assertEquals(fromString(""), fromString("hello").substring(0, 0)); + assertEquals(fromString("el"), fromString("hello").substring(1, 3)); + assertEquals(fromString("数"), fromString("数据砖头").substring(0, 1)); + assertEquals(fromString("据砖"), fromString("数据砖头").substring(1, 3)); + assertEquals(fromString("头"), fromString("数据砖头").substring(3, 5)); + assertEquals(fromString("ߵ梷"), fromString("ߵ梷").substring(0, 2)); + } + + @Test + public void trims() { + assertEquals(fromString("hello"), fromString(" hello ").trim()); + assertEquals(fromString("hello "), fromString(" hello ").trimLeft()); + assertEquals(fromString(" hello"), fromString(" hello ").trimRight()); + + assertEquals(fromString(""), fromString(" ").trim()); + assertEquals(fromString(""), fromString(" ").trimLeft()); + assertEquals(fromString(""), fromString(" ").trimRight()); + + assertEquals(fromString("数据砖头"), fromString(" 数据砖头 ").trim()); + assertEquals(fromString("数据砖头 "), fromString(" 数据砖头 ").trimLeft()); + assertEquals(fromString(" 数据砖头"), fromString(" 数据砖头 ").trimRight()); + + assertEquals(fromString("数据砖头"), fromString("数据砖头").trim()); + assertEquals(fromString("数据砖头"), fromString("数据砖头").trimLeft()); + assertEquals(fromString("数据砖头"), fromString("数据砖头").trimRight()); + } + + @Test + public void indexOf() { + assertEquals(0, fromString("").indexOf(fromString(""), 0)); + assertEquals(-1, fromString("").indexOf(fromString("l"), 0)); + assertEquals(0, fromString("hello").indexOf(fromString(""), 0)); + assertEquals(2, fromString("hello").indexOf(fromString("l"), 0)); + assertEquals(3, fromString("hello").indexOf(fromString("l"), 3)); + assertEquals(-1, fromString("hello").indexOf(fromString("a"), 0)); + assertEquals(2, fromString("hello").indexOf(fromString("ll"), 0)); + assertEquals(-1, fromString("hello").indexOf(fromString("ll"), 4)); + assertEquals(1, fromString("数据砖头").indexOf(fromString("据砖"), 0)); + assertEquals(-1, fromString("数据砖头").indexOf(fromString("数"), 3)); + assertEquals(0, fromString("数据砖头").indexOf(fromString("数"), 0)); + assertEquals(3, fromString("数据砖头").indexOf(fromString("头"), 0)); + } + + @Test + public void reverse() { + assertEquals(fromString("olleh"), fromString("hello").reverse()); + assertEquals(fromString(""), fromString("").reverse()); + assertEquals(fromString("者行孙"), fromString("孙行者").reverse()); + assertEquals(fromString("者行孙 olleh"), fromString("hello 孙行者").reverse()); + } + + @Test + public void repeat() { + assertEquals(fromString("数d数d数d数d数d"), fromString("数d").repeat(5)); + assertEquals(fromString("数d"), fromString("数d").repeat(1)); + assertEquals(fromString(""), fromString("数d").repeat(-1)); + } + + @Test + public void pad() { + assertEquals(fromString("hel"), fromString("hello").lpad(3, fromString("????"))); + assertEquals(fromString("hello"), fromString("hello").lpad(5, fromString("????"))); + assertEquals(fromString("?hello"), fromString("hello").lpad(6, fromString("????"))); + assertEquals(fromString("???????hello"), fromString("hello").lpad(12, fromString("????"))); + assertEquals(fromString("?????hello"), fromString("hello").lpad(10, fromString("?????"))); + assertEquals(fromString("???????"), fromString("").lpad(7, fromString("?????"))); + + assertEquals(fromString("hel"), fromString("hello").rpad(3, fromString("????"))); + assertEquals(fromString("hello"), fromString("hello").rpad(5, fromString("????"))); + assertEquals(fromString("hello?"), fromString("hello").rpad(6, fromString("????"))); + assertEquals(fromString("hello???????"), fromString("hello").rpad(12, fromString("????"))); + assertEquals(fromString("hello?????"), fromString("hello").rpad(10, fromString("?????"))); + assertEquals(fromString("???????"), fromString("").rpad(7, fromString("?????"))); + + + assertEquals(fromString("数据砖"), fromString("数据砖头").lpad(3, fromString("????"))); + assertEquals(fromString("?数据砖头"), fromString("数据砖头").lpad(5, fromString("????"))); + assertEquals(fromString("??数据砖头"), fromString("数据砖头").lpad(6, fromString("????"))); + assertEquals(fromString("孙行数据砖头"), fromString("数据砖头").lpad(6, fromString("孙行者"))); + assertEquals(fromString("孙行者数据砖头"), fromString("数据砖头").lpad(7, fromString("孙行者"))); + assertEquals(fromString("孙行者孙行者孙行数据砖头"), fromString("数据砖头").lpad(12, fromString("孙行者"))); + + assertEquals(fromString("数据砖"), fromString("数据砖头").rpad(3, fromString("????"))); + assertEquals(fromString("数据砖头?"), fromString("数据砖头").rpad(5, fromString("????"))); + assertEquals(fromString("数据砖头??"), fromString("数据砖头").rpad(6, fromString("????"))); + assertEquals(fromString("数据砖头孙行"), fromString("数据砖头").rpad(6, fromString("孙行者"))); + assertEquals(fromString("数据砖头孙行者"), fromString("数据砖头").rpad(7, fromString("孙行者"))); + assertEquals(fromString("数据砖头孙行者孙行者孙行"), fromString("数据砖头").rpad(12, fromString("孙行者"))); } @Test From 7ce3b818fb1ba3f291eda58988e4808e999cae3a Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Thu, 9 Jul 2015 13:19:36 -0700 Subject: [PATCH 11/23] [MINOR] [STREAMING] Fix log statements in ReceiverSupervisorImpl Log statements incorrectly showed that the executor was being stopped when receiver was being stopped. Author: Tathagata Das Closes #7328 from tdas/fix-log and squashes the following commits: 9cc6e99 [Tathagata Das] Fix log statements. --- .../spark/streaming/receiver/ReceiverSupervisor.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala index 33be067ebdaf2..eeb14ca3a49e9 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala @@ -182,12 +182,12 @@ private[streaming] abstract class ReceiverSupervisor( /** Wait the thread until the supervisor is stopped */ def awaitTermination() { + logInfo("Waiting for receiver to be stopped") stopLatch.await() - logInfo("Waiting for executor stop is over") if (stoppingError != null) { - logError("Stopped executor with error: " + stoppingError) + logError("Stopped receiver with error: " + stoppingError) } else { - logWarning("Stopped executor without error") + logInfo("Stopped receiver without error") } if (stoppingError != null) { throw stoppingError From 930fe95350f8865e2af2d7afa5b717210933cd43 Mon Sep 17 00:00:00 2001 From: xutingjun Date: Thu, 9 Jul 2015 13:21:10 -0700 Subject: [PATCH 12/23] [SPARK-8953] SPARK_EXECUTOR_CORES is not read in SparkSubmit The configuration ```SPARK_EXECUTOR_CORES``` won't put into ```SparkConf```, so it has no effect to the dynamic executor allocation. Author: xutingjun Closes #7322 from XuTingjun/SPARK_EXECUTOR_CORES and squashes the following commits: 2cafa89 [xutingjun] make SPARK_EXECUTOR_CORES has effect to dynamicAllocation --- .../scala/org/apache/spark/deploy/SparkSubmitArguments.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index 73ab18332feb4..6e3c0b21b33c2 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -162,6 +162,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S .orNull executorCores = Option(executorCores) .orElse(sparkProperties.get("spark.executor.cores")) + .orElse(env.get("SPARK_EXECUTOR_CORES")) .orNull totalExecutorCores = Option(totalExecutorCores) .orElse(sparkProperties.get("spark.cores.max")) From 88bf430331eef3c02438ca441616034486e15789 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Thu, 9 Jul 2015 13:22:17 -0700 Subject: [PATCH 13/23] [SPARK-7419] [STREAMING] [TESTS] Fix CheckpointSuite.recovery with file input stream Fix this failure: https://amplab.cs.berkeley.edu/jenkins/job/Spark-Master-SBT/2886/AMPLAB_JENKINS_BUILD_PROFILE=hadoop2.3,label=centos/testReport/junit/org.apache.spark.streaming/CheckpointSuite/recovery_with_file_input_stream/ To reproduce this failure, you can add `Thread.sleep(2000)` before this line https://github.com/apache/spark/blob/a9c4e29950a14e32acaac547e9a0e8879fd37fc9/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala#L477 Author: zsxwing Closes #7323 from zsxwing/SPARK-7419 and squashes the following commits: b3caf58 [zsxwing] Fix CheckpointSuite.recovery with file input stream --- .../spark/streaming/CheckpointSuite.scala | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala index 6b0a3f91d4d06..6a94928076236 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala @@ -424,11 +424,11 @@ class CheckpointSuite extends TestSuiteBase { } } } - clock.advance(batchDuration.milliseconds) eventually(eventuallyTimeout) { // Wait until all files have been recorded and all batches have started assert(recordedFiles(ssc) === Seq(1, 2, 3) && batchCounter.getNumStartedBatches === 3) } + clock.advance(batchDuration.milliseconds) // Wait for a checkpoint to be written eventually(eventuallyTimeout) { assert(Checkpoint.getCheckpointFiles(checkpointDir).size === 6) @@ -454,9 +454,12 @@ class CheckpointSuite extends TestSuiteBase { // recorded before failure were saved and successfully recovered logInfo("*********** RESTARTING ************") withStreamingContext(new StreamingContext(checkpointDir)) { ssc => - // So that the restarted StreamingContext's clock has gone forward in time since failure - ssc.conf.set("spark.streaming.manualClock.jump", (batchDuration * 3).milliseconds.toString) - val oldClockTime = clock.getTimeMillis() + // "batchDuration.milliseconds * 3" has gone before restarting StreamingContext. And because + // the recovery time is read from the checkpoint time but the original clock doesn't align + // with the batch time, we need to add the offset "batchDuration.milliseconds / 2". + ssc.conf.set("spark.streaming.manualClock.jump", + (batchDuration.milliseconds / 2 + batchDuration.milliseconds * 3).toString) + val oldClockTime = clock.getTimeMillis() // 15000ms clock = ssc.scheduler.clock.asInstanceOf[ManualClock] val batchCounter = new BatchCounter(ssc) val outputStream = ssc.graph.getOutputStreams().head.asInstanceOf[TestOutputStream[Int]] @@ -467,10 +470,10 @@ class CheckpointSuite extends TestSuiteBase { ssc.start() // Verify that the clock has traveled forward to the expected time eventually(eventuallyTimeout) { - clock.getTimeMillis() === oldClockTime + assert(clock.getTimeMillis() === oldClockTime) } - // Wait for pre-failure batch to be recomputed (3 while SSC was down plus last batch) - val numBatchesAfterRestart = 4 + // There are 5 batches between 6000ms and 15000ms (inclusive). + val numBatchesAfterRestart = 5 eventually(eventuallyTimeout) { assert(batchCounter.getNumCompletedBatches === numBatchesAfterRestart) } @@ -483,7 +486,6 @@ class CheckpointSuite extends TestSuiteBase { assert(batchCounter.getNumCompletedBatches === index + numBatchesAfterRestart + 1) } } - clock.advance(batchDuration.milliseconds) logInfo("Output after restart = " + outputStream.output.mkString("[", ", ", "]")) assert(outputStream.output.size > 0, "No files processed after restart") ssc.stop() From ebdf58538058e57381c04b6725d4be0c37847ed3 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Thu, 9 Jul 2015 13:25:11 -0700 Subject: [PATCH 14/23] [SPARK-2017] [UI] Stage page hangs with many tasks (This reopens a patch that was closed in the past: #6248) When you view the stage page while running the following: ``` sc.parallelize(1 to X, 10000).count() ``` The page never loads, the job is stalled, and you end up running into an OOM: ``` HTTP ERROR 500 Problem accessing /stages/stage/. Reason: Server Error Caused by: java.lang.OutOfMemoryError: Java heap space at java.util.Arrays.copyOf(Arrays.java:2367) at java.lang.AbstractStringBuilder.expandCapacity(AbstractStringBuilder.java:130) ``` This patch compresses Jetty responses in gzip. The correct long-term fix is to add pagination. Author: Andrew Or Closes #7296 from andrewor14/gzip-jetty and squashes the following commits: a051c64 [Andrew Or] Use GZIP to compress Jetty responses --- .../main/scala/org/apache/spark/ui/JettyUtils.scala | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala index 06e616220c706..f413c1d37fbb6 100644 --- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala @@ -210,10 +210,16 @@ private[spark] object JettyUtils extends Logging { conf: SparkConf, serverName: String = ""): ServerInfo = { - val collection = new ContextHandlerCollection - collection.setHandlers(handlers.toArray) addFilters(handlers, conf) + val collection = new ContextHandlerCollection + val gzipHandlers = handlers.map { h => + val gzipHandler = new GzipHandler + gzipHandler.setHandler(h) + gzipHandler + } + collection.setHandlers(gzipHandlers.toArray) + // Bind to the given port, or throw a java.net.BindException if the port is occupied def connect(currentPort: Int): (Server, Int) = { val server = new Server(new InetSocketAddress(hostName, currentPort)) From c4830598b271cc6390d127bd4cf8ab02b28792e0 Mon Sep 17 00:00:00 2001 From: Iulian Dragos Date: Thu, 9 Jul 2015 13:26:46 -0700 Subject: [PATCH 15/23] [SPARK-6287] [MESOS] Add dynamic allocation to the coarse-grained Mesos scheduler This is largely based on extracting the dynamic allocation parts from tnachen's #3861. Author: Iulian Dragos Closes #4984 from dragos/issue/mesos-coarse-dynamicAllocation and squashes the following commits: 39df8cd [Iulian Dragos] Update tests to latest changes in core. 9d2c9fa [Iulian Dragos] Remove adjustment of executorLimitOption in doKillExecutors. 8b00f52 [Iulian Dragos] Latest round of reviews. 0cd00e0 [Iulian Dragos] Add persistent shuffle directory 15c45c1 [Iulian Dragos] Add dynamic allocation to the Spark coarse-grained scheduler. --- .../scala/org/apache/spark/SparkContext.scala | 19 +- .../mesos/CoarseMesosSchedulerBackend.scala | 136 +++++++++++--- .../cluster/mesos/MesosSchedulerUtils.scala | 4 +- .../spark/storage/DiskBlockManager.scala | 8 +- .../scala/org/apache/spark/util/Utils.scala | 45 +++-- .../CoarseMesosSchedulerBackendSuite.scala | 175 ++++++++++++++++++ 6 files changed, 331 insertions(+), 56 deletions(-) create mode 100644 core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index d2547eeff2b4e..82704b1ab2189 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -532,7 +532,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli _executorAllocationManager = if (dynamicAllocationEnabled) { assert(supportDynamicAllocation, - "Dynamic allocation of executors is currently only supported in YARN mode") + "Dynamic allocation of executors is currently only supported in YARN and Mesos mode") Some(new ExecutorAllocationManager(this, listenerBus, _conf)) } else { None @@ -853,7 +853,6 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli minPartitions).setName(path) } - /** * :: Experimental :: * @@ -1364,10 +1363,14 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli /** * Return whether dynamically adjusting the amount of resources allocated to - * this application is supported. This is currently only available for YARN. + * this application is supported. This is currently only available for YARN + * and Mesos coarse-grained mode. */ - private[spark] def supportDynamicAllocation = - master.contains("yarn") || _conf.getBoolean("spark.dynamicAllocation.testing", false) + private[spark] def supportDynamicAllocation: Boolean = { + (master.contains("yarn") + || master.contains("mesos") + || _conf.getBoolean("spark.dynamicAllocation.testing", false)) + } /** * :: DeveloperApi :: @@ -1385,7 +1388,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli */ private[spark] override def requestTotalExecutors(numExecutors: Int): Boolean = { assert(supportDynamicAllocation, - "Requesting executors is currently only supported in YARN mode") + "Requesting executors is currently only supported in YARN and Mesos modes") schedulerBackend match { case b: CoarseGrainedSchedulerBackend => b.requestTotalExecutors(numExecutors) @@ -1403,7 +1406,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli @DeveloperApi override def requestExecutors(numAdditionalExecutors: Int): Boolean = { assert(supportDynamicAllocation, - "Requesting executors is currently only supported in YARN mode") + "Requesting executors is currently only supported in YARN and Mesos modes") schedulerBackend match { case b: CoarseGrainedSchedulerBackend => b.requestExecutors(numAdditionalExecutors) @@ -1421,7 +1424,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli @DeveloperApi override def killExecutors(executorIds: Seq[String]): Boolean = { assert(supportDynamicAllocation, - "Killing executors is currently only supported in YARN mode") + "Killing executors is currently only supported in YARN and Mesos modes") schedulerBackend match { case b: CoarseGrainedSchedulerBackend => b.killExecutors(executorIds) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index b68f8c7685eba..cbade131494bc 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -18,11 +18,14 @@ package org.apache.spark.scheduler.cluster.mesos import java.io.File -import java.util.{List => JList} +import java.util.{List => JList, Collections} +import java.util.concurrent.locks.ReentrantLock import scala.collection.JavaConversions._ import scala.collection.mutable.{HashMap, HashSet} +import com.google.common.collect.HashBiMap +import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, _} import org.apache.mesos.{Scheduler => MScheduler, _} import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, _} import org.apache.spark.{SparkContext, SparkEnv, SparkException, TaskState} @@ -60,9 +63,27 @@ private[spark] class CoarseMesosSchedulerBackend( val slaveIdsWithExecutors = new HashSet[String] - val taskIdToSlaveId = new HashMap[Int, String] - val failuresBySlaveId = new HashMap[String, Int] // How many times tasks on each slave failed + val taskIdToSlaveId: HashBiMap[Int, String] = HashBiMap.create[Int, String] + // How many times tasks on each slave failed + val failuresBySlaveId: HashMap[String, Int] = new HashMap[String, Int] + + /** + * The total number of executors we aim to have. Undefined when not using dynamic allocation + * and before the ExecutorAllocatorManager calls [[doRequesTotalExecutors]]. + */ + private var executorLimitOption: Option[Int] = None + + /** + * Return the current executor limit, which may be [[Int.MaxValue]] + * before properly initialized. + */ + private[mesos] def executorLimit: Int = executorLimitOption.getOrElse(Int.MaxValue) + + private val pendingRemovedSlaveIds = new HashSet[String] + // private lock object protecting mutable state above. Using the intrinsic lock + // may lead to deadlocks since the superclass might also try to lock + private val stateLock = new ReentrantLock val extraCoresPerSlave = conf.getInt("spark.mesos.extra.cores", 0) @@ -86,7 +107,7 @@ private[spark] class CoarseMesosSchedulerBackend( startScheduler(master, CoarseMesosSchedulerBackend.this, fwInfo) } - def createCommand(offer: Offer, numCores: Int): CommandInfo = { + def createCommand(offer: Offer, numCores: Int, taskId: Int): CommandInfo = { val executorSparkHome = conf.getOption("spark.mesos.executor.home") .orElse(sc.getSparkHome()) .getOrElse { @@ -120,10 +141,6 @@ private[spark] class CoarseMesosSchedulerBackend( } val command = CommandInfo.newBuilder() .setEnvironment(environment) - val driverUrl = sc.env.rpcEnv.uriOf( - SparkEnv.driverActorSystemName, - RpcAddress(conf.get("spark.driver.host"), conf.get("spark.driver.port").toInt), - CoarseGrainedSchedulerBackend.ENDPOINT_NAME) val uri = conf.getOption("spark.executor.uri") .orElse(Option(System.getenv("SPARK_EXECUTOR_URI"))) @@ -133,7 +150,7 @@ private[spark] class CoarseMesosSchedulerBackend( command.setValue( "%s \"%s\" org.apache.spark.executor.CoarseGrainedExecutorBackend" .format(prefixEnv, runScript) + - s" --driver-url $driverUrl" + + s" --driver-url $driverURL" + s" --executor-id ${offer.getSlaveId.getValue}" + s" --hostname ${offer.getHostname}" + s" --cores $numCores" + @@ -142,11 +159,12 @@ private[spark] class CoarseMesosSchedulerBackend( // Grab everything to the first '.'. We'll use that and '*' to // glob the directory "correctly". val basename = uri.get.split('/').last.split('.').head + val executorId = sparkExecutorId(offer.getSlaveId.getValue, taskId.toString) command.setValue( s"cd $basename*; $prefixEnv " + "./bin/spark-class org.apache.spark.executor.CoarseGrainedExecutorBackend" + - s" --driver-url $driverUrl" + - s" --executor-id ${offer.getSlaveId.getValue}" + + s" --driver-url $driverURL" + + s" --executor-id $executorId" + s" --hostname ${offer.getHostname}" + s" --cores $numCores" + s" --app-id $appId") @@ -155,6 +173,17 @@ private[spark] class CoarseMesosSchedulerBackend( command.build() } + protected def driverURL: String = { + if (conf.contains("spark.testing")) { + "driverURL" + } else { + sc.env.rpcEnv.uriOf( + SparkEnv.driverActorSystemName, + RpcAddress(conf.get("spark.driver.host"), conf.get("spark.driver.port").toInt), + CoarseGrainedSchedulerBackend.ENDPOINT_NAME) + } + } + override def offerRescinded(d: SchedulerDriver, o: OfferID) {} override def registered(d: SchedulerDriver, frameworkId: FrameworkID, masterInfo: MasterInfo) { @@ -172,17 +201,18 @@ private[spark] class CoarseMesosSchedulerBackend( * unless we've already launched more than we wanted to. */ override def resourceOffers(d: SchedulerDriver, offers: JList[Offer]) { - synchronized { + stateLock.synchronized { val filters = Filters.newBuilder().setRefuseSeconds(5).build() for (offer <- offers) { val offerAttributes = toAttributeMap(offer.getAttributesList) val meetsConstraints = matchesAttributeRequirements(slaveOfferConstraints, offerAttributes) - val slaveId = offer.getSlaveId.toString + val slaveId = offer.getSlaveId.getValue val mem = getResource(offer.getResourcesList, "mem") val cpus = getResource(offer.getResourcesList, "cpus").toInt val id = offer.getId.getValue - if (meetsConstraints && + if (taskIdToSlaveId.size < executorLimit && totalCoresAcquired < maxCores && + meetsConstraints && mem >= calculateTotalMemory(sc) && cpus >= 1 && failuresBySlaveId.getOrElse(slaveId, 0) < MAX_SLAVE_FAILURES && @@ -197,7 +227,7 @@ private[spark] class CoarseMesosSchedulerBackend( val task = MesosTaskInfo.newBuilder() .setTaskId(TaskID.newBuilder().setValue(taskId.toString).build()) .setSlaveId(offer.getSlaveId) - .setCommand(createCommand(offer, cpusToUse + extraCoresPerSlave)) + .setCommand(createCommand(offer, cpusToUse + extraCoresPerSlave, taskId)) .setName("Task " + taskId) .addResources(createResource("cpus", cpusToUse)) .addResources(createResource("mem", calculateTotalMemory(sc))) @@ -209,7 +239,9 @@ private[spark] class CoarseMesosSchedulerBackend( // accept the offer and launch the task logDebug(s"Accepting offer: $id with attributes: $offerAttributes mem: $mem cpu: $cpus") - d.launchTasks(List(offer.getId), List(task.build()), filters) + d.launchTasks( + Collections.singleton(offer.getId), + Collections.singleton(task.build()), filters) } else { // Decline the offer logDebug(s"Declining offer: $id with attributes: $offerAttributes mem: $mem cpu: $cpus") @@ -224,7 +256,7 @@ private[spark] class CoarseMesosSchedulerBackend( val taskId = status.getTaskId.getValue.toInt val state = status.getState logInfo("Mesos task " + taskId + " is now " + state) - synchronized { + stateLock.synchronized { if (TaskState.isFinished(TaskState.fromMesos(state))) { val slaveId = taskIdToSlaveId(taskId) slaveIdsWithExecutors -= slaveId @@ -242,8 +274,9 @@ private[spark] class CoarseMesosSchedulerBackend( "is Spark installed on it?") } } + executorTerminated(d, slaveId, s"Executor finished with state $state") // In case we'd rejected everything before but have now lost a node - mesosDriver.reviveOffers() + d.reviveOffers() } } } @@ -262,18 +295,39 @@ private[spark] class CoarseMesosSchedulerBackend( override def frameworkMessage(d: SchedulerDriver, e: ExecutorID, s: SlaveID, b: Array[Byte]) {} - override def slaveLost(d: SchedulerDriver, slaveId: SlaveID) { - logInfo("Mesos slave lost: " + slaveId.getValue) - synchronized { - if (slaveIdsWithExecutors.contains(slaveId.getValue)) { - // Note that the slave ID corresponds to the executor ID on that slave - slaveIdsWithExecutors -= slaveId.getValue - removeExecutor(slaveId.getValue, "Mesos slave lost") + /** + * Called when a slave is lost or a Mesos task finished. Update local view on + * what tasks are running and remove the terminated slave from the list of pending + * slave IDs that we might have asked to be killed. It also notifies the driver + * that an executor was removed. + */ + private def executorTerminated(d: SchedulerDriver, slaveId: String, reason: String): Unit = { + stateLock.synchronized { + if (slaveIdsWithExecutors.contains(slaveId)) { + val slaveIdToTaskId = taskIdToSlaveId.inverse() + if (slaveIdToTaskId.contains(slaveId)) { + val taskId: Int = slaveIdToTaskId.get(slaveId) + taskIdToSlaveId.remove(taskId) + removeExecutor(sparkExecutorId(slaveId, taskId.toString), reason) + } + // TODO: This assumes one Spark executor per Mesos slave, + // which may no longer be true after SPARK-5095 + pendingRemovedSlaveIds -= slaveId + slaveIdsWithExecutors -= slaveId } } } - override def executorLost(d: SchedulerDriver, e: ExecutorID, s: SlaveID, status: Int) { + private def sparkExecutorId(slaveId: String, taskId: String): String = { + s"$slaveId/$taskId" + } + + override def slaveLost(d: SchedulerDriver, slaveId: SlaveID): Unit = { + logInfo("Mesos slave lost: " + slaveId.getValue) + executorTerminated(d, slaveId.getValue, "Mesos slave lost: " + slaveId.getValue) + } + + override def executorLost(d: SchedulerDriver, e: ExecutorID, s: SlaveID, status: Int): Unit = { logInfo("Executor lost: %s, marking slave %s as lost".format(e.getValue, s.getValue)) slaveLost(d, s) } @@ -284,4 +338,34 @@ private[spark] class CoarseMesosSchedulerBackend( super.applicationId } + override def doRequestTotalExecutors(requestedTotal: Int): Boolean = { + // We don't truly know if we can fulfill the full amount of executors + // since at coarse grain it depends on the amount of slaves available. + logInfo("Capping the total amount of executors to " + requestedTotal) + executorLimitOption = Some(requestedTotal) + true + } + + override def doKillExecutors(executorIds: Seq[String]): Boolean = { + if (mesosDriver == null) { + logWarning("Asked to kill executors before the Mesos driver was started.") + return false + } + + val slaveIdToTaskId = taskIdToSlaveId.inverse() + for (executorId <- executorIds) { + val slaveId = executorId.split("/")(0) + if (slaveIdToTaskId.contains(slaveId)) { + mesosDriver.killTask( + TaskID.newBuilder().setValue(slaveIdToTaskId.get(slaveId).toString).build()) + pendingRemovedSlaveIds += slaveId + } else { + logWarning("Unable to find executor Id '" + executorId + "' in Mesos scheduler") + } + } + // no need to adjust `executorLimitOption` since the AllocationManager already communicated + // the desired limit through a call to `doRequestTotalExecutors`. + // See [[o.a.s.scheduler.cluster.CoarseGrainedSchedulerBackend.killExecutors]] + true + } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala index d8a8c848bb4d1..925702e63afd3 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala @@ -24,7 +24,7 @@ import scala.collection.JavaConversions._ import scala.util.control.NonFatal import com.google.common.base.Splitter -import org.apache.mesos.{MesosSchedulerDriver, Protos, Scheduler} +import org.apache.mesos.{MesosSchedulerDriver, SchedulerDriver, Scheduler, Protos} import org.apache.mesos.Protos._ import org.apache.mesos.protobuf.GeneratedMessage import org.apache.spark.{Logging, SparkContext} @@ -39,7 +39,7 @@ private[mesos] trait MesosSchedulerUtils extends Logging { private final val registerLatch = new CountDownLatch(1) // Driver for talking to Mesos - protected var mesosDriver: MesosSchedulerDriver = null + protected var mesosDriver: SchedulerDriver = null /** * Starts the MesosSchedulerDriver with the provided information. This method returns diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala index 91ef86389a0c3..5f537692a16c5 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala @@ -124,10 +124,16 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon (blockId, getFile(blockId)) } + /** + * Create local directories for storing block data. These directories are + * located inside configured local directories and won't + * be deleted on JVM exit when using the external shuffle service. + */ private def createLocalDirs(conf: SparkConf): Array[File] = { - Utils.getOrCreateLocalRootDirs(conf).flatMap { rootDir => + Utils.getConfiguredLocalDirs(conf).flatMap { rootDir => try { val localDir = Utils.createDirectory(rootDir, "blockmgr") + Utils.chmod700(localDir) logInfo(s"Created local directory at $localDir") Some(localDir) } catch { diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 944560a91354a..b6b932104a94d 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -733,7 +733,12 @@ private[spark] object Utils extends Logging { localRootDirs } - private def getOrCreateLocalRootDirsImpl(conf: SparkConf): Array[String] = { + /** + * Return the configured local directories where Spark can write files. This + * method does not create any directories on its own, it only encapsulates the + * logic of locating the local directories according to deployment mode. + */ + def getConfiguredLocalDirs(conf: SparkConf): Array[String] = { if (isRunningInYarnContainer(conf)) { // If we are in yarn mode, systems can have different disk layouts so we must set it // to what Yarn on this system said was available. Note this assumes that Yarn has @@ -749,27 +754,29 @@ private[spark] object Utils extends Logging { Option(conf.getenv("SPARK_LOCAL_DIRS")) .getOrElse(conf.get("spark.local.dir", System.getProperty("java.io.tmpdir"))) .split(",") - .flatMap { root => - try { - val rootDir = new File(root) - if (rootDir.exists || rootDir.mkdirs()) { - val dir = createTempDir(root) - chmod700(dir) - Some(dir.getAbsolutePath) - } else { - logError(s"Failed to create dir in $root. Ignoring this directory.") - None - } - } catch { - case e: IOException => - logError(s"Failed to create local root dir in $root. Ignoring this directory.") - None - } - } - .toArray } } + private def getOrCreateLocalRootDirsImpl(conf: SparkConf): Array[String] = { + getConfiguredLocalDirs(conf).flatMap { root => + try { + val rootDir = new File(root) + if (rootDir.exists || rootDir.mkdirs()) { + val dir = createTempDir(root) + chmod700(dir) + Some(dir.getAbsolutePath) + } else { + logError(s"Failed to create dir in $root. Ignoring this directory.") + None + } + } catch { + case e: IOException => + logError(s"Failed to create local root dir in $root. Ignoring this directory.") + None + } + }.toArray + } + /** Get the Yarn approved local directories. */ private def getYarnLocalDirs(conf: SparkConf): String = { // Hadoop 0.23 and 2.x have different Environment variable names for the diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala new file mode 100644 index 0000000000000..3f1692917a357 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala @@ -0,0 +1,175 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster.mesos + +import java.util +import java.util.Collections + +import org.apache.mesos.Protos.Value.Scalar +import org.apache.mesos.Protos._ +import org.apache.mesos.SchedulerDriver +import org.mockito.Matchers._ +import org.mockito.Mockito._ +import org.mockito.Matchers +import org.scalatest.mock.MockitoSugar +import org.scalatest.BeforeAndAfter + +import org.apache.spark.scheduler.TaskSchedulerImpl +import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkFunSuite} + +class CoarseMesosSchedulerBackendSuite extends SparkFunSuite + with LocalSparkContext + with MockitoSugar + with BeforeAndAfter { + + private def createOffer(offerId: String, slaveId: String, mem: Int, cpu: Int): Offer = { + val builder = Offer.newBuilder() + builder.addResourcesBuilder() + .setName("mem") + .setType(Value.Type.SCALAR) + .setScalar(Scalar.newBuilder().setValue(mem)) + builder.addResourcesBuilder() + .setName("cpus") + .setType(Value.Type.SCALAR) + .setScalar(Scalar.newBuilder().setValue(cpu)) + builder.setId(OfferID.newBuilder() + .setValue(offerId).build()) + .setFrameworkId(FrameworkID.newBuilder() + .setValue("f1")) + .setSlaveId(SlaveID.newBuilder().setValue(slaveId)) + .setHostname(s"host${slaveId}") + .build() + } + + private def createSchedulerBackend( + taskScheduler: TaskSchedulerImpl, + driver: SchedulerDriver): CoarseMesosSchedulerBackend = { + val backend = new CoarseMesosSchedulerBackend(taskScheduler, sc, "master") { + mesosDriver = driver + markRegistered() + } + backend.start() + backend + } + + var sparkConf: SparkConf = _ + + before { + sparkConf = (new SparkConf) + .setMaster("local[*]") + .setAppName("test-mesos-dynamic-alloc") + .setSparkHome("/path") + + sc = new SparkContext(sparkConf) + } + + test("mesos supports killing and limiting executors") { + val driver = mock[SchedulerDriver] + val taskScheduler = mock[TaskSchedulerImpl] + when(taskScheduler.sc).thenReturn(sc) + + sparkConf.set("spark.driver.host", "driverHost") + sparkConf.set("spark.driver.port", "1234") + + val backend = createSchedulerBackend(taskScheduler, driver) + val minMem = backend.calculateTotalMemory(sc).toInt + val minCpu = 4 + + val mesosOffers = new java.util.ArrayList[Offer] + mesosOffers.add(createOffer("o1", "s1", minMem, minCpu)) + + val taskID0 = TaskID.newBuilder().setValue("0").build() + + backend.resourceOffers(driver, mesosOffers) + verify(driver, times(1)).launchTasks( + Matchers.eq(Collections.singleton(mesosOffers.get(0).getId)), + any[util.Collection[TaskInfo]], + any[Filters]) + + // simulate the allocation manager down-scaling executors + backend.doRequestTotalExecutors(0) + assert(backend.doKillExecutors(Seq("s1/0"))) + verify(driver, times(1)).killTask(taskID0) + + val mesosOffers2 = new java.util.ArrayList[Offer] + mesosOffers2.add(createOffer("o2", "s2", minMem, minCpu)) + backend.resourceOffers(driver, mesosOffers2) + + verify(driver, times(1)) + .declineOffer(OfferID.newBuilder().setValue("o2").build()) + + // Verify we didn't launch any new executor + assert(backend.slaveIdsWithExecutors.size === 1) + + backend.doRequestTotalExecutors(2) + backend.resourceOffers(driver, mesosOffers2) + verify(driver, times(1)).launchTasks( + Matchers.eq(Collections.singleton(mesosOffers2.get(0).getId)), + any[util.Collection[TaskInfo]], + any[Filters]) + + assert(backend.slaveIdsWithExecutors.size === 2) + backend.slaveLost(driver, SlaveID.newBuilder().setValue("s1").build()) + assert(backend.slaveIdsWithExecutors.size === 1) + } + + test("mesos supports killing and relaunching tasks with executors") { + val driver = mock[SchedulerDriver] + val taskScheduler = mock[TaskSchedulerImpl] + when(taskScheduler.sc).thenReturn(sc) + + val backend = createSchedulerBackend(taskScheduler, driver) + val minMem = backend.calculateTotalMemory(sc).toInt + 1024 + val minCpu = 4 + + val mesosOffers = new java.util.ArrayList[Offer] + val offer1 = createOffer("o1", "s1", minMem, minCpu) + mesosOffers.add(offer1) + + val offer2 = createOffer("o2", "s1", minMem, 1); + + backend.resourceOffers(driver, mesosOffers) + + verify(driver, times(1)).launchTasks( + Matchers.eq(Collections.singleton(offer1.getId)), + anyObject(), + anyObject[Filters]) + + // Simulate task killed, executor no longer running + val status = TaskStatus.newBuilder() + .setTaskId(TaskID.newBuilder().setValue("0").build()) + .setSlaveId(SlaveID.newBuilder().setValue("s1").build()) + .setState(TaskState.TASK_KILLED) + .build + + backend.statusUpdate(driver, status) + assert(!backend.slaveIdsWithExecutors.contains("s1")) + + mesosOffers.clear() + mesosOffers.add(offer2) + backend.resourceOffers(driver, mesosOffers) + assert(backend.slaveIdsWithExecutors.contains("s1")) + + verify(driver, times(1)).launchTasks( + Matchers.eq(Collections.singleton(offer2.getId)), + anyObject(), + anyObject[Filters]) + + verify(driver, times(1)).reviveOffers() + } +} From 1f6b0b1234cc03aa2e07aea7fec2de7563885238 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Thu, 9 Jul 2015 13:48:29 -0700 Subject: [PATCH 16/23] [SPARK-8701] [STREAMING] [WEBUI] Add input metadata in the batch page This PR adds `metadata` to `InputInfo`. `InputDStream` can report its metadata for a batch and it will be shown in the batch page. For example, ![screen shot](https://cloud.githubusercontent.com/assets/1000778/8403741/d6ffc7e2-1e79-11e5-9888-c78c1575123a.png) FileInputDStream will display the new files for a batch, and DirectKafkaInputDStream will display its offset ranges. Author: zsxwing Closes #7081 from zsxwing/input-metadata and squashes the following commits: f7abd9b [zsxwing] Revert the space changes in project/MimaExcludes.scala d906209 [zsxwing] Merge branch 'master' into input-metadata 74762da [zsxwing] Fix MiMa tests 7903e33 [zsxwing] Merge branch 'master' into input-metadata 450a46c [zsxwing] Address comments 1d94582 [zsxwing] Raname InputInfo to StreamInputInfo and change "metadata" to Map[String, Any] d496ae9 [zsxwing] Add input metadata in the batch page --- .../kafka/DirectKafkaInputDStream.scala | 23 ++++++++-- .../spark/streaming/kafka/OffsetRange.scala | 2 +- project/MimaExcludes.scala | 6 +++ .../streaming/dstream/FileInputDStream.scala | 10 ++++- .../dstream/ReceiverInputDStream.scala | 4 +- .../spark/streaming/scheduler/BatchInfo.scala | 9 ++-- .../scheduler/InputInfoTracker.scala | 38 +++++++++++++--- .../streaming/scheduler/JobGenerator.scala | 3 +- .../spark/streaming/scheduler/JobSet.scala | 4 +- .../apache/spark/streaming/ui/BatchPage.scala | 43 +++++++++++++++++-- .../spark/streaming/ui/BatchUIData.scala | 8 ++-- .../ui/StreamingJobProgressListener.scala | 5 ++- .../streaming/StreamingListenerSuite.scala | 6 +-- .../spark/streaming/TestSuiteBase.scala | 2 +- .../scheduler/InputInfoTrackerSuite.scala | 8 ++-- .../StreamingJobProgressListenerSuite.scala | 28 ++++++------ 16 files changed, 148 insertions(+), 51 deletions(-) diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala index 876456c964770..48a1933d92f85 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala @@ -19,7 +19,7 @@ package org.apache.spark.streaming.kafka import scala.annotation.tailrec import scala.collection.mutable -import scala.reflect.{classTag, ClassTag} +import scala.reflect.ClassTag import kafka.common.TopicAndPartition import kafka.message.MessageAndMetadata @@ -29,7 +29,7 @@ import org.apache.spark.{Logging, SparkException} import org.apache.spark.streaming.{StreamingContext, Time} import org.apache.spark.streaming.dstream._ import org.apache.spark.streaming.kafka.KafkaCluster.LeaderOffset -import org.apache.spark.streaming.scheduler.InputInfo +import org.apache.spark.streaming.scheduler.StreamInputInfo /** * A stream of {@link org.apache.spark.streaming.kafka.KafkaRDD} where @@ -119,8 +119,23 @@ class DirectKafkaInputDStream[ val rdd = KafkaRDD[K, V, U, T, R]( context.sparkContext, kafkaParams, currentOffsets, untilOffsets, messageHandler) - // Report the record number of this batch interval to InputInfoTracker. - val inputInfo = InputInfo(id, rdd.count) + // Report the record number and metadata of this batch interval to InputInfoTracker. + val offsetRanges = currentOffsets.map { case (tp, fo) => + val uo = untilOffsets(tp) + OffsetRange(tp.topic, tp.partition, fo, uo.offset) + } + val description = offsetRanges.filter { offsetRange => + // Don't display empty ranges. + offsetRange.fromOffset != offsetRange.untilOffset + }.map { offsetRange => + s"topic: ${offsetRange.topic}\tpartition: ${offsetRange.partition}\t" + + s"offsets: ${offsetRange.fromOffset} to ${offsetRange.untilOffset}" + }.mkString("\n") + // Copy offsetRanges to immutable.List to prevent from being modified by the user + val metadata = Map( + "offsets" -> offsetRanges.toList, + StreamInputInfo.METADATA_KEY_DESCRIPTION -> description) + val inputInfo = StreamInputInfo(id, rdd.count, metadata) ssc.scheduler.inputInfoTracker.reportInfo(validTime, inputInfo) currentOffsets = untilOffsets.map(kv => kv._1 -> kv._2.offset) diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/OffsetRange.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/OffsetRange.scala index 2675042666304..f326e7f1f6f8d 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/OffsetRange.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/OffsetRange.scala @@ -75,7 +75,7 @@ final class OffsetRange private( } override def toString(): String = { - s"OffsetRange(topic: '$topic', partition: $partition, range: [$fromOffset -> $untilOffset]" + s"OffsetRange(topic: '$topic', partition: $partition, range: [$fromOffset -> $untilOffset])" } /** this is to avoid ClassNotFoundException during checkpoint restore */ diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 821aadd477ef3..79089aae2a37c 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -77,6 +77,12 @@ object MimaExcludes { // SPARK-8914 Remove RDDApi ProblemFilters.exclude[MissingClassProblem]( "org.apache.spark.sql.RDDApi") + ) ++ Seq( + // SPARK-8701 Add input metadata in the batch page. + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.streaming.scheduler.InputInfo$"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.streaming.scheduler.InputInfo") ) case v if v.startsWith("1.4") => diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala index 86a8e2beff57c..dd4da9d9ca6a2 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala @@ -28,6 +28,7 @@ import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} import org.apache.spark.rdd.{RDD, UnionRDD} import org.apache.spark.streaming._ +import org.apache.spark.streaming.scheduler.StreamInputInfo import org.apache.spark.util.{SerializableConfiguration, TimeStampedHashMap, Utils} /** @@ -144,7 +145,14 @@ class FileInputDStream[K, V, F <: NewInputFormat[K, V]]( logInfo("New files at time " + validTime + ":\n" + newFiles.mkString("\n")) batchTimeToSelectedFiles += ((validTime, newFiles)) recentlySelectedFiles ++= newFiles - Some(filesToRDD(newFiles)) + val rdds = Some(filesToRDD(newFiles)) + // Copy newFiles to immutable.List to prevent from being modified by the user + val metadata = Map( + "files" -> newFiles.toList, + StreamInputInfo.METADATA_KEY_DESCRIPTION -> newFiles.mkString("\n")) + val inputInfo = StreamInputInfo(id, 0, metadata) + ssc.scheduler.inputInfoTracker.reportInfo(validTime, inputInfo) + rdds } /** Clear the old time-to-files mappings along with old RDDs */ diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala index e76e7eb0dea19..a50f0efc030ce 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala @@ -24,7 +24,7 @@ import org.apache.spark.storage.BlockId import org.apache.spark.streaming._ import org.apache.spark.streaming.rdd.WriteAheadLogBackedBlockRDD import org.apache.spark.streaming.receiver.Receiver -import org.apache.spark.streaming.scheduler.InputInfo +import org.apache.spark.streaming.scheduler.StreamInputInfo import org.apache.spark.streaming.util.WriteAheadLogUtils /** @@ -70,7 +70,7 @@ abstract class ReceiverInputDStream[T: ClassTag](@transient ssc_ : StreamingCont val blockIds = blockInfos.map { _.blockId.asInstanceOf[BlockId] }.toArray // Register the input blocks information into InputInfoTracker - val inputInfo = InputInfo(id, blockInfos.flatMap(_.numRecords).sum) + val inputInfo = StreamInputInfo(id, blockInfos.flatMap(_.numRecords).sum) ssc.scheduler.inputInfoTracker.reportInfo(validTime, inputInfo) if (blockInfos.nonEmpty) { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/BatchInfo.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/BatchInfo.scala index 5b9bfbf9b01e3..9922b6bc1201b 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/BatchInfo.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/BatchInfo.scala @@ -24,7 +24,7 @@ import org.apache.spark.streaming.Time * :: DeveloperApi :: * Class having information on completed batches. * @param batchTime Time of the batch - * @param streamIdToNumRecords A map of input stream id to record number + * @param streamIdToInputInfo A map of input stream id to its input info * @param submissionTime Clock time of when jobs of this batch was submitted to * the streaming scheduler queue * @param processingStartTime Clock time of when the first job of this batch started processing @@ -33,12 +33,15 @@ import org.apache.spark.streaming.Time @DeveloperApi case class BatchInfo( batchTime: Time, - streamIdToNumRecords: Map[Int, Long], + streamIdToInputInfo: Map[Int, StreamInputInfo], submissionTime: Long, processingStartTime: Option[Long], processingEndTime: Option[Long] ) { + @deprecated("Use streamIdToInputInfo instead", "1.5.0") + def streamIdToNumRecords: Map[Int, Long] = streamIdToInputInfo.mapValues(_.numRecords) + /** * Time taken for the first job of this batch to start processing from the time this batch * was submitted to the streaming scheduler. Essentially, it is @@ -63,5 +66,5 @@ case class BatchInfo( /** * The number of recorders received by the receivers in this batch. */ - def numRecords: Long = streamIdToNumRecords.values.sum + def numRecords: Long = streamIdToInputInfo.values.map(_.numRecords).sum } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/InputInfoTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/InputInfoTracker.scala index 7c0db8a863c67..363c03d431f04 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/InputInfoTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/InputInfoTracker.scala @@ -20,11 +20,34 @@ package org.apache.spark.streaming.scheduler import scala.collection.mutable import org.apache.spark.Logging +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.streaming.{Time, StreamingContext} -/** To track the information of input stream at specified batch time. */ -private[streaming] case class InputInfo(inputStreamId: Int, numRecords: Long) { +/** + * :: DeveloperApi :: + * Track the information of input stream at specified batch time. + * + * @param inputStreamId the input stream id + * @param numRecords the number of records in a batch + * @param metadata metadata for this batch. It should contain at least one standard field named + * "Description" which maps to the content that will be shown in the UI. + */ +@DeveloperApi +case class StreamInputInfo( + inputStreamId: Int, numRecords: Long, metadata: Map[String, Any] = Map.empty) { require(numRecords >= 0, "numRecords must not be negative") + + def metadataDescription: Option[String] = + metadata.get(StreamInputInfo.METADATA_KEY_DESCRIPTION).map(_.toString) +} + +@DeveloperApi +object StreamInputInfo { + + /** + * The key for description in `StreamInputInfo.metadata`. + */ + val METADATA_KEY_DESCRIPTION: String = "Description" } /** @@ -34,12 +57,13 @@ private[streaming] case class InputInfo(inputStreamId: Int, numRecords: Long) { private[streaming] class InputInfoTracker(ssc: StreamingContext) extends Logging { // Map to track all the InputInfo related to specific batch time and input stream. - private val batchTimeToInputInfos = new mutable.HashMap[Time, mutable.HashMap[Int, InputInfo]] + private val batchTimeToInputInfos = + new mutable.HashMap[Time, mutable.HashMap[Int, StreamInputInfo]] /** Report the input information with batch time to the tracker */ - def reportInfo(batchTime: Time, inputInfo: InputInfo): Unit = synchronized { + def reportInfo(batchTime: Time, inputInfo: StreamInputInfo): Unit = synchronized { val inputInfos = batchTimeToInputInfos.getOrElseUpdate(batchTime, - new mutable.HashMap[Int, InputInfo]()) + new mutable.HashMap[Int, StreamInputInfo]()) if (inputInfos.contains(inputInfo.inputStreamId)) { throw new IllegalStateException(s"Input stream ${inputInfo.inputStreamId}} for batch" + @@ -49,10 +73,10 @@ private[streaming] class InputInfoTracker(ssc: StreamingContext) extends Logging } /** Get the all the input stream's information of specified batch time */ - def getInfo(batchTime: Time): Map[Int, InputInfo] = synchronized { + def getInfo(batchTime: Time): Map[Int, StreamInputInfo] = synchronized { val inputInfos = batchTimeToInputInfos.get(batchTime) // Convert mutable HashMap to immutable Map for the caller - inputInfos.map(_.toMap).getOrElse(Map[Int, InputInfo]()) + inputInfos.map(_.toMap).getOrElse(Map[Int, StreamInputInfo]()) } /** Cleanup the tracked input information older than threshold batch time */ diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala index 9f93d6cbc3c20..f5d41858646e4 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala @@ -244,8 +244,7 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { } match { case Success(jobs) => val streamIdToInputInfos = jobScheduler.inputInfoTracker.getInfo(time) - val streamIdToNumRecords = streamIdToInputInfos.mapValues(_.numRecords) - jobScheduler.submitJobSet(JobSet(time, jobs, streamIdToNumRecords)) + jobScheduler.submitJobSet(JobSet(time, jobs, streamIdToInputInfos)) case Failure(e) => jobScheduler.reportError("Error generating jobs for time " + time, e) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala index e6be63b2ddbdc..95833efc9417f 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala @@ -28,7 +28,7 @@ private[streaming] case class JobSet( time: Time, jobs: Seq[Job], - streamIdToNumRecords: Map[Int, Long] = Map.empty) { + streamIdToInputInfo: Map[Int, StreamInputInfo] = Map.empty) { private val incompleteJobs = new HashSet[Job]() private val submissionTime = System.currentTimeMillis() // when this jobset was submitted @@ -64,7 +64,7 @@ case class JobSet( def toBatchInfo: BatchInfo = { new BatchInfo( time, - streamIdToNumRecords, + streamIdToInputInfo, submissionTime, if (processingStartTime >= 0 ) Some(processingStartTime) else None, if (processingEndTime >= 0 ) Some(processingEndTime) else None diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala index f75067669abe5..0c891662c264f 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala @@ -17,11 +17,9 @@ package org.apache.spark.streaming.ui -import java.text.SimpleDateFormat -import java.util.Date import javax.servlet.http.HttpServletRequest -import scala.xml.{NodeSeq, Node, Text} +import scala.xml.{NodeSeq, Node, Text, Unparsed} import org.apache.commons.lang3.StringEscapeUtils @@ -303,6 +301,9 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { batchUIData.processingDelay.map(SparkUIUtils.formatDuration).getOrElse("-") val formattedTotalDelay = batchUIData.totalDelay.map(SparkUIUtils.formatDuration).getOrElse("-") + val inputMetadatas = batchUIData.streamIdToInputInfo.values.flatMap { inputInfo => + inputInfo.metadataDescription.map(desc => inputInfo.inputStreamId -> desc) + }.toSeq val summary: NodeSeq =

    @@ -326,6 +327,13 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { Total delay: {formattedTotalDelay} + { + if (inputMetadatas.nonEmpty) { +
  • + Input Metadata:{generateInputMetadataTable(inputMetadatas)} +
  • + } + }
@@ -340,4 +348,33 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { SparkUIUtils.headerSparkPage(s"Details of batch at $formattedBatchTime", content, parent) } + + def generateInputMetadataTable(inputMetadatas: Seq[(Int, String)]): Seq[Node] = { + + + + + + + + + {inputMetadatas.flatMap(generateInputMetadataRow)} + +
InputMetadata
+ } + + def generateInputMetadataRow(inputMetadata: (Int, String)): Seq[Node] = { + val streamId = inputMetadata._1 + + + {streamingListener.streamName(streamId).getOrElse(s"Stream-$streamId")} + {metadataDescriptionToHTML(inputMetadata._2)} + + } + + private def metadataDescriptionToHTML(metadataDescription: String): Seq[Node] = { + // tab to 4 spaces and "\n" to "
" + Unparsed(StringEscapeUtils.escapeHtml4(metadataDescription). + replaceAllLiterally("\t", "    ").replaceAllLiterally("\n", "
")) + } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchUIData.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchUIData.scala index a5514dfd71c9f..ae508c0e9577b 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchUIData.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchUIData.scala @@ -19,14 +19,14 @@ package org.apache.spark.streaming.ui import org.apache.spark.streaming.Time -import org.apache.spark.streaming.scheduler.BatchInfo +import org.apache.spark.streaming.scheduler.{BatchInfo, StreamInputInfo} import org.apache.spark.streaming.ui.StreamingJobProgressListener._ private[ui] case class OutputOpIdAndSparkJobId(outputOpId: OutputOpId, sparkJobId: SparkJobId) private[ui] case class BatchUIData( val batchTime: Time, - val streamIdToNumRecords: Map[Int, Long], + val streamIdToInputInfo: Map[Int, StreamInputInfo], val submissionTime: Long, val processingStartTime: Option[Long], val processingEndTime: Option[Long], @@ -58,7 +58,7 @@ private[ui] case class BatchUIData( /** * The number of recorders received by the receivers in this batch. */ - def numRecords: Long = streamIdToNumRecords.values.sum + def numRecords: Long = streamIdToInputInfo.values.map(_.numRecords).sum } private[ui] object BatchUIData { @@ -66,7 +66,7 @@ private[ui] object BatchUIData { def apply(batchInfo: BatchInfo): BatchUIData = { new BatchUIData( batchInfo.batchTime, - batchInfo.streamIdToNumRecords, + batchInfo.streamIdToInputInfo, batchInfo.submissionTime, batchInfo.processingStartTime, batchInfo.processingEndTime diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala index 68e8ce98945e0..b77c555c68b8b 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala @@ -192,7 +192,7 @@ private[streaming] class StreamingJobProgressListener(ssc: StreamingContext) def receivedEventRateWithBatchTime: Map[Int, Seq[(Long, Double)]] = synchronized { val _retainedBatches = retainedBatches val latestBatches = _retainedBatches.map { batchUIData => - (batchUIData.batchTime.milliseconds, batchUIData.streamIdToNumRecords) + (batchUIData.batchTime.milliseconds, batchUIData.streamIdToInputInfo.mapValues(_.numRecords)) } streamIds.map { streamId => val eventRates = latestBatches.map { @@ -205,7 +205,8 @@ private[streaming] class StreamingJobProgressListener(ssc: StreamingContext) } def lastReceivedBatchRecords: Map[Int, Long] = synchronized { - val lastReceivedBlockInfoOption = lastReceivedBatch.map(_.streamIdToNumRecords) + val lastReceivedBlockInfoOption = + lastReceivedBatch.map(_.streamIdToInputInfo.mapValues(_.numRecords)) lastReceivedBlockInfoOption.map { lastReceivedBlockInfo => streamIds.map { streamId => (streamId, lastReceivedBlockInfo.getOrElse(streamId, 0L)) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala index 7bc7727a9fbe4..4bc1dd4a30fc4 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala @@ -59,7 +59,7 @@ class StreamingListenerSuite extends TestSuiteBase with Matchers { batchInfosSubmitted.foreach { info => info.numRecords should be (1L) - info.streamIdToNumRecords should be (Map(0 -> 1L)) + info.streamIdToInputInfo should be (Map(0 -> StreamInputInfo(0, 1L))) } isInIncreasingOrder(batchInfosSubmitted.map(_.submissionTime)) should be (true) @@ -77,7 +77,7 @@ class StreamingListenerSuite extends TestSuiteBase with Matchers { batchInfosStarted.foreach { info => info.numRecords should be (1L) - info.streamIdToNumRecords should be (Map(0 -> 1L)) + info.streamIdToInputInfo should be (Map(0 -> StreamInputInfo(0, 1L))) } isInIncreasingOrder(batchInfosStarted.map(_.submissionTime)) should be (true) @@ -98,7 +98,7 @@ class StreamingListenerSuite extends TestSuiteBase with Matchers { batchInfosCompleted.foreach { info => info.numRecords should be (1L) - info.streamIdToNumRecords should be (Map(0 -> 1L)) + info.streamIdToInputInfo should be (Map(0 -> StreamInputInfo(0, 1L))) } isInIncreasingOrder(batchInfosCompleted.map(_.submissionTime)) should be (true) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala index 31b1aebf6a8ec..0d58a7b54412f 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala @@ -76,7 +76,7 @@ class TestInputStream[T: ClassTag](ssc_ : StreamingContext, input: Seq[Seq[T]], } // Report the input data's information to InputInfoTracker for testing - val inputInfo = InputInfo(id, selectedInput.length.toLong) + val inputInfo = StreamInputInfo(id, selectedInput.length.toLong) ssc.scheduler.inputInfoTracker.reportInfo(validTime, inputInfo) val rdd = ssc.sc.makeRDD(selectedInput, numPartitions) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/InputInfoTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/InputInfoTrackerSuite.scala index 2e210397fe7c7..f5248acf712b9 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/InputInfoTrackerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/InputInfoTrackerSuite.scala @@ -46,8 +46,8 @@ class InputInfoTrackerSuite extends SparkFunSuite with BeforeAndAfter { val streamId1 = 0 val streamId2 = 1 val time = Time(0L) - val inputInfo1 = InputInfo(streamId1, 100L) - val inputInfo2 = InputInfo(streamId2, 300L) + val inputInfo1 = StreamInputInfo(streamId1, 100L) + val inputInfo2 = StreamInputInfo(streamId2, 300L) inputInfoTracker.reportInfo(time, inputInfo1) inputInfoTracker.reportInfo(time, inputInfo2) @@ -63,8 +63,8 @@ class InputInfoTrackerSuite extends SparkFunSuite with BeforeAndAfter { val inputInfoTracker = new InputInfoTracker(ssc) val streamId1 = 0 - val inputInfo1 = InputInfo(streamId1, 100L) - val inputInfo2 = InputInfo(streamId1, 300L) + val inputInfo1 = StreamInputInfo(streamId1, 100L) + val inputInfo2 = StreamInputInfo(streamId1, 300L) inputInfoTracker.reportInfo(Time(0), inputInfo1) inputInfoTracker.reportInfo(Time(1), inputInfo2) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala index c9175d61b1f49..40dc1fb601bd0 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala @@ -49,10 +49,12 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { val ssc = setupStreams(input, operation) val listener = new StreamingJobProgressListener(ssc) - val streamIdToNumRecords = Map(0 -> 300L, 1 -> 300L) + val streamIdToInputInfo = Map( + 0 -> StreamInputInfo(0, 300L), + 1 -> StreamInputInfo(1, 300L, Map(StreamInputInfo.METADATA_KEY_DESCRIPTION -> "test"))) // onBatchSubmitted - val batchInfoSubmitted = BatchInfo(Time(1000), streamIdToNumRecords, 1000, None, None) + val batchInfoSubmitted = BatchInfo(Time(1000), streamIdToInputInfo, 1000, None, None) listener.onBatchSubmitted(StreamingListenerBatchSubmitted(batchInfoSubmitted)) listener.waitingBatches should be (List(BatchUIData(batchInfoSubmitted))) listener.runningBatches should be (Nil) @@ -64,7 +66,7 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { listener.numTotalReceivedRecords should be (0) // onBatchStarted - val batchInfoStarted = BatchInfo(Time(1000), streamIdToNumRecords, 1000, Some(2000), None) + val batchInfoStarted = BatchInfo(Time(1000), streamIdToInputInfo, 1000, Some(2000), None) listener.onBatchStarted(StreamingListenerBatchStarted(batchInfoStarted)) listener.waitingBatches should be (Nil) listener.runningBatches should be (List(BatchUIData(batchInfoStarted))) @@ -94,7 +96,9 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { batchUIData.get.schedulingDelay should be (batchInfoStarted.schedulingDelay) batchUIData.get.processingDelay should be (batchInfoStarted.processingDelay) batchUIData.get.totalDelay should be (batchInfoStarted.totalDelay) - batchUIData.get.streamIdToNumRecords should be (Map(0 -> 300L, 1 -> 300L)) + batchUIData.get.streamIdToInputInfo should be (Map( + 0 -> StreamInputInfo(0, 300L), + 1 -> StreamInputInfo(1, 300L, Map(StreamInputInfo.METADATA_KEY_DESCRIPTION -> "test")))) batchUIData.get.numRecords should be(600) batchUIData.get.outputOpIdSparkJobIdPairs should be Seq(OutputOpIdAndSparkJobId(0, 0), @@ -103,7 +107,7 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { OutputOpIdAndSparkJobId(1, 1)) // onBatchCompleted - val batchInfoCompleted = BatchInfo(Time(1000), streamIdToNumRecords, 1000, Some(2000), None) + val batchInfoCompleted = BatchInfo(Time(1000), streamIdToInputInfo, 1000, Some(2000), None) listener.onBatchCompleted(StreamingListenerBatchCompleted(batchInfoCompleted)) listener.waitingBatches should be (Nil) listener.runningBatches should be (Nil) @@ -141,9 +145,9 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { val limit = ssc.conf.getInt("spark.streaming.ui.retainedBatches", 1000) val listener = new StreamingJobProgressListener(ssc) - val streamIdToNumRecords = Map(0 -> 300L, 1 -> 300L) + val streamIdToInputInfo = Map(0 -> StreamInputInfo(0, 300L), 1 -> StreamInputInfo(1, 300L)) - val batchInfoCompleted = BatchInfo(Time(1000), streamIdToNumRecords, 1000, Some(2000), None) + val batchInfoCompleted = BatchInfo(Time(1000), streamIdToInputInfo, 1000, Some(2000), None) for(_ <- 0 until (limit + 10)) { listener.onBatchCompleted(StreamingListenerBatchCompleted(batchInfoCompleted)) @@ -182,7 +186,7 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { batchUIData.get.schedulingDelay should be (batchInfoSubmitted.schedulingDelay) batchUIData.get.processingDelay should be (batchInfoSubmitted.processingDelay) batchUIData.get.totalDelay should be (batchInfoSubmitted.totalDelay) - batchUIData.get.streamIdToNumRecords should be (Map.empty) + batchUIData.get.streamIdToInputInfo should be (Map.empty) batchUIData.get.numRecords should be (0) batchUIData.get.outputOpIdSparkJobIdPairs should be (Seq(OutputOpIdAndSparkJobId(0, 0))) @@ -211,14 +215,14 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { val limit = ssc.conf.getInt("spark.streaming.ui.retainedBatches", 1000) for (_ <- 0 until 2 * limit) { - val streamIdToNumRecords = Map(0 -> 300L, 1 -> 300L) + val streamIdToInputInfo = Map(0 -> StreamInputInfo(0, 300L), 1 -> StreamInputInfo(1, 300L)) // onBatchSubmitted - val batchInfoSubmitted = BatchInfo(Time(1000), streamIdToNumRecords, 1000, None, None) + val batchInfoSubmitted = BatchInfo(Time(1000), streamIdToInputInfo, 1000, None, None) listener.onBatchSubmitted(StreamingListenerBatchSubmitted(batchInfoSubmitted)) // onBatchStarted - val batchInfoStarted = BatchInfo(Time(1000), streamIdToNumRecords, 1000, Some(2000), None) + val batchInfoStarted = BatchInfo(Time(1000), streamIdToInputInfo, 1000, Some(2000), None) listener.onBatchStarted(StreamingListenerBatchStarted(batchInfoStarted)) // onJobStart @@ -235,7 +239,7 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { listener.onJobStart(jobStart4) // onBatchCompleted - val batchInfoCompleted = BatchInfo(Time(1000), streamIdToNumRecords, 1000, Some(2000), None) + val batchInfoCompleted = BatchInfo(Time(1000), streamIdToInputInfo, 1000, Some(2000), None) listener.onBatchCompleted(StreamingListenerBatchCompleted(batchInfoCompleted)) } From 3ccebf36c5abe04702d4cf223552a94034d980fb Mon Sep 17 00:00:00 2001 From: jerryshao Date: Thu, 9 Jul 2015 13:54:44 -0700 Subject: [PATCH 17/23] [SPARK-8389] [STREAMING] [PYSPARK] Expose KafkaRDDs offsetRange in Python This PR propose a simple way to expose OffsetRange in Python code, also the usage of offsetRanges is similar to Scala/Java way, here in Python we could get OffsetRange like: ``` dstream.foreachRDD(lambda r: KafkaUtils.offsetRanges(r)) ``` Reason I didn't follow the way what SPARK-8389 suggested is that: Python Kafka API has one more step to decode the message compared to Scala/Java, Which makes Python API return a transformed RDD/DStream, not directly wrapped so-called JavaKafkaRDD, so it is hard to backtrack to the original RDD to get the offsetRange. Author: jerryshao Closes #7185 from jerryshao/SPARK-8389 and squashes the following commits: 4c6d320 [jerryshao] Another way to fix subclass deserialization issue e6a8011 [jerryshao] Address the comments fd13937 [jerryshao] Fix serialization bug 7debf1c [jerryshao] bug fix cff3893 [jerryshao] refactor the code according to the comments 2aabf9e [jerryshao] Style fix 848c708 [jerryshao] Add HasOffsetRanges for Python --- .../spark/streaming/kafka/KafkaUtils.scala | 13 ++ python/pyspark/streaming/kafka.py | 123 ++++++++++++++++-- python/pyspark/streaming/tests.py | 64 +++++++++ python/pyspark/streaming/util.py | 7 +- 4 files changed, 196 insertions(+), 11 deletions(-) diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala index 0e33362d34acd..f3b01bd60b178 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala @@ -670,4 +670,17 @@ private class KafkaUtilsPythonHelper { TopicAndPartition(topic, partition) def createBroker(host: String, port: JInt): Broker = Broker(host, port) + + def offsetRangesOfKafkaRDD(rdd: RDD[_]): JList[OffsetRange] = { + val parentRDDs = rdd.getNarrowAncestors + val kafkaRDDs = parentRDDs.filter(rdd => rdd.isInstanceOf[KafkaRDD[_, _, _, _, _]]) + + require( + kafkaRDDs.length == 1, + "Cannot get offset ranges, as there may be multiple Kafka RDDs or no Kafka RDD associated" + + "with this RDD, please call this method only on a Kafka RDD.") + + val kafkaRDD = kafkaRDDs.head.asInstanceOf[KafkaRDD[_, _, _, _, _]] + kafkaRDD.offsetRanges.toSeq + } } diff --git a/python/pyspark/streaming/kafka.py b/python/pyspark/streaming/kafka.py index 10a859a532e28..33dd596335b47 100644 --- a/python/pyspark/streaming/kafka.py +++ b/python/pyspark/streaming/kafka.py @@ -21,6 +21,8 @@ from pyspark.storagelevel import StorageLevel from pyspark.serializers import PairDeserializer, NoOpSerializer from pyspark.streaming import DStream +from pyspark.streaming.dstream import TransformedDStream +from pyspark.streaming.util import TransformFunction __all__ = ['Broker', 'KafkaUtils', 'OffsetRange', 'TopicAndPartition', 'utf8_decoder'] @@ -122,8 +124,9 @@ def createDirectStream(ssc, topics, kafkaParams, fromOffsets={}, raise e ser = PairDeserializer(NoOpSerializer(), NoOpSerializer()) - stream = DStream(jstream, ssc, ser) - return stream.map(lambda k_v: (keyDecoder(k_v[0]), valueDecoder(k_v[1]))) + stream = DStream(jstream, ssc, ser) \ + .map(lambda k_v: (keyDecoder(k_v[0]), valueDecoder(k_v[1]))) + return KafkaDStream(stream._jdstream, ssc, stream._jrdd_deserializer) @staticmethod def createRDD(sc, kafkaParams, offsetRanges, leaders={}, @@ -161,8 +164,8 @@ def createRDD(sc, kafkaParams, offsetRanges, leaders={}, raise e ser = PairDeserializer(NoOpSerializer(), NoOpSerializer()) - rdd = RDD(jrdd, sc, ser) - return rdd.map(lambda k_v: (keyDecoder(k_v[0]), valueDecoder(k_v[1]))) + rdd = RDD(jrdd, sc, ser).map(lambda k_v: (keyDecoder(k_v[0]), valueDecoder(k_v[1]))) + return KafkaRDD(rdd._jrdd, rdd.ctx, rdd._jrdd_deserializer) @staticmethod def _printErrorMsg(sc): @@ -200,14 +203,30 @@ def __init__(self, topic, partition, fromOffset, untilOffset): :param fromOffset: Inclusive starting offset. :param untilOffset: Exclusive ending offset. """ - self._topic = topic - self._partition = partition - self._fromOffset = fromOffset - self._untilOffset = untilOffset + self.topic = topic + self.partition = partition + self.fromOffset = fromOffset + self.untilOffset = untilOffset + + def __eq__(self, other): + if isinstance(other, self.__class__): + return (self.topic == other.topic + and self.partition == other.partition + and self.fromOffset == other.fromOffset + and self.untilOffset == other.untilOffset) + else: + return False + + def __ne__(self, other): + return not self.__eq__(other) + + def __str__(self): + return "OffsetRange(topic: %s, partition: %d, range: [%d -> %d]" \ + % (self.topic, self.partition, self.fromOffset, self.untilOffset) def _jOffsetRange(self, helper): - return helper.createOffsetRange(self._topic, self._partition, self._fromOffset, - self._untilOffset) + return helper.createOffsetRange(self.topic, self.partition, self.fromOffset, + self.untilOffset) class TopicAndPartition(object): @@ -244,3 +263,87 @@ def __init__(self, host, port): def _jBroker(self, helper): return helper.createBroker(self._host, self._port) + + +class KafkaRDD(RDD): + """ + A Python wrapper of KafkaRDD, to provide additional information on normal RDD. + """ + + def __init__(self, jrdd, ctx, jrdd_deserializer): + RDD.__init__(self, jrdd, ctx, jrdd_deserializer) + + def offsetRanges(self): + """ + Get the OffsetRange of specific KafkaRDD. + :return: A list of OffsetRange + """ + try: + helperClass = self.ctx._jvm.java.lang.Thread.currentThread().getContextClassLoader() \ + .loadClass("org.apache.spark.streaming.kafka.KafkaUtilsPythonHelper") + helper = helperClass.newInstance() + joffsetRanges = helper.offsetRangesOfKafkaRDD(self._jrdd.rdd()) + except Py4JJavaError as e: + if 'ClassNotFoundException' in str(e.java_exception): + KafkaUtils._printErrorMsg(self.ctx) + raise e + + ranges = [OffsetRange(o.topic(), o.partition(), o.fromOffset(), o.untilOffset()) + for o in joffsetRanges] + return ranges + + +class KafkaDStream(DStream): + """ + A Python wrapper of KafkaDStream + """ + + def __init__(self, jdstream, ssc, jrdd_deserializer): + DStream.__init__(self, jdstream, ssc, jrdd_deserializer) + + def foreachRDD(self, func): + """ + Apply a function to each RDD in this DStream. + """ + if func.__code__.co_argcount == 1: + old_func = func + func = lambda r, rdd: old_func(rdd) + jfunc = TransformFunction(self._sc, func, self._jrdd_deserializer) \ + .rdd_wrapper(lambda jrdd, ctx, ser: KafkaRDD(jrdd, ctx, ser)) + api = self._ssc._jvm.PythonDStream + api.callForeachRDD(self._jdstream, jfunc) + + def transform(self, func): + """ + Return a new DStream in which each RDD is generated by applying a function + on each RDD of this DStream. + + `func` can have one argument of `rdd`, or have two arguments of + (`time`, `rdd`) + """ + if func.__code__.co_argcount == 1: + oldfunc = func + func = lambda t, rdd: oldfunc(rdd) + assert func.__code__.co_argcount == 2, "func should take one or two arguments" + + return KafkaTransformedDStream(self, func) + + +class KafkaTransformedDStream(TransformedDStream): + """ + Kafka specific wrapper of TransformedDStream to transform on Kafka RDD. + """ + + def __init__(self, prev, func): + TransformedDStream.__init__(self, prev, func) + + @property + def _jdstream(self): + if self._jdstream_val is not None: + return self._jdstream_val + + jfunc = TransformFunction(self._sc, self.func, self.prev._jrdd_deserializer) \ + .rdd_wrapper(lambda jrdd, ctx, ser: KafkaRDD(jrdd, ctx, ser)) + dstream = self._sc._jvm.PythonTransformedDStream(self.prev._jdstream.dstream(), jfunc) + self._jdstream_val = dstream.asJavaDStream() + return self._jdstream_val diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 188c8ff12067e..4ecae1e4bf282 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -678,6 +678,70 @@ def test_kafka_rdd_with_leaders(self): rdd = KafkaUtils.createRDD(self.sc, kafkaParams, offsetRanges, leaders) self._validateRddResult(sendData, rdd) + @unittest.skipIf(sys.version >= "3", "long type not support") + def test_kafka_rdd_get_offsetRanges(self): + """Test Python direct Kafka RDD get OffsetRanges.""" + topic = self._randomTopic() + sendData = {"a": 3, "b": 4, "c": 5} + offsetRanges = [OffsetRange(topic, 0, long(0), long(sum(sendData.values())))] + kafkaParams = {"metadata.broker.list": self._kafkaTestUtils.brokerAddress()} + + self._kafkaTestUtils.createTopic(topic) + self._kafkaTestUtils.sendMessages(topic, sendData) + rdd = KafkaUtils.createRDD(self.sc, kafkaParams, offsetRanges) + self.assertEqual(offsetRanges, rdd.offsetRanges()) + + @unittest.skipIf(sys.version >= "3", "long type not support") + def test_kafka_direct_stream_foreach_get_offsetRanges(self): + """Test the Python direct Kafka stream foreachRDD get offsetRanges.""" + topic = self._randomTopic() + sendData = {"a": 1, "b": 2, "c": 3} + kafkaParams = {"metadata.broker.list": self._kafkaTestUtils.brokerAddress(), + "auto.offset.reset": "smallest"} + + self._kafkaTestUtils.createTopic(topic) + self._kafkaTestUtils.sendMessages(topic, sendData) + + stream = KafkaUtils.createDirectStream(self.ssc, [topic], kafkaParams) + + offsetRanges = [] + + def getOffsetRanges(_, rdd): + for o in rdd.offsetRanges(): + offsetRanges.append(o) + + stream.foreachRDD(getOffsetRanges) + self.ssc.start() + self.wait_for(offsetRanges, 1) + + self.assertEqual(offsetRanges, [OffsetRange(topic, 0, long(0), long(6))]) + + @unittest.skipIf(sys.version >= "3", "long type not support") + def test_kafka_direct_stream_transform_get_offsetRanges(self): + """Test the Python direct Kafka stream transform get offsetRanges.""" + topic = self._randomTopic() + sendData = {"a": 1, "b": 2, "c": 3} + kafkaParams = {"metadata.broker.list": self._kafkaTestUtils.brokerAddress(), + "auto.offset.reset": "smallest"} + + self._kafkaTestUtils.createTopic(topic) + self._kafkaTestUtils.sendMessages(topic, sendData) + + stream = KafkaUtils.createDirectStream(self.ssc, [topic], kafkaParams) + + offsetRanges = [] + + def transformWithOffsetRanges(rdd): + for o in rdd.offsetRanges(): + offsetRanges.append(o) + return rdd + + stream.transform(transformWithOffsetRanges).foreachRDD(lambda rdd: rdd.count()) + self.ssc.start() + self.wait_for(offsetRanges, 1) + + self.assertEqual(offsetRanges, [OffsetRange(topic, 0, long(0), long(6))]) + class FlumeStreamTests(PySparkStreamingTestCase): timeout = 20 # seconds diff --git a/python/pyspark/streaming/util.py b/python/pyspark/streaming/util.py index a9bfec2aab8fc..b20613b1283bd 100644 --- a/python/pyspark/streaming/util.py +++ b/python/pyspark/streaming/util.py @@ -37,6 +37,11 @@ def __init__(self, ctx, func, *deserializers): self.ctx = ctx self.func = func self.deserializers = deserializers + self._rdd_wrapper = lambda jrdd, ctx, ser: RDD(jrdd, ctx, ser) + + def rdd_wrapper(self, func): + self._rdd_wrapper = func + return self def call(self, milliseconds, jrdds): try: @@ -51,7 +56,7 @@ def call(self, milliseconds, jrdds): if len(sers) < len(jrdds): sers += (sers[0],) * (len(jrdds) - len(sers)) - rdds = [RDD(jrdd, self.ctx, ser) if jrdd else None + rdds = [self._rdd_wrapper(jrdd, self.ctx, ser) if jrdd else None for jrdd, ser in zip(jrdds, sers)] t = datetime.fromtimestamp(milliseconds / 1000.0) r = self.func(t, *rdds) From c9e2ef52bb54f35a904427389dc492d61f29b018 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 9 Jul 2015 14:43:38 -0700 Subject: [PATCH 18/23] [SPARK-7902] [SPARK-6289] [SPARK-8685] [SQL] [PYSPARK] Refactor of serialization for Python DataFrame This PR fix the long standing issue of serialization between Python RDD and DataFrame, it change to using a customized Pickler for InternalRow to enable customized unpickling (type conversion, especially for UDT), now we can support UDT for UDF, cc mengxr . There is no generated `Row` anymore. Author: Davies Liu Closes #7301 from davies/sql_ser and squashes the following commits: 81bef71 [Davies Liu] address comments e9217bd [Davies Liu] add regression tests db34167 [Davies Liu] Refactor of serialization for Python DataFrame --- python/pyspark/sql/context.py | 5 +- python/pyspark/sql/dataframe.py | 16 +- python/pyspark/sql/tests.py | 28 +- python/pyspark/sql/types.py | 419 ++++++------------ .../spark/sql/catalyst/expressions/rows.scala | 12 + .../org/apache/spark/sql/DataFrame.scala | 5 +- .../spark/sql/execution/pythonUDFs.scala | 122 ++++- 7 files changed, 292 insertions(+), 315 deletions(-) diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 309c11faf9319..c93a15badae29 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -30,7 +30,7 @@ from pyspark.serializers import AutoBatchedSerializer, PickleSerializer from pyspark.sql import since from pyspark.sql.types import Row, StringType, StructType, _verify_type, \ - _infer_schema, _has_nulltype, _merge_type, _create_converter, _python_to_sql_converter + _infer_schema, _has_nulltype, _merge_type, _create_converter from pyspark.sql.dataframe import DataFrame from pyspark.sql.readwriter import DataFrameReader from pyspark.sql.utils import install_exception_handler @@ -388,8 +388,7 @@ def createDataFrame(self, data, schema=None, samplingRatio=None): raise TypeError("schema should be StructType or list or None") # convert python objects to sql data - converter = _python_to_sql_converter(schema) - rdd = rdd.map(converter) + rdd = rdd.map(schema.toInternal) jrdd = self._jvm.SerDeUtil.toJavaArray(rdd._to_java_object_rdd()) df = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json()) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 1e9c657cf81b3..83e02b85f06f1 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -31,7 +31,7 @@ from pyspark.storagelevel import StorageLevel from pyspark.traceback_utils import SCCallSiteSync from pyspark.sql import since -from pyspark.sql.types import _create_cls, _parse_datatype_json_string +from pyspark.sql.types import _parse_datatype_json_string from pyspark.sql.column import Column, _to_seq, _to_java_column from pyspark.sql.readwriter import DataFrameWriter from pyspark.sql.types import * @@ -83,15 +83,7 @@ def rdd(self): """ if self._lazy_rdd is None: jrdd = self._jdf.javaToPython() - rdd = RDD(jrdd, self.sql_ctx._sc, BatchedSerializer(PickleSerializer())) - schema = self.schema - - def applySchema(it): - cls = _create_cls(schema) - return map(cls, it) - - self._lazy_rdd = rdd.mapPartitions(applySchema) - + self._lazy_rdd = RDD(jrdd, self.sql_ctx._sc, BatchedSerializer(PickleSerializer())) return self._lazy_rdd @property @@ -287,9 +279,7 @@ def collect(self): """ with SCCallSiteSync(self._sc) as css: port = self._sc._jvm.PythonRDD.collectAndServe(self._jdf.javaToPython().rdd()) - rs = list(_load_from_socket(port, BatchedSerializer(PickleSerializer()))) - cls = _create_cls(self.schema) - return [cls(r) for r in rs] + return list(_load_from_socket(port, BatchedSerializer(PickleSerializer()))) @ignore_unicode_prefix @since(1.3) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 66827d48850d9..4d7cad5a1ab88 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -151,6 +151,17 @@ def test_range(self): self.assertEqual(self.sqlCtx.range(-2).count(), 0) self.assertEqual(self.sqlCtx.range(3).count(), 3) + def test_duplicated_column_names(self): + df = self.sqlCtx.createDataFrame([(1, 2)], ["c", "c"]) + row = df.select('*').first() + self.assertEqual(1, row[0]) + self.assertEqual(2, row[1]) + self.assertEqual("Row(c=1, c=2)", str(row)) + # Cannot access columns + self.assertRaises(AnalysisException, lambda: df.select(df[0]).first()) + self.assertRaises(AnalysisException, lambda: df.select(df.c).first()) + self.assertRaises(AnalysisException, lambda: df.select(df["c"]).first()) + def test_explode(self): from pyspark.sql.functions import explode d = [Row(a=1, intlist=[1, 2, 3], mapfield={"a": "b"})] @@ -401,6 +412,14 @@ def test_apply_schema_with_udt(self): point = df.head().point self.assertEquals(point, ExamplePoint(1.0, 2.0)) + def test_udf_with_udt(self): + from pyspark.sql.tests import ExamplePoint + row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) + df = self.sc.parallelize([row]).toDF() + self.assertEqual(1.0, df.map(lambda r: r.point.x).first()) + udf = UserDefinedFunction(lambda p: p.y, DoubleType()) + self.assertEqual(2.0, df.select(udf(df.point)).first()[0]) + def test_parquet_with_udt(self): from pyspark.sql.tests import ExamplePoint row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) @@ -693,12 +712,9 @@ def test_time_with_timezone(self): utcnow = datetime.datetime.fromtimestamp(ts, utc) df = self.sqlCtx.createDataFrame([(day, now, utcnow)]) day1, now1, utcnow1 = df.first() - # Pyrolite serialize java.sql.Date as datetime, will be fixed in new version - self.assertEqual(day1.date(), day) - # Pyrolite does not support microsecond, the error should be - # less than 1 millisecond - self.assertTrue(now - now1 < datetime.timedelta(0.001)) - self.assertTrue(now - utcnow1 < datetime.timedelta(0.001)) + self.assertEqual(day1, day) + self.assertEqual(now, now1) + self.assertEqual(now, utcnow1) def test_decimal(self): from decimal import Decimal diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index fecfe6d71e9a7..d63857691675a 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -20,13 +20,9 @@ import time import datetime import calendar -import keyword -import warnings import json import re -import weakref from array import array -from operator import itemgetter if sys.version >= "3": long = int @@ -71,6 +67,26 @@ def json(self): separators=(',', ':'), sort_keys=True) + def needConversion(self): + """ + Does this type need to conversion between Python object and internal SQL object. + + This is used to avoid the unnecessary conversion for ArrayType/MapType/StructType. + """ + return False + + def toInternal(self, obj): + """ + Converts a Python object into an internal SQL object. + """ + return obj + + def fromInternal(self, obj): + """ + Converts an internal SQL object into a native Python object. + """ + return obj + # This singleton pattern does not work with pickle, you will get # another object after pickle and unpickle @@ -143,6 +159,17 @@ class DateType(AtomicType): __metaclass__ = DataTypeSingleton + EPOCH_ORDINAL = datetime.datetime(1970, 1, 1).toordinal() + + def needConversion(self): + return True + + def toInternal(self, d): + return d and d.toordinal() - self.EPOCH_ORDINAL + + def fromInternal(self, v): + return v and datetime.date.fromordinal(v + self.EPOCH_ORDINAL) + class TimestampType(AtomicType): """Timestamp (datetime.datetime) data type. @@ -150,6 +177,19 @@ class TimestampType(AtomicType): __metaclass__ = DataTypeSingleton + def needConversion(self): + return True + + def toInternal(self, dt): + if dt is not None: + seconds = (calendar.timegm(dt.utctimetuple()) if dt.tzinfo + else time.mktime(dt.timetuple())) + return int(seconds * 1e6 + dt.microsecond) + + def fromInternal(self, ts): + if ts is not None: + return datetime.datetime.fromtimestamp(ts / 1e6) + class DecimalType(FractionalType): """Decimal (decimal.Decimal) data type. @@ -259,6 +299,19 @@ def fromJson(cls, json): return ArrayType(_parse_datatype_json_value(json["elementType"]), json["containsNull"]) + def needConversion(self): + return self.elementType.needConversion() + + def toInternal(self, obj): + if not self.needConversion(): + return obj + return obj and [self.elementType.toInternal(v) for v in obj] + + def fromInternal(self, obj): + if not self.needConversion(): + return obj + return obj and [self.elementType.fromInternal(v) for v in obj] + class MapType(DataType): """Map data type. @@ -304,6 +357,21 @@ def fromJson(cls, json): _parse_datatype_json_value(json["valueType"]), json["valueContainsNull"]) + def needConversion(self): + return self.keyType.needConversion() or self.valueType.needConversion() + + def toInternal(self, obj): + if not self.needConversion(): + return obj + return obj and dict((self.keyType.toInternal(k), self.valueType.toInternal(v)) + for k, v in obj.items()) + + def fromInternal(self, obj): + if not self.needConversion(): + return obj + return obj and dict((self.keyType.fromInternal(k), self.valueType.fromInternal(v)) + for k, v in obj.items()) + class StructField(DataType): """A field in :class:`StructType`. @@ -311,7 +379,7 @@ class StructField(DataType): :param name: string, name of the field. :param dataType: :class:`DataType` of the field. :param nullable: boolean, whether the field can be null (None) or not. - :param metadata: a dict from string to simple type that can be serialized to JSON automatically + :param metadata: a dict from string to simple type that can be toInternald to JSON automatically """ def __init__(self, name, dataType, nullable=True, metadata=None): @@ -351,6 +419,15 @@ def fromJson(cls, json): json["nullable"], json["metadata"]) + def needConversion(self): + return self.dataType.needConversion() + + def toInternal(self, obj): + return self.dataType.toInternal(obj) + + def fromInternal(self, obj): + return self.dataType.fromInternal(obj) + class StructType(DataType): """Struct type, consisting of a list of :class:`StructField`. @@ -371,10 +448,13 @@ def __init__(self, fields=None): """ if not fields: self.fields = [] + self.names = [] else: self.fields = fields + self.names = [f.name for f in fields] assert all(isinstance(f, StructField) for f in fields),\ "fields should be a list of StructField" + self._needSerializeFields = None def add(self, field, data_type=None, nullable=True, metadata=None): """ @@ -406,6 +486,7 @@ def add(self, field, data_type=None, nullable=True, metadata=None): """ if isinstance(field, StructField): self.fields.append(field) + self.names.append(field.name) else: if isinstance(field, str) and data_type is None: raise ValueError("Must specify DataType if passing name of struct_field to create.") @@ -415,6 +496,7 @@ def add(self, field, data_type=None, nullable=True, metadata=None): else: data_type_f = data_type self.fields.append(StructField(field, data_type_f, nullable, metadata)) + self.names.append(field) return self def simpleString(self): @@ -432,6 +514,41 @@ def jsonValue(self): def fromJson(cls, json): return StructType([StructField.fromJson(f) for f in json["fields"]]) + def needConversion(self): + # We need convert Row()/namedtuple into tuple() + return True + + def toInternal(self, obj): + if obj is None: + return + + if self._needSerializeFields is None: + self._needSerializeFields = any(f.needConversion() for f in self.fields) + + if self._needSerializeFields: + if isinstance(obj, dict): + return tuple(f.toInternal(obj.get(n)) for n, f in zip(names, self.fields)) + elif isinstance(obj, (tuple, list)): + return tuple(f.toInternal(v) for f, v in zip(self.fields, obj)) + else: + raise ValueError("Unexpected tuple %r with StructType" % obj) + else: + if isinstance(obj, dict): + return tuple(obj.get(n) for n in self.names) + elif isinstance(obj, (list, tuple)): + return tuple(obj) + else: + raise ValueError("Unexpected tuple %r with StructType" % obj) + + def fromInternal(self, obj): + if obj is None: + return + if isinstance(obj, Row): + # it's already converted by pickler + return obj + values = [f.dataType.fromInternal(v) for f, v in zip(self.fields, obj)] + return _create_row(self.names, values) + class UserDefinedType(DataType): """User-defined type (UDT). @@ -464,17 +581,35 @@ def scalaUDT(cls): """ raise NotImplementedError("UDT must have a paired Scala UDT.") + def needConversion(self): + return True + + @classmethod + def _cachedSqlType(cls): + """ + Cache the sqlType() into class, because it's heavy used in `toInternal`. + """ + if not hasattr(cls, "_cached_sql_type"): + cls._cached_sql_type = cls.sqlType() + return cls._cached_sql_type + + def toInternal(self, obj): + return self._cachedSqlType().toInternal(self.serialize(obj)) + + def fromInternal(self, obj): + return self.deserialize(self._cachedSqlType().fromInternal(obj)) + def serialize(self, obj): """ Converts the a user-type object into a SQL datum. """ - raise NotImplementedError("UDT must implement serialize().") + raise NotImplementedError("UDT must implement toInternal().") def deserialize(self, datum): """ Converts a SQL datum into a user-type object. """ - raise NotImplementedError("UDT must implement deserialize().") + raise NotImplementedError("UDT must implement fromInternal().") def simpleString(self): return 'udt' @@ -671,117 +806,6 @@ def _infer_schema(row): return StructType(fields) -def _need_python_to_sql_conversion(dataType): - """ - Checks whether we need python to sql conversion for the given type. - For now, only UDTs need this conversion. - - >>> _need_python_to_sql_conversion(DoubleType()) - False - >>> schema0 = StructType([StructField("indices", ArrayType(IntegerType(), False), False), - ... StructField("values", ArrayType(DoubleType(), False), False)]) - >>> _need_python_to_sql_conversion(schema0) - True - >>> _need_python_to_sql_conversion(ExamplePointUDT()) - True - >>> schema1 = ArrayType(ExamplePointUDT(), False) - >>> _need_python_to_sql_conversion(schema1) - True - >>> schema2 = StructType([StructField("label", DoubleType(), False), - ... StructField("point", ExamplePointUDT(), False)]) - >>> _need_python_to_sql_conversion(schema2) - True - """ - if isinstance(dataType, StructType): - # convert namedtuple or Row into tuple - return True - elif isinstance(dataType, ArrayType): - return _need_python_to_sql_conversion(dataType.elementType) - elif isinstance(dataType, MapType): - return _need_python_to_sql_conversion(dataType.keyType) or \ - _need_python_to_sql_conversion(dataType.valueType) - elif isinstance(dataType, UserDefinedType): - return True - elif isinstance(dataType, (DateType, TimestampType)): - return True - else: - return False - - -EPOCH_ORDINAL = datetime.datetime(1970, 1, 1).toordinal() - - -def _python_to_sql_converter(dataType): - """ - Returns a converter that converts a Python object into a SQL datum for the given type. - - >>> conv = _python_to_sql_converter(DoubleType()) - >>> conv(1.0) - 1.0 - >>> conv = _python_to_sql_converter(ArrayType(DoubleType(), False)) - >>> conv([1.0, 2.0]) - [1.0, 2.0] - >>> conv = _python_to_sql_converter(ExamplePointUDT()) - >>> conv(ExamplePoint(1.0, 2.0)) - [1.0, 2.0] - >>> schema = StructType([StructField("label", DoubleType(), False), - ... StructField("point", ExamplePointUDT(), False)]) - >>> conv = _python_to_sql_converter(schema) - >>> conv((1.0, ExamplePoint(1.0, 2.0))) - (1.0, [1.0, 2.0]) - """ - if not _need_python_to_sql_conversion(dataType): - return lambda x: x - - if isinstance(dataType, StructType): - names, types = zip(*[(f.name, f.dataType) for f in dataType.fields]) - if any(_need_python_to_sql_conversion(t) for t in types): - converters = [_python_to_sql_converter(t) for t in types] - - def converter(obj): - if isinstance(obj, dict): - return tuple(c(obj.get(n)) for n, c in zip(names, converters)) - elif isinstance(obj, tuple): - if hasattr(obj, "__fields__") or hasattr(obj, "_fields"): - return tuple(c(v) for c, v in zip(converters, obj)) - else: - return tuple(c(v) for c, v in zip(converters, obj)) - elif obj is not None: - raise ValueError("Unexpected tuple %r with type %r" % (obj, dataType)) - else: - def converter(obj): - if isinstance(obj, dict): - return tuple(obj.get(n) for n in names) - else: - return tuple(obj) - return converter - elif isinstance(dataType, ArrayType): - element_converter = _python_to_sql_converter(dataType.elementType) - return lambda a: a and [element_converter(v) for v in a] - elif isinstance(dataType, MapType): - key_converter = _python_to_sql_converter(dataType.keyType) - value_converter = _python_to_sql_converter(dataType.valueType) - return lambda m: m and dict([(key_converter(k), value_converter(v)) for k, v in m.items()]) - - elif isinstance(dataType, UserDefinedType): - return lambda obj: obj and dataType.serialize(obj) - - elif isinstance(dataType, DateType): - return lambda d: d and d.toordinal() - EPOCH_ORDINAL - - elif isinstance(dataType, TimestampType): - - def to_posix_timstamp(dt): - if dt: - seconds = (calendar.timegm(dt.utctimetuple()) if dt.tzinfo - else time.mktime(dt.timetuple())) - return int(seconds * 1e6 + dt.microsecond) - return to_posix_timstamp - - else: - raise ValueError("Unexpected type %r" % dataType) - - def _has_nulltype(dt): """ Return whether there is NullType in `dt` or not """ if isinstance(dt, StructType): @@ -1076,7 +1100,7 @@ def _verify_type(obj, dataType): if isinstance(dataType, UserDefinedType): if not (hasattr(obj, '__UDT__') and obj.__UDT__ == dataType): raise ValueError("%r is not an instance of type %r" % (obj, dataType)) - _verify_type(dataType.serialize(obj), dataType.sqlType()) + _verify_type(dataType.toInternal(obj), dataType.sqlType()) return _type = type(dataType) @@ -1086,7 +1110,7 @@ def _verify_type(obj, dataType): if not isinstance(obj, (tuple, list)): raise TypeError("StructType can not accept object in type %s" % type(obj)) else: - # subclass of them can not be deserialized in JVM + # subclass of them can not be fromInternald in JVM if type(obj) not in _acceptable_types[_type]: raise TypeError("%s can not accept object in type %s" % (dataType, type(obj))) @@ -1106,159 +1130,10 @@ def _verify_type(obj, dataType): for v, f in zip(obj, dataType.fields): _verify_type(v, f.dataType) -_cached_cls = weakref.WeakValueDictionary() - - -def _restore_object(dataType, obj): - """ Restore object during unpickling. """ - # use id(dataType) as key to speed up lookup in dict - # Because of batched pickling, dataType will be the - # same object in most cases. - k = id(dataType) - cls = _cached_cls.get(k) - if cls is None or cls.__datatype is not dataType: - # use dataType as key to avoid create multiple class - cls = _cached_cls.get(dataType) - if cls is None: - cls = _create_cls(dataType) - _cached_cls[dataType] = cls - cls.__datatype = dataType - _cached_cls[k] = cls - return cls(obj) - - -def _create_object(cls, v): - """ Create an customized object with class `cls`. """ - # datetime.date would be deserialized as datetime.datetime - # from java type, so we need to set it back. - if cls is datetime.date and isinstance(v, datetime.datetime): - return v.date() - return cls(v) if v is not None else v - - -def _create_getter(dt, i): - """ Create a getter for item `i` with schema """ - cls = _create_cls(dt) - - def getter(self): - return _create_object(cls, self[i]) - - return getter - - -def _has_struct_or_date(dt): - """Return whether `dt` is or has StructType/DateType in it""" - if isinstance(dt, StructType): - return True - elif isinstance(dt, ArrayType): - return _has_struct_or_date(dt.elementType) - elif isinstance(dt, MapType): - return _has_struct_or_date(dt.keyType) or _has_struct_or_date(dt.valueType) - elif isinstance(dt, DateType): - return True - elif isinstance(dt, UserDefinedType): - return True - return False - - -def _create_properties(fields): - """Create properties according to fields""" - ps = {} - for i, f in enumerate(fields): - name = f.name - if (name.startswith("__") and name.endswith("__") - or keyword.iskeyword(name)): - warnings.warn("field name %s can not be accessed in Python," - "use position to access it instead" % name) - if _has_struct_or_date(f.dataType): - # delay creating object until accessing it - getter = _create_getter(f.dataType, i) - else: - getter = itemgetter(i) - ps[name] = property(getter) - return ps - - -def _create_cls(dataType): - """ - Create an class by dataType - - The created class is similar to namedtuple, but can have nested schema. - - >>> schema = _parse_schema_abstract("a b c") - >>> row = (1, 1.0, "str") - >>> schema = _infer_schema_type(row, schema) - >>> obj = _create_cls(schema)(row) - >>> import pickle - >>> pickle.loads(pickle.dumps(obj)) - Row(a=1, b=1.0, c='str') - - >>> row = [[1], {"key": (1, 2.0)}] - >>> schema = _parse_schema_abstract("a[] b{c d}") - >>> schema = _infer_schema_type(row, schema) - >>> obj = _create_cls(schema)(row) - >>> pickle.loads(pickle.dumps(obj)) - Row(a=[1], b={'key': Row(c=1, d=2.0)}) - >>> pickle.loads(pickle.dumps(obj.a)) - [1] - >>> pickle.loads(pickle.dumps(obj.b)) - {'key': Row(c=1, d=2.0)} - """ - - if isinstance(dataType, ArrayType): - cls = _create_cls(dataType.elementType) - - def List(l): - if l is None: - return - return [_create_object(cls, v) for v in l] - - return List - - elif isinstance(dataType, MapType): - kcls = _create_cls(dataType.keyType) - vcls = _create_cls(dataType.valueType) - - def Dict(d): - if d is None: - return - return dict((_create_object(kcls, k), _create_object(vcls, v)) for k, v in d.items()) - - return Dict - - elif isinstance(dataType, DateType): - return datetime.date - - elif isinstance(dataType, UserDefinedType): - return lambda datum: dataType.deserialize(datum) - - elif not isinstance(dataType, StructType): - # no wrapper for atomic types - return lambda x: x - - class Row(tuple): - - """ Row in DataFrame """ - __datatype = dataType - __fields__ = tuple(f.name for f in dataType.fields) - __slots__ = () - - # create property for fast access - locals().update(_create_properties(dataType.fields)) - - def asDict(self): - """ Return as a dict """ - return dict((n, getattr(self, n)) for n in self.__fields__) - - def __repr__(self): - # call collect __repr__ for nested objects - return ("Row(%s)" % ", ".join("%s=%r" % (n, getattr(self, n)) - for n in self.__fields__)) - - def __reduce__(self): - return (_restore_object, (self.__datatype, tuple(self))) - return Row +# This is used to unpickle a Row from JVM +def _create_row_inbound_converter(dataType): + return lambda *a: dataType.fromInternal(a) def _create_row(fields, values): diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index 8b472a529e5c9..094904bbf9c15 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -132,6 +132,18 @@ class GenericInternalRow(protected[sql] val values: Array[Any]) override def copy(): InternalRow = this } +/** + * This is used for serialization of Python DataFrame + */ +class GenericInternalRowWithSchema(values: Array[Any], override val schema: StructType) + extends GenericInternalRow(values) { + + /** No-arg constructor for serialization. */ + protected def this() = this(null, null) + + override def fieldIndex(name: String): Int = schema.fieldIndex(name) +} + class GenericMutableRow(val values: Array[Any]) extends MutableRow with ArrayBackedRow { /** No-arg constructor for serialization. */ protected def this() = this(null) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index d9f987ae0252f..d7966651b1948 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -30,7 +30,6 @@ import org.apache.commons.lang3.StringUtils import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.api.java.JavaRDD -import org.apache.spark.api.python.SerDeUtil import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis._ @@ -1550,8 +1549,8 @@ class DataFrame private[sql]( */ protected[sql] def javaToPython: JavaRDD[Array[Byte]] = { val structType = schema // capture it for closure - val jrdd = queryExecution.toRdd.map(EvaluatePython.toJava(_, structType)).toJavaRDD() - SerDeUtil.javaToPython(jrdd) + val rdd = queryExecution.toRdd.map(EvaluatePython.toJava(_, structType)) + EvaluatePython.javaToPython(rdd) } //////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala index 1c8130b07c7fb..6d6e67dace177 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala @@ -17,15 +17,16 @@ package org.apache.spark.sql.execution +import java.io.OutputStream import java.util.{List => JList, Map => JMap} import scala.collection.JavaConversions._ import scala.collection.JavaConverters._ -import net.razorvine.pickle.{Pickler, Unpickler} +import net.razorvine.pickle._ import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.api.python.{PythonBroadcast, PythonRDD} +import org.apache.spark.api.python.{PythonBroadcast, PythonRDD, SerDeUtil} import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow @@ -33,7 +34,6 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.{Accumulator, Logging => SparkLogging} @@ -130,8 +130,13 @@ object EvaluatePython { case (null, _) => null case (row: InternalRow, struct: StructType) => - val fields = struct.fields.map(field => field.dataType) - rowToArray(row, fields) + val values = new Array[Any](row.size) + var i = 0 + while (i < row.size) { + values(i) = toJava(row(i), struct.fields(i).dataType) + i += 1 + } + new GenericInternalRowWithSchema(values, struct) case (seq: Seq[Any], array: ArrayType) => seq.map(x => toJava(x, array.elementType)).asJava @@ -142,9 +147,6 @@ object EvaluatePython { case (ud, udt: UserDefinedType[_]) => toJava(ud, udt.sqlType) - case (date: Int, DateType) => DateTimeUtils.toJavaDate(date) - case (t: Long, TimestampType) => DateTimeUtils.toJavaTimestamp(t) - case (d: Decimal, _) => d.toJavaBigDecimal case (s: UTF8String, StringType) => s.toString @@ -152,14 +154,6 @@ object EvaluatePython { case (other, _) => other } - /** - * Convert Row into Java Array (for pickled into Python) - */ - def rowToArray(row: InternalRow, fields: Seq[DataType]): Array[Any] = { - // TODO: this is slow! - row.toSeq.zip(fields).map {case (obj, dt) => toJava(obj, dt)}.toArray - } - /** * Converts `obj` to the type specified by the data type, or returns null if the type of obj is * unexpected. Because Python doesn't enforce the type. @@ -220,6 +214,96 @@ object EvaluatePython { // TODO(davies): we could improve this by try to cast the object to expected type case (c, _) => null } + + + private val module = "pyspark.sql.types" + + /** + * Pickler for StructType + */ + private class StructTypePickler extends IObjectPickler { + + private val cls = classOf[StructType] + + def register(): Unit = { + Pickler.registerCustomPickler(cls, this) + } + + def pickle(obj: Object, out: OutputStream, pickler: Pickler): Unit = { + out.write(Opcodes.GLOBAL) + out.write((module + "\n" + "_parse_datatype_json_string" + "\n").getBytes("utf-8")) + val schema = obj.asInstanceOf[StructType] + pickler.save(schema.json) + out.write(Opcodes.TUPLE1) + out.write(Opcodes.REDUCE) + } + } + + /** + * Pickler for InternalRow + */ + private class RowPickler extends IObjectPickler { + + private val cls = classOf[GenericInternalRowWithSchema] + + // register this to Pickler and Unpickler + def register(): Unit = { + Pickler.registerCustomPickler(this.getClass, this) + Pickler.registerCustomPickler(cls, this) + } + + def pickle(obj: Object, out: OutputStream, pickler: Pickler): Unit = { + if (obj == this) { + out.write(Opcodes.GLOBAL) + out.write((module + "\n" + "_create_row_inbound_converter" + "\n").getBytes("utf-8")) + } else { + // it will be memorized by Pickler to save some bytes + pickler.save(this) + val row = obj.asInstanceOf[GenericInternalRowWithSchema] + // schema should always be same object for memoization + pickler.save(row.schema) + out.write(Opcodes.TUPLE1) + out.write(Opcodes.REDUCE) + + out.write(Opcodes.MARK) + var i = 0 + while (i < row.values.size) { + pickler.save(row.values(i)) + i += 1 + } + row.values.foreach(pickler.save) + out.write(Opcodes.TUPLE) + out.write(Opcodes.REDUCE) + } + } + } + + private[this] var registered = false + /** + * This should be called before trying to serialize any above classes un cluster mode, + * this should be put in the closure + */ + def registerPicklers(): Unit = { + synchronized { + if (!registered) { + SerDeUtil.initialize() + new StructTypePickler().register() + new RowPickler().register() + registered = true + } + } + } + + /** + * Convert an RDD of Java objects to an RDD of serialized Python objects, that is usable by + * PySpark. + */ + def javaToPython(rdd: RDD[Any]): RDD[Array[Byte]] = { + rdd.mapPartitions { iter => + registerPicklers() // let it called in executor + new SerDeUtil.AutoBatchedPickler(iter) + } + } } /** @@ -254,12 +338,14 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: val childResults = child.execute().map(_.copy()) val parent = childResults.mapPartitions { iter => + EvaluatePython.registerPicklers() // register pickler for Row val pickle = new Pickler val currentRow = newMutableProjection(udf.children, child.output)() val fields = udf.children.map(_.dataType) - iter.grouped(1000).map { inputRows => + val schema = new StructType(fields.map(t => new StructField("", t, true)).toArray) + iter.grouped(100).map { inputRows => val toBePickled = inputRows.map { row => - EvaluatePython.rowToArray(currentRow(row), fields) + EvaluatePython.toJava(currentRow(row), schema) }.toArray pickle.dumps(toBePickled) } From 897700369f3aedf1a8fdb0984dd3d6d8e498e3af Mon Sep 17 00:00:00 2001 From: guowei2 Date: Thu, 9 Jul 2015 15:01:53 -0700 Subject: [PATCH 19/23] [SPARK-8865] [STREAMING] FIX BUG: check key in kafka params Author: guowei2 Closes #7254 from guowei2/spark-8865 and squashes the following commits: 48ca17a [guowei2] fix contains key --- .../scala/org/apache/spark/streaming/kafka/KafkaCluster.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala index 3e6b937af57b0..8465432c5850f 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala @@ -410,7 +410,7 @@ object KafkaCluster { } Seq("zookeeper.connect", "group.id").foreach { s => - if (!props.contains(s)) { + if (!props.containsKey(s)) { props.setProperty(s, "") } } From 69165330303a71ea1da748eca7a780ec172b326f Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Thu, 9 Jul 2015 15:14:14 -0700 Subject: [PATCH 20/23] Closes #6837 Closes #7321 Closes #2634 Closes #4963 Closes #2137 From e29ce319fa6ffb9c8e5110814d4923d433aa1b76 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Thu, 9 Jul 2015 15:49:30 -0700 Subject: [PATCH 21/23] [SPARK-8963][ML] cleanup tests in linear regression suite Simplify model weight assertions to use vector comparision, switch to using absTol when comparing with 0.0 intercepts Author: Holden Karau Closes #7327 from holdenk/SPARK-8913-cleanup-tests-from-SPARK-8700-logistic-regression and squashes the following commits: 5bac185 [Holden Karau] Simplify model weight assertions to use vector comparision, switch to using absTol when comparing with 0.0 intercepts --- .../ml/regression/LinearRegressionSuite.scala | 57 ++++++++----------- 1 file changed, 24 insertions(+), 33 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index 5f39d44f37352..4f6a57739558b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.ml.regression import org.apache.spark.SparkFunSuite -import org.apache.spark.mllib.linalg.DenseVector +import org.apache.spark.mllib.linalg.{DenseVector, Vectors} import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.{DataFrame, Row} @@ -75,11 +75,10 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { as.numeric.data.V3. 7.198257 */ val interceptR = 6.298698 - val weightsR = Array(4.700706, 7.199082) + val weightsR = Vectors.dense(4.700706, 7.199082) assert(model.intercept ~== interceptR relTol 1E-3) - assert(model.weights(0) ~== weightsR(0) relTol 1E-3) - assert(model.weights(1) ~== weightsR(1) relTol 1E-3) + assert(model.weights ~= weightsR relTol 1E-3) model.transform(dataset).select("features", "prediction").collect().foreach { case Row(features: DenseVector, prediction1: Double) => @@ -104,11 +103,10 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { as.numeric.data.V2. 6.995908 as.numeric.data.V3. 5.275131 */ - val weightsR = Array(6.995908, 5.275131) + val weightsR = Vectors.dense(6.995908, 5.275131) - assert(model.intercept ~== 0 relTol 1E-3) - assert(model.weights(0) ~== weightsR(0) relTol 1E-3) - assert(model.weights(1) ~== weightsR(1) relTol 1E-3) + assert(model.intercept ~== 0 absTol 1E-3) + assert(model.weights ~= weightsR relTol 1E-3) /* Then again with the data with no intercept: > weightsWithoutIntercept @@ -118,11 +116,10 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { as.numeric.data3.V2. 4.70011 as.numeric.data3.V3. 7.19943 */ - val weightsWithoutInterceptR = Array(4.70011, 7.19943) + val weightsWithoutInterceptR = Vectors.dense(4.70011, 7.19943) - assert(modelWithoutIntercept.intercept ~== 0 relTol 1E-3) - assert(modelWithoutIntercept.weights(0) ~== weightsWithoutInterceptR(0) relTol 1E-3) - assert(modelWithoutIntercept.weights(1) ~== weightsWithoutInterceptR(1) relTol 1E-3) + assert(modelWithoutIntercept.intercept ~== 0 absTol 1E-3) + assert(modelWithoutIntercept.weights ~= weightsWithoutInterceptR relTol 1E-3) } test("linear regression with intercept with L1 regularization") { @@ -139,11 +136,10 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { as.numeric.data.V3. 6.679841 */ val interceptR = 6.24300 - val weightsR = Array(4.024821, 6.679841) + val weightsR = Vectors.dense(4.024821, 6.679841) assert(model.intercept ~== interceptR relTol 1E-3) - assert(model.weights(0) ~== weightsR(0) relTol 1E-3) - assert(model.weights(1) ~== weightsR(1) relTol 1E-3) + assert(model.weights ~= weightsR relTol 1E-3) model.transform(dataset).select("features", "prediction").collect().foreach { case Row(features: DenseVector, prediction1: Double) => @@ -169,11 +165,10 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { as.numeric.data.V3. 4.772913 */ val interceptR = 0.0 - val weightsR = Array(6.299752, 4.772913) + val weightsR = Vectors.dense(6.299752, 4.772913) - assert(model.intercept ~== interceptR relTol 1E-3) - assert(model.weights(0) ~== weightsR(0) relTol 1E-3) - assert(model.weights(1) ~== weightsR(1) relTol 1E-3) + assert(model.intercept ~== interceptR absTol 1E-5) + assert(model.weights ~= weightsR relTol 1E-3) model.transform(dataset).select("features", "prediction").collect().foreach { case Row(features: DenseVector, prediction1: Double) => @@ -197,11 +192,10 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { as.numeric.data.V3. 4.926260 */ val interceptR = 5.269376 - val weightsR = Array(3.736216, 5.712356) + val weightsR = Vectors.dense(3.736216, 5.712356) assert(model.intercept ~== interceptR relTol 1E-3) - assert(model.weights(0) ~== weightsR(0) relTol 1E-3) - assert(model.weights(1) ~== weightsR(1) relTol 1E-3) + assert(model.weights ~= weightsR relTol 1E-3) model.transform(dataset).select("features", "prediction").collect().foreach { case Row(features: DenseVector, prediction1: Double) => @@ -227,11 +221,10 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { as.numeric.data.V3. 4.214502 */ val interceptR = 0.0 - val weightsR = Array(5.522875, 4.214502) + val weightsR = Vectors.dense(5.522875, 4.214502) - assert(model.intercept ~== interceptR relTol 1E-3) - assert(model.weights(0) ~== weightsR(0) relTol 1E-3) - assert(model.weights(1) ~== weightsR(1) relTol 1E-3) + assert(model.intercept ~== interceptR absTol 1E-3) + assert(model.weights ~== weightsR relTol 1E-3) model.transform(dataset).select("features", "prediction").collect().foreach { case Row(features: DenseVector, prediction1: Double) => @@ -255,11 +248,10 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { as.numeric.data.V3. 5.200403 */ val interceptR = 5.696056 - val weightsR = Array(3.670489, 6.001122) + val weightsR = Vectors.dense(3.670489, 6.001122) assert(model.intercept ~== interceptR relTol 1E-3) - assert(model.weights(0) ~== weightsR(0) relTol 1E-3) - assert(model.weights(1) ~== weightsR(1) relTol 1E-3) + assert(model.weights ~== weightsR relTol 1E-3) model.transform(dataset).select("features", "prediction").collect().foreach { case Row(features: DenseVector, prediction1: Double) => @@ -285,11 +277,10 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { as.numeric.dataM.V3. 4.322251 */ val interceptR = 0.0 - val weightsR = Array(5.673348, 4.322251) + val weightsR = Vectors.dense(5.673348, 4.322251) - assert(model.intercept ~== interceptR relTol 1E-3) - assert(model.weights(0) ~== weightsR(0) relTol 1E-3) - assert(model.weights(1) ~== weightsR(1) relTol 1E-3) + assert(model.intercept ~== interceptR absTol 1E-3) + assert(model.weights ~= weightsR relTol 1E-3) model.transform(dataset).select("features", "prediction").collect().foreach { case Row(features: DenseVector, prediction1: Double) => From a0cc3e5aa3fcfd0fce6813c520152657d327aaf2 Mon Sep 17 00:00:00 2001 From: Feynman Liang Date: Thu, 9 Jul 2015 16:21:21 -0700 Subject: [PATCH 22/23] [SPARK-8538] [SPARK-8539] [ML] Linear Regression Training and Testing Results Adds results (e.g. objective value at each iteration, residuals) on training and user-specified test sets for LinearRegressionModel. Notes to Reviewers: * Are the `*TrainingResults` and `Results` classes too specialized for `LinearRegressionModel`? Where would be an appropriate level of abstraction? * Please check `transient` annotations are correct; the datasets should not be copied and kept during serialization. * Any thoughts on `RDD`s versus `DataFrame`s? If using `DataFrame`s, suggested schemas for each intermediate step? Also, how to create a "local DataFrame" without a `sqlContext`? Author: Feynman Liang Closes #7099 from feynmanliang/SPARK-8538 and squashes the following commits: d219fa4 [Feynman Liang] Update docs 4a42680 [Feynman Liang] Change Summary to hold values, move transient annotations down to metrics and predictions DF 6300031 [Feynman Liang] Code review changes 0a5e762 [Feynman Liang] Fix build error e71102d [Feynman Liang] Merge branch 'master' into SPARK-8538 3367489 [Feynman Liang] Merge branch 'master' into SPARK-8538 70f267c [Feynman Liang] Make TrainingSummary transient and remove Serializable from *Summary and RegressionMetrics 1d9ea42 [Feynman Liang] Fix failing Java test a65dfda [Feynman Liang] Make TrainingSummary and metrics serializable, prediction dataframe transient 0a605d8 [Feynman Liang] Replace Params from LinearRegression*Summary with private constructor vals c2fe835 [Feynman Liang] Optimize imports 02d8a70 [Feynman Liang] Add Params to LinearModel*Summary, refactor tests and add test for evaluate() 8f999f4 [Feynman Liang] Refactor from jkbradley code review 072e948 [Feynman Liang] Style 509ae36 [Feynman Liang] Use DFs and localize serialization to LinearRegressionModel 9509c79 [Feynman Liang] Fix imports b2bbaa3 [Feynman Liang] Refactored LinearRegressionResults API to be more private ffceaec [Feynman Liang] Merge branch 'master' into SPARK-8538 1cedb2b [Feynman Liang] Add test for decreasing objective trace dab0aff [Feynman Liang] Add LinearRegressionTrainingResults tests, make test suite code copy+pasteable 97b0a81 [Feynman Liang] Add LinearRegressionModel.evaluate() to get results on test sets dc51bce [Feynman Liang] Style guide fixes 521f397 [Feynman Liang] Use RDD[(Double, Double)] instead of DF 2ff5710 [Feynman Liang] Add training results and model summary to ML LinearRegression --- .../ml/regression/LinearRegression.scala | 139 +++++++++++++++++- .../ml/regression/LinearRegressionSuite.scala | 59 ++++++++ 2 files changed, 192 insertions(+), 6 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index f672c96576a33..8fc986056657d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -22,18 +22,20 @@ import scala.collection.mutable import breeze.linalg.{DenseVector => BDV, norm => brzNorm} import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS, OWLQN => BreezeOWLQN} -import org.apache.spark.{SparkException, Logging} +import org.apache.spark.{Logging, SparkException} import org.apache.spark.annotation.Experimental import org.apache.spark.ml.PredictorParams import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util.Identifiable +import org.apache.spark.mllib.evaluation.RegressionMetrics import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.linalg.BLAS._ import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer import org.apache.spark.rdd.RDD -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.functions.{col, udf} import org.apache.spark.storage.StorageLevel import org.apache.spark.util.StatCounter @@ -139,7 +141,16 @@ class LinearRegression(override val uid: String) logWarning(s"The standard deviation of the label is zero, so the weights will be zeros " + s"and the intercept will be the mean of the label; as a result, training is not needed.") if (handlePersistence) instances.unpersist() - return new LinearRegressionModel(uid, Vectors.sparse(numFeatures, Seq()), yMean) + val weights = Vectors.sparse(numFeatures, Seq()) + val intercept = yMean + + val model = new LinearRegressionModel(uid, weights, intercept) + val trainingSummary = new LinearRegressionTrainingSummary( + model.transform(dataset).select($(predictionCol), $(labelCol)), + $(predictionCol), + $(labelCol), + Array(0D)) + return copyValues(model.setSummary(trainingSummary)) } val featuresMean = summarizer.mean.toArray @@ -178,7 +189,6 @@ class LinearRegression(override val uid: String) state = states.next() arrayBuilder += state.adjustedValue } - if (state == null) { val msg = s"${optimizer.getClass.getName} failed." logError(msg) @@ -209,7 +219,13 @@ class LinearRegression(override val uid: String) if (handlePersistence) instances.unpersist() - copyValues(new LinearRegressionModel(uid, weights, intercept)) + val model = copyValues(new LinearRegressionModel(uid, weights, intercept)) + val trainingSummary = new LinearRegressionTrainingSummary( + model.transform(dataset).select($(predictionCol), $(labelCol)), + $(predictionCol), + $(labelCol), + objectiveHistory) + model.setSummary(trainingSummary) } override def copy(extra: ParamMap): LinearRegression = defaultCopy(extra) @@ -227,13 +243,124 @@ class LinearRegressionModel private[ml] ( extends RegressionModel[Vector, LinearRegressionModel] with LinearRegressionParams { + private var trainingSummary: Option[LinearRegressionTrainingSummary] = None + + /** + * Gets summary (e.g. residuals, mse, r-squared ) of model on training set. An exception is + * thrown if `trainingSummary == None`. + */ + def summary: LinearRegressionTrainingSummary = trainingSummary match { + case Some(summ) => summ + case None => + throw new SparkException( + "No training summary available for this LinearRegressionModel", + new NullPointerException()) + } + + private[regression] def setSummary(summary: LinearRegressionTrainingSummary): this.type = { + this.trainingSummary = Some(summary) + this + } + + /** Indicates whether a training summary exists for this model instance. */ + def hasSummary: Boolean = trainingSummary.isDefined + + /** + * Evaluates the model on a testset. + * @param dataset Test dataset to evaluate model on. + */ + // TODO: decide on a good name before exposing to public API + private[regression] def evaluate(dataset: DataFrame): LinearRegressionSummary = { + val t = udf { features: Vector => predict(features) } + val predictionAndObservations = dataset + .select(col($(labelCol)), t(col($(featuresCol))).as($(predictionCol))) + + new LinearRegressionSummary(predictionAndObservations, $(predictionCol), $(labelCol)) + } + override protected def predict(features: Vector): Double = { dot(features, weights) + intercept } override def copy(extra: ParamMap): LinearRegressionModel = { - copyValues(new LinearRegressionModel(uid, weights, intercept), extra) + val newModel = copyValues(new LinearRegressionModel(uid, weights, intercept)) + if (trainingSummary.isDefined) newModel.setSummary(trainingSummary.get) + newModel + } +} + +/** + * :: Experimental :: + * Linear regression training results. + * @param predictions predictions outputted by the model's `transform` method. + * @param objectiveHistory objective function (scaled loss + regularization) at each iteration. + */ +@Experimental +class LinearRegressionTrainingSummary private[regression] ( + predictions: DataFrame, + predictionCol: String, + labelCol: String, + val objectiveHistory: Array[Double]) + extends LinearRegressionSummary(predictions, predictionCol, labelCol) { + + /** Number of training iterations until termination */ + val totalIterations = objectiveHistory.length + +} + +/** + * :: Experimental :: + * Linear regression results evaluated on a dataset. + * @param predictions predictions outputted by the model's `transform` method. + */ +@Experimental +class LinearRegressionSummary private[regression] ( + @transient val predictions: DataFrame, + val predictionCol: String, + val labelCol: String) extends Serializable { + + @transient private val metrics = new RegressionMetrics( + predictions + .select(predictionCol, labelCol) + .map { case Row(pred: Double, label: Double) => (pred, label) } ) + + /** + * Returns the explained variance regression score. + * explainedVariance = 1 - variance(y - \hat{y}) / variance(y) + * Reference: [[http://en.wikipedia.org/wiki/Explained_variation]] + */ + val explainedVariance: Double = metrics.explainedVariance + + /** + * Returns the mean absolute error, which is a risk function corresponding to the + * expected value of the absolute error loss or l1-norm loss. + */ + val meanAbsoluteError: Double = metrics.meanAbsoluteError + + /** + * Returns the mean squared error, which is a risk function corresponding to the + * expected value of the squared error loss or quadratic loss. + */ + val meanSquaredError: Double = metrics.meanSquaredError + + /** + * Returns the root mean squared error, which is defined as the square root of + * the mean squared error. + */ + val rootMeanSquaredError: Double = metrics.rootMeanSquaredError + + /** + * Returns R^2^, the coefficient of determination. + * Reference: [[http://en.wikipedia.org/wiki/Coefficient_of_determination]] + */ + val r2: Double = metrics.r2 + + /** Residuals (predicted value - label value) */ + @transient lazy val residuals: DataFrame = { + val t = udf { (pred: Double, label: Double) => pred - label} + predictions.select(t(col(predictionCol), col(labelCol)).as("residuals")) } + } /** diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index 4f6a57739558b..cf120cf2a4b47 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -289,4 +289,63 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { assert(prediction1 ~== prediction2 relTol 1E-5) } } + + test("linear regression model training summary") { + val trainer = new LinearRegression + val model = trainer.fit(dataset) + + // Training results for the model should be available + assert(model.hasSummary) + + // Residuals in [[LinearRegressionResults]] should equal those manually computed + val expectedResiduals = dataset.select("features", "label") + .map { case Row(features: DenseVector, label: Double) => + val prediction = + features(0) * model.weights(0) + features(1) * model.weights(1) + model.intercept + prediction - label + } + .zip(model.summary.residuals.map(_.getDouble(0))) + .collect() + .foreach { case (manualResidual: Double, resultResidual: Double) => + assert(manualResidual ~== resultResidual relTol 1E-5) + } + + /* + Use the following R code to generate model training results. + + predictions <- predict(fit, newx=features) + residuals <- predictions - label + > mean(residuals^2) # MSE + [1] 0.009720325 + > mean(abs(residuals)) # MAD + [1] 0.07863206 + > cor(predictions, label)^2# r^2 + [,1] + s0 0.9998749 + */ + assert(model.summary.meanSquaredError ~== 0.00972035 relTol 1E-5) + assert(model.summary.meanAbsoluteError ~== 0.07863206 relTol 1E-5) + assert(model.summary.r2 ~== 0.9998749 relTol 1E-5) + + // Objective function should be monotonically decreasing for linear regression + assert( + model.summary + .objectiveHistory + .sliding(2) + .forall(x => x(0) >= x(1))) + } + + test("linear regression model testset evaluation summary") { + val trainer = new LinearRegression + val model = trainer.fit(dataset) + + // Evaluating on training dataset should yield results summary equal to training summary + val testSummary = model.evaluate(dataset) + assert(model.summary.meanSquaredError ~== testSummary.meanSquaredError relTol 1E-5) + assert(model.summary.r2 ~== testSummary.r2 relTol 1E-5) + model.summary.residuals.select("residuals").collect() + .zip(testSummary.residuals.select("residuals").collect()) + .forall { case (Row(r1: Double), Row(r2: Double)) => r1 ~== r2 relTol 1E-5 } + } + } From 2d45571fcb002cc9f03056c5a3f14493b83315a4 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Thu, 9 Jul 2015 17:09:16 -0700 Subject: [PATCH 23/23] [SPARK-8959] [SQL] [HOTFIX] Removes parquet-thrift and libthrift dependencies These two dependencies were introduced in #7231 to help testing Parquet compatibility with `parquet-thrift`. However, they somehow crash the Scala compiler in Maven builds. This PR fixes this issue by: 1. Removing these two dependencies, and 2. Instead of generating the testing Parquet file programmatically, checking in an actual testing Parquet file generated by `parquet-thrift` as a test resource. This is just a quick fix to bring back Maven builds. Need to figure out the root case as binary Parquet files are harder to maintain. Author: Cheng Lian Closes #7330 from liancheng/spark-8959 and squashes the following commits: cf69512 [Cheng Lian] Brings back Maven builds --- pom.xml | 14 - sql/core/pom.xml | 10 - .../spark/sql/parquet/test/thrift/Nested.java | 541 ---- .../test/thrift/ParquetThriftCompat.java | 2808 ----------------- .../spark/sql/parquet/test/thrift/Suit.java | 51 - .../parquet-thrift-compat.snappy.parquet | Bin 0 -> 10550 bytes .../ParquetThriftCompatibilitySuite.scala | 78 +- 7 files changed, 8 insertions(+), 3494 deletions(-) delete mode 100644 sql/core/src/test/gen-java/org/apache/spark/sql/parquet/test/thrift/Nested.java delete mode 100644 sql/core/src/test/gen-java/org/apache/spark/sql/parquet/test/thrift/ParquetThriftCompat.java delete mode 100644 sql/core/src/test/gen-java/org/apache/spark/sql/parquet/test/thrift/Suit.java create mode 100755 sql/core/src/test/resources/parquet-thrift-compat.snappy.parquet diff --git a/pom.xml b/pom.xml index 529e47f8b5253..1eda108dc065b 100644 --- a/pom.xml +++ b/pom.xml @@ -161,7 +161,6 @@ 2.4.4 1.1.1.7 1.1.2 - 0.9.2 false @@ -181,7 +180,6 @@ compile compile test - test