diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index bf6e8ec8f37b8..5049ef924561c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -131,7 +131,7 @@ private[recommendation] trait ALSModelParams extends Params with HasPredictionCo * Common params for ALS. */ private[recommendation] trait ALSParams extends ALSModelParams with HasMaxIter with HasRegParam - with HasPredictionCol with HasCheckpointInterval with HasSeed { + with HasCheckpointInterval with HasSeed { /** * Param for rank of the matrix factorization (positive). diff --git a/python/pyspark/ml/recommendation.py b/python/pyspark/ml/recommendation.py index 676662da8c316..df9c765457ec1 100644 --- a/python/pyspark/ml/recommendation.py +++ b/python/pyspark/ml/recommendation.py @@ -28,8 +28,143 @@ @inherit_doc -class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, HasRegParam, HasSeed, - JavaMLWritable, JavaMLReadable): +class _ALSModelParams(HasPredictionCol): + """ + Params for :py:class:`ALS` and :py:class:`ALSModel`. + + .. versionadded:: 3.0.0 + """ + + userCol = Param(Params._dummy(), "userCol", "column name for user ids. Ids must be within " + + "the integer value range.", typeConverter=TypeConverters.toString) + itemCol = Param(Params._dummy(), "itemCol", "column name for item ids. Ids must be within " + + "the integer value range.", typeConverter=TypeConverters.toString) + coldStartStrategy = Param(Params._dummy(), "coldStartStrategy", "strategy for dealing with " + + "unknown or new users/items at prediction time. This may be useful " + + "in cross-validation or production scenarios, for handling " + + "user/item ids the model has not seen in the training data. " + + "Supported values: 'nan', 'drop'.", + typeConverter=TypeConverters.toString) + + @since("1.4.0") + def getUserCol(self): + """ + Gets the value of userCol or its default value. + """ + return self.getOrDefault(self.userCol) + + @since("1.4.0") + def getItemCol(self): + """ + Gets the value of itemCol or its default value. + """ + return self.getOrDefault(self.itemCol) + + @since("2.2.0") + def getColdStartStrategy(self): + """ + Gets the value of coldStartStrategy or its default value. + """ + return self.getOrDefault(self.coldStartStrategy) + + +@inherit_doc +class _ALSParams(_ALSModelParams, HasMaxIter, HasRegParam, HasCheckpointInterval, HasSeed): + """ + Params for :py:class:`ALS`. + + .. versionadded:: 3.0.0 + """ + + rank = Param(Params._dummy(), "rank", "rank of the factorization", + typeConverter=TypeConverters.toInt) + numUserBlocks = Param(Params._dummy(), "numUserBlocks", "number of user blocks", + typeConverter=TypeConverters.toInt) + numItemBlocks = Param(Params._dummy(), "numItemBlocks", "number of item blocks", + typeConverter=TypeConverters.toInt) + implicitPrefs = Param(Params._dummy(), "implicitPrefs", "whether to use implicit preference", + typeConverter=TypeConverters.toBoolean) + alpha = Param(Params._dummy(), "alpha", "alpha for implicit preference", + typeConverter=TypeConverters.toFloat) + + ratingCol = Param(Params._dummy(), "ratingCol", "column name for ratings", + typeConverter=TypeConverters.toString) + nonnegative = Param(Params._dummy(), "nonnegative", + "whether to use nonnegative constraint for least squares", + typeConverter=TypeConverters.toBoolean) + intermediateStorageLevel = Param(Params._dummy(), "intermediateStorageLevel", + "StorageLevel for intermediate datasets. Cannot be 'NONE'.", + typeConverter=TypeConverters.toString) + finalStorageLevel = Param(Params._dummy(), "finalStorageLevel", + "StorageLevel for ALS model factors.", + typeConverter=TypeConverters.toString) + + @since("1.4.0") + def getRank(self): + """ + Gets the value of rank or its default value. + """ + return self.getOrDefault(self.rank) + + @since("1.4.0") + def getNumUserBlocks(self): + """ + Gets the value of numUserBlocks or its default value. + """ + return self.getOrDefault(self.numUserBlocks) + + @since("1.4.0") + def getNumItemBlocks(self): + """ + Gets the value of numItemBlocks or its default value. + """ + return self.getOrDefault(self.numItemBlocks) + + @since("1.4.0") + def getImplicitPrefs(self): + """ + Gets the value of implicitPrefs or its default value. + """ + return self.getOrDefault(self.implicitPrefs) + + @since("1.4.0") + def getAlpha(self): + """ + Gets the value of alpha or its default value. + """ + return self.getOrDefault(self.alpha) + + @since("1.4.0") + def getRatingCol(self): + """ + Gets the value of ratingCol or its default value. + """ + return self.getOrDefault(self.ratingCol) + + @since("1.4.0") + def getNonnegative(self): + """ + Gets the value of nonnegative or its default value. + """ + return self.getOrDefault(self.nonnegative) + + @since("2.0.0") + def getIntermediateStorageLevel(self): + """ + Gets the value of intermediateStorageLevel or its default value. + """ + return self.getOrDefault(self.intermediateStorageLevel) + + @since("2.0.0") + def getFinalStorageLevel(self): + """ + Gets the value of finalStorageLevel or its default value. + """ + return self.getOrDefault(self.finalStorageLevel) + + +@inherit_doc +class ALS(JavaEstimator, _ALSParams, JavaMLWritable, JavaMLReadable): """ Alternating Least Squares (ALS) matrix factorization. @@ -79,6 +214,12 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha ... ["user", "item", "rating"]) >>> als = ALS(rank=10, maxIter=5, seed=0) >>> model = als.fit(df) + >>> model.getUserCol() + 'user' + >>> model.getItemCol() + 'item' + >>> model.setPredictionCol("newPrediction") + ALS... >>> model.rank 10 >>> model.userFactors.orderBy("id").collect() @@ -86,11 +227,11 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha >>> test = spark.createDataFrame([(0, 2), (1, 0), (2, 0)], ["user", "item"]) >>> predictions = sorted(model.transform(test).collect(), key=lambda r: r[0]) >>> predictions[0] - Row(user=0, item=2, prediction=0.6929101347923279) + Row(user=0, item=2, newPrediction=0.6929101347923279) >>> predictions[1] - Row(user=1, item=0, prediction=3.47356915473938) + Row(user=1, item=0, newPrediction=3.47356915473938) >>> predictions[2] - Row(user=2, item=0, prediction=-0.8991986513137817) + Row(user=2, item=0, newPrediction=-0.8991986513137817) >>> user_recs = model.recommendForAllUsers(3) >>> user_recs.where(user_recs.user == 0)\ .select("recommendations.item", "recommendations.rating").collect() @@ -125,38 +266,6 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha .. versionadded:: 1.4.0 """ - rank = Param(Params._dummy(), "rank", "rank of the factorization", - typeConverter=TypeConverters.toInt) - numUserBlocks = Param(Params._dummy(), "numUserBlocks", "number of user blocks", - typeConverter=TypeConverters.toInt) - numItemBlocks = Param(Params._dummy(), "numItemBlocks", "number of item blocks", - typeConverter=TypeConverters.toInt) - implicitPrefs = Param(Params._dummy(), "implicitPrefs", "whether to use implicit preference", - typeConverter=TypeConverters.toBoolean) - alpha = Param(Params._dummy(), "alpha", "alpha for implicit preference", - typeConverter=TypeConverters.toFloat) - userCol = Param(Params._dummy(), "userCol", "column name for user ids. Ids must be within " + - "the integer value range.", typeConverter=TypeConverters.toString) - itemCol = Param(Params._dummy(), "itemCol", "column name for item ids. Ids must be within " + - "the integer value range.", typeConverter=TypeConverters.toString) - ratingCol = Param(Params._dummy(), "ratingCol", "column name for ratings", - typeConverter=TypeConverters.toString) - nonnegative = Param(Params._dummy(), "nonnegative", - "whether to use nonnegative constraint for least squares", - typeConverter=TypeConverters.toBoolean) - intermediateStorageLevel = Param(Params._dummy(), "intermediateStorageLevel", - "StorageLevel for intermediate datasets. Cannot be 'NONE'.", - typeConverter=TypeConverters.toString) - finalStorageLevel = Param(Params._dummy(), "finalStorageLevel", - "StorageLevel for ALS model factors.", - typeConverter=TypeConverters.toString) - coldStartStrategy = Param(Params._dummy(), "coldStartStrategy", "strategy for dealing with " + - "unknown or new users/items at prediction time. This may be useful " + - "in cross-validation or production scenarios, for handling " + - "user/item ids the model has not seen in the training data. " + - "Supported values: 'nan', 'drop'.", - typeConverter=TypeConverters.toString) - @keyword_only def __init__(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemBlocks=10, implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", seed=None, @@ -208,13 +317,6 @@ def setRank(self, value): """ return self._set(rank=value) - @since("1.4.0") - def getRank(self): - """ - Gets the value of rank or its default value. - """ - return self.getOrDefault(self.rank) - @since("1.4.0") def setNumUserBlocks(self, value): """ @@ -222,13 +324,6 @@ def setNumUserBlocks(self, value): """ return self._set(numUserBlocks=value) - @since("1.4.0") - def getNumUserBlocks(self): - """ - Gets the value of numUserBlocks or its default value. - """ - return self.getOrDefault(self.numUserBlocks) - @since("1.4.0") def setNumItemBlocks(self, value): """ @@ -236,13 +331,6 @@ def setNumItemBlocks(self, value): """ return self._set(numItemBlocks=value) - @since("1.4.0") - def getNumItemBlocks(self): - """ - Gets the value of numItemBlocks or its default value. - """ - return self.getOrDefault(self.numItemBlocks) - @since("1.4.0") def setNumBlocks(self, value): """ @@ -258,13 +346,6 @@ def setImplicitPrefs(self, value): """ return self._set(implicitPrefs=value) - @since("1.4.0") - def getImplicitPrefs(self): - """ - Gets the value of implicitPrefs or its default value. - """ - return self.getOrDefault(self.implicitPrefs) - @since("1.4.0") def setAlpha(self, value): """ @@ -272,13 +353,6 @@ def setAlpha(self, value): """ return self._set(alpha=value) - @since("1.4.0") - def getAlpha(self): - """ - Gets the value of alpha or its default value. - """ - return self.getOrDefault(self.alpha) - @since("1.4.0") def setUserCol(self, value): """ @@ -286,13 +360,6 @@ def setUserCol(self, value): """ return self._set(userCol=value) - @since("1.4.0") - def getUserCol(self): - """ - Gets the value of userCol or its default value. - """ - return self.getOrDefault(self.userCol) - @since("1.4.0") def setItemCol(self, value): """ @@ -300,13 +367,6 @@ def setItemCol(self, value): """ return self._set(itemCol=value) - @since("1.4.0") - def getItemCol(self): - """ - Gets the value of itemCol or its default value. - """ - return self.getOrDefault(self.itemCol) - @since("1.4.0") def setRatingCol(self, value): """ @@ -314,13 +374,6 @@ def setRatingCol(self, value): """ return self._set(ratingCol=value) - @since("1.4.0") - def getRatingCol(self): - """ - Gets the value of ratingCol or its default value. - """ - return self.getOrDefault(self.ratingCol) - @since("1.4.0") def setNonnegative(self, value): """ @@ -328,13 +381,6 @@ def setNonnegative(self, value): """ return self._set(nonnegative=value) - @since("1.4.0") - def getNonnegative(self): - """ - Gets the value of nonnegative or its default value. - """ - return self.getOrDefault(self.nonnegative) - @since("2.0.0") def setIntermediateStorageLevel(self, value): """ @@ -342,13 +388,6 @@ def setIntermediateStorageLevel(self, value): """ return self._set(intermediateStorageLevel=value) - @since("2.0.0") - def getIntermediateStorageLevel(self): - """ - Gets the value of intermediateStorageLevel or its default value. - """ - return self.getOrDefault(self.intermediateStorageLevel) - @since("2.0.0") def setFinalStorageLevel(self, value): """ @@ -356,13 +395,6 @@ def setFinalStorageLevel(self, value): """ return self._set(finalStorageLevel=value) - @since("2.0.0") - def getFinalStorageLevel(self): - """ - Gets the value of finalStorageLevel or its default value. - """ - return self.getOrDefault(self.finalStorageLevel) - @since("2.2.0") def setColdStartStrategy(self, value): """ @@ -370,21 +402,35 @@ def setColdStartStrategy(self, value): """ return self._set(coldStartStrategy=value) - @since("2.2.0") - def getColdStartStrategy(self): - """ - Gets the value of coldStartStrategy or its default value. - """ - return self.getOrDefault(self.coldStartStrategy) - -class ALSModel(JavaModel, JavaMLWritable, JavaMLReadable): +class ALSModel(JavaModel, _ALSModelParams, JavaMLWritable, JavaMLReadable): """ Model fitted by ALS. .. versionadded:: 1.4.0 """ + @since("3.0.0") + def setUserCol(self, value): + """ + Sets the value of :py:attr:`userCol`. + """ + return self._set(userCol=value) + + @since("3.0.0") + def setItemCol(self, value): + """ + Sets the value of :py:attr:`itemCol`. + """ + return self._set(itemCol=value) + + @since("3.0.0") + def setColdStartStrategy(self, value): + """ + Sets the value of :py:attr:`coldStartStrategy`. + """ + return self._set(coldStartStrategy=value) + @property @since("1.4.0") def rank(self):