Skip to content

Commit

Permalink
[SPARK-7104] [MLLIB] Support model save/load in Python's Word2Vec
Browse files Browse the repository at this point in the history
Author: Yu ISHIKAWA <[email protected]>

Closes apache#6821 from yu-iskw/SPARK-7104 and squashes the following commits:

975136b [Yu ISHIKAWA] Organize import
0ef58b6 [Yu ISHIKAWA] Use rmtree, instead of removedirs
cb21653 [Yu ISHIKAWA] Add an explicit type for `Word2VecModelWrapper.save`
1d468ef [Yu ISHIKAWA] [SPARK-7104][MLlib] Support model save/load in Python's Word2Vec
  • Loading branch information
yu-iskw authored and jkbradley committed Jul 2, 2015
1 parent fc7aebd commit 488bad3
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import scala.reflect.ClassTag

import net.razorvine.pickle._

import org.apache.spark.SparkContext
import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
import org.apache.spark.api.python.SerDeUtil
import org.apache.spark.mllib.classification._
Expand Down Expand Up @@ -641,6 +642,8 @@ private[python] class PythonMLLibAPI extends Serializable {
def getVectors: JMap[String, JList[Float]] = {
model.getVectors.map({case (k, v) => (k, v.toList.asJava)}).asJava
}

def save(sc: SparkContext, path: String): Unit = model.save(sc, path)
}

/**
Expand Down
21 changes: 20 additions & 1 deletion python/pyspark/mllib/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from pyspark.mllib.linalg import (
Vector, Vectors, DenseVector, SparseVector, _convert_to_vector)
from pyspark.mllib.regression import LabeledPoint
from pyspark.mllib.util import JavaLoader, JavaSaveable

__all__ = ['Normalizer', 'StandardScalerModel', 'StandardScaler',
'HashingTF', 'IDFModel', 'IDF', 'Word2Vec', 'Word2VecModel',
Expand Down Expand Up @@ -416,7 +417,7 @@ def fit(self, dataset):
return IDFModel(jmodel)


class Word2VecModel(JavaVectorTransformer):
class Word2VecModel(JavaVectorTransformer, JavaSaveable, JavaLoader):
"""
class for Word2Vec model
"""
Expand Down Expand Up @@ -455,6 +456,12 @@ def getVectors(self):
"""
return self.call("getVectors")

@classmethod
def load(cls, sc, path):
jmodel = sc._jvm.org.apache.spark.mllib.feature \
.Word2VecModel.load(sc._jsc.sc(), path)
return Word2VecModel(jmodel)


@ignore_unicode_prefix
class Word2Vec(object):
Expand Down Expand Up @@ -488,6 +495,18 @@ class Word2Vec(object):
>>> syms = model.findSynonyms(vec, 2)
>>> [s[0] for s in syms]
[u'b', u'c']
>>> import os, tempfile
>>> path = tempfile.mkdtemp()
>>> model.save(sc, path)
>>> sameModel = Word2VecModel.load(sc, path)
>>> model.transform("a") == sameModel.transform("a")
True
>>> from shutil import rmtree
>>> try:
... rmtree(path)
... except OSError:
... pass
"""
def __init__(self):
"""
Expand Down

0 comments on commit 488bad3

Please sign in to comment.