From 488bad319a70975733e83c83490240a70beb0c90 Mon Sep 17 00:00:00 2001 From: Yu ISHIKAWA Date: Thu, 2 Jul 2015 15:55:16 -0700 Subject: [PATCH] [SPARK-7104] [MLLIB] Support model save/load in Python's Word2Vec Author: Yu ISHIKAWA Closes #6821 from yu-iskw/SPARK-7104 and squashes the following commits: 975136b [Yu ISHIKAWA] Organize import 0ef58b6 [Yu ISHIKAWA] Use rmtree, instead of removedirs cb21653 [Yu ISHIKAWA] Add an explicit type for `Word2VecModelWrapper.save` 1d468ef [Yu ISHIKAWA] [SPARK-7104][MLlib] Support model save/load in Python's Word2Vec --- .../mllib/api/python/PythonMLLibAPI.scala | 3 +++ python/pyspark/mllib/feature.py | 21 ++++++++++++++++++- 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 458fab48fef5a..e628059c4af8e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -28,6 +28,7 @@ import scala.reflect.ClassTag import net.razorvine.pickle._ +import org.apache.spark.SparkContext import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} import org.apache.spark.api.python.SerDeUtil import org.apache.spark.mllib.classification._ @@ -641,6 +642,8 @@ private[python] class PythonMLLibAPI extends Serializable { def getVectors: JMap[String, JList[Float]] = { model.getVectors.map({case (k, v) => (k, v.toList.asJava)}).asJava } + + def save(sc: SparkContext, path: String): Unit = model.save(sc, path) } /** diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py index b5138773fd61b..f921e3ad1a314 100644 --- a/python/pyspark/mllib/feature.py +++ b/python/pyspark/mllib/feature.py @@ -36,6 +36,7 @@ from pyspark.mllib.linalg import ( Vector, Vectors, DenseVector, SparseVector, _convert_to_vector) from pyspark.mllib.regression import LabeledPoint +from pyspark.mllib.util import JavaLoader, JavaSaveable __all__ = ['Normalizer', 'StandardScalerModel', 'StandardScaler', 'HashingTF', 'IDFModel', 'IDF', 'Word2Vec', 'Word2VecModel', @@ -416,7 +417,7 @@ def fit(self, dataset): return IDFModel(jmodel) -class Word2VecModel(JavaVectorTransformer): +class Word2VecModel(JavaVectorTransformer, JavaSaveable, JavaLoader): """ class for Word2Vec model """ @@ -455,6 +456,12 @@ def getVectors(self): """ return self.call("getVectors") + @classmethod + def load(cls, sc, path): + jmodel = sc._jvm.org.apache.spark.mllib.feature \ + .Word2VecModel.load(sc._jsc.sc(), path) + return Word2VecModel(jmodel) + @ignore_unicode_prefix class Word2Vec(object): @@ -488,6 +495,18 @@ class Word2Vec(object): >>> syms = model.findSynonyms(vec, 2) >>> [s[0] for s in syms] [u'b', u'c'] + + >>> import os, tempfile + >>> path = tempfile.mkdtemp() + >>> model.save(sc, path) + >>> sameModel = Word2VecModel.load(sc, path) + >>> model.transform("a") == sameModel.transform("a") + True + >>> from shutil import rmtree + >>> try: + ... rmtree(path) + ... except OSError: + ... pass """ def __init__(self): """