Skip to content

Commit

Permalink
update example code
Browse files Browse the repository at this point in the history
  • Loading branch information
mengxr committed Oct 28, 2014
1 parent 4ce0506 commit fdeed9a
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@

package org.apache.spark.examples.mllib

import java.io.File

import com.google.common.io.Files
import scopt.OptionParser

import org.apache.spark.{SparkConf, SparkContext}
Expand Down Expand Up @@ -96,15 +99,21 @@ object DatasetExample {
(sum1, sum2) => sum1.merge(sum2))
println(s"Selected features column with average values:\n ${featureSummary.mean.toString}")

schemaRDD.saveAsParquetFile("/tmp/dataset")
val newDataset = sqlContext.parquetFile("/tmp/dataset")
val tmpDir = Files.createTempDir()
tmpDir.deleteOnExit()
val outputDir = new File(tmpDir, "dataset").toString
println(s"Saving to $outputDir as Parquet file.")
schemaRDD.saveAsParquetFile(outputDir)

println(s"Loading Parquet file with UDT from $outputDir.")
val newDataset = sqlContext.parquetFile(outputDir)

println(s"Schema from Parquet: ${newDataset.schema.prettyJson}")
val newFeatures = newDataset.select('features).map { case Row(v: Vector) => v }
val newFeaturesSummary = newFeatures.aggregate(new MultivariateOnlineSummarizer())(
(summary, feat) => summary.add(feat),
(sum1, sum2) => sum1.merge(sum2))
println(s"Selected features column with average values:\n ${featureSummary.mean.toString}")
println(s"Selected features column with average values:\n ${newFeaturesSummary.mean.toString}")

sc.stop()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -348,4 +348,4 @@ private[spark] class SparseVectorUDT extends UserDefinedType[SparseVector] {
val values = row.getAs[Seq[Double]](2).toArray
new SparseVector(vSize, indices, values)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ class MyDenseVectorUDT extends UserDefinedType[MyDenseVector] {
}

override def deserialize(row: Row): MyDenseVector = {
val features = row.getAs[Seq[Double]](0).toArray
new MyDenseVector(features)
val data = row.getAs[Seq[Double]](0).toArray
new MyDenseVector(data)
}
}

Expand Down

0 comments on commit fdeed9a

Please sign in to comment.