From 1433b1139e510457419f2ffd29195bf29f6e556e Mon Sep 17 00:00:00 2001 From: Vincenzo Selvaggio Date: Wed, 29 Oct 2014 10:06:07 +0000 Subject: [PATCH] complete suite tests --- .../export/ModelExportFactorySuite.scala | 9 +++++-- .../pmml/KMeansPMMLModelExportSuite.scala | 25 +++++++++++++------ 2 files changed, 24 insertions(+), 10 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/export/ModelExportFactorySuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/export/ModelExportFactorySuite.scala index 9b6b4160d6120..bdc0239e94993 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/export/ModelExportFactorySuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/export/ModelExportFactorySuite.scala @@ -26,25 +26,30 @@ class ModelExportFactorySuite extends FunSuite{ test("ModelExportFactory create KMeansPMMLModelExport when passing a KMeansModel") { + //arrange val clusterCenters = Array( Vectors.dense(1.0, 2.0, 6.0), Vectors.dense(1.0, 3.0, 0.0), Vectors.dense(1.0, 4.0, 6.0) ) - val kmeansModel = new KMeansModel(clusterCenters); + //act val modelExport = ModelExportFactory.createModelExport(kmeansModel, ModelExportType.PMML) + //assert assert(modelExport.isInstanceOf[KMeansPMMLModelExport]) } - test("ModelExportFactory throws IllegalArgumentException when passing an unsupported model") { + test("ModelExportFactory throw IllegalArgumentException when passing an unsupported model") { + //arrange val invalidModel = new Object; + //assert intercept[IllegalArgumentException] { + //act ModelExportFactory.createModelExport(invalidModel, ModelExportType.PMML) } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/export/pmml/KMeansPMMLModelExportSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/export/pmml/KMeansPMMLModelExportSuite.scala index 02339b0e20e28..4bfd60906a670 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/export/pmml/KMeansPMMLModelExportSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/export/pmml/KMeansPMMLModelExportSuite.scala @@ -22,29 +22,38 @@ import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.export.ModelExportFactory import org.apache.spark.mllib.clustering.KMeansModel import org.apache.spark.mllib.export.ModelExportType +import org.dmg.pmml.ClusteringModel +import javax.xml.parsers.DocumentBuilderFactory +import java.io.ByteArrayOutputStream class KMeansPMMLModelExportSuite extends FunSuite{ test("KMeansPMMLModelExport generate PMML format") { + //arrange model to test val clusterCenters = Array( Vectors.dense(1.0, 2.0, 6.0), Vectors.dense(1.0, 3.0, 0.0), Vectors.dense(1.0, 4.0, 6.0) ) - val kmeansModel = new KMeansModel(clusterCenters); + //act by exporting the model to the PMML format val modelExport = ModelExportFactory.createModelExport(kmeansModel, ModelExportType.PMML) - + + //assert that the PMML format is as expected assert(modelExport.isInstanceOf[PMMLModelExport]) + var pmml = modelExport.asInstanceOf[PMMLModelExport].getPmml() + assert(pmml.getHeader().getDescription() === "k-means clustering") + //check that the number of fields match the single vector size + assert(pmml.getDataDictionary().getNumberOfFields() === clusterCenters(0).size) + //this verify that there is a model attached to the pmml object and the model is a clustering one + //it also verifies that the pmml model has the same number of clusters of the spark model + assert(pmml.getModels().get(0).asInstanceOf[ClusteringModel].getNumberOfClusters() === clusterCenters.size) - //TODO: asserts - //compare pmml fields to strings - modelExport.asInstanceOf[PMMLModelExport].getPmml() - //use document builder to load the xml generated and validated the notes by looking for them - modelExport.asInstanceOf[PMMLModelExport].save(System.out) - //saveLocalFile too??? search how to unit test file creating in java + //manual checking + //modelExport.asInstanceOf[PMMLModelExport].save(System.out) + //modelExport.asInstanceOf[PMMLModelExport].saveLocalFile("/tmp/kmeans.xml") }