diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala index 344da2c90c94b..f8d83f4ec7327 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala @@ -17,6 +17,9 @@ package org.apache.spark.examples.mllib +import java.io.File + +import com.google.common.io.Files import scopt.OptionParser import org.apache.spark.{SparkConf, SparkContext} @@ -96,15 +99,21 @@ object DatasetExample { (sum1, sum2) => sum1.merge(sum2)) println(s"Selected features column with average values:\n ${featureSummary.mean.toString}") - schemaRDD.saveAsParquetFile("/tmp/dataset") - val newDataset = sqlContext.parquetFile("/tmp/dataset") + val tmpDir = Files.createTempDir() + tmpDir.deleteOnExit() + val outputDir = new File(tmpDir, "dataset").toString + println(s"Saving to $outputDir as Parquet file.") + schemaRDD.saveAsParquetFile(outputDir) + + println(s"Loading Parquet file with UDT from $outputDir.") + val newDataset = sqlContext.parquetFile(outputDir) println(s"Schema from Parquet: ${newDataset.schema.prettyJson}") val newFeatures = newDataset.select('features).map { case Row(v: Vector) => v } val newFeaturesSummary = newFeatures.aggregate(new MultivariateOnlineSummarizer())( (summary, feat) => summary.add(feat), (sum1, sum2) => sum1.merge(sum2)) - println(s"Selected features column with average values:\n ${featureSummary.mean.toString}") + println(s"Selected features column with average values:\n ${newFeaturesSummary.mean.toString}") sc.stop() } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index c593156f30233..17d7684b1ddf5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -348,4 +348,4 @@ private[spark] class SparseVectorUDT extends UserDefinedType[SparseVector] { val values = row.getAs[Seq[Double]](2).toArray new SparseVector(vSize, indices, values) } - +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index a470c2765f19e..3208c910a5bc4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -47,8 +47,8 @@ class MyDenseVectorUDT extends UserDefinedType[MyDenseVector] { } override def deserialize(row: Row): MyDenseVector = { - val features = row.getAs[Seq[Double]](0).toArray - new MyDenseVector(features) + val data = row.getAs[Seq[Double]](0).toArray + new MyDenseVector(data) } }