Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-10264][Documentation] Added @Since to ml.recomendation #10756

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 30 additions & 3 deletions mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
Original file line number Diff line number Diff line change
Expand Up @@ -180,22 +180,27 @@ private[recommendation] trait ALSParams extends ALSModelParams with HasMaxIter w
* @param itemFactors a DataFrame that stores item factors in two columns: `id` and `features`
*/
@Experimental
@Since("1.3.0")
class ALSModel private[ml] (
override val uid: String,
val rank: Int,
@Since("1.4.0") override val uid: String,
@Since("1.4.0") val rank: Int,
@transient val userFactors: DataFrame,
@transient val itemFactors: DataFrame)
extends Model[ALSModel] with ALSModelParams with MLWritable {

/** @group setParam */
@Since("1.4.0")
def setUserCol(value: String): this.type = set(userCol, value)

/** @group setParam */
@Since("1.4.0")
def setItemCol(value: String): this.type = set(itemCol, value)

/** @group setParam */
@Since("1.3.0")
def setPredictionCol(value: String): this.type = set(predictionCol, value)

@Since("1.3.0")
override def transform(dataset: DataFrame): DataFrame = {
// Register a UDF for DataFrame, and then
// create a new column named map(predictionCol) by running the predict UDF.
Expand All @@ -213,13 +218,15 @@ class ALSModel private[ml] (
predict(userFactors("features"), itemFactors("features")).as($(predictionCol)))
}

@Since("1.3.0")
override def transformSchema(schema: StructType): StructType = {
validateParams()
SchemaUtils.checkColumnType(schema, $(userCol), IntegerType)
SchemaUtils.checkColumnType(schema, $(itemCol), IntegerType)
SchemaUtils.appendColumn(schema, $(predictionCol), FloatType)
}

@Since("1.5.0")
override def copy(extra: ParamMap): ALSModel = {
val copied = new ALSModel(uid, rank, userFactors, itemFactors)
copyValues(copied, extra).setParent(parent)
Expand Down Expand Up @@ -303,65 +310,83 @@ object ALSModel extends MLReadable[ALSModel] {
* preferences rather than explicit ratings given to items.
*/
@Experimental
class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams
@Since("1.3.0")
class ALS(@Since("1.4.0") override val uid: String) extends Estimator[ALSModel] with ALSParams
with DefaultParamsWritable {

import org.apache.spark.ml.recommendation.ALS.Rating

@Since("1.4.0")
def this() = this(Identifiable.randomUID("als"))

/** @group setParam */
@Since("1.3.0")
def setRank(value: Int): this.type = set(rank, value)

/** @group setParam */
@Since("1.3.0")
def setNumUserBlocks(value: Int): this.type = set(numUserBlocks, value)

/** @group setParam */
@Since("1.3.0")
def setNumItemBlocks(value: Int): this.type = set(numItemBlocks, value)

/** @group setParam */
@Since("1.3.0")
def setImplicitPrefs(value: Boolean): this.type = set(implicitPrefs, value)

/** @group setParam */
@Since("1.3.0")
def setAlpha(value: Double): this.type = set(alpha, value)

/** @group setParam */
@Since("1.3.0")
def setUserCol(value: String): this.type = set(userCol, value)

/** @group setParam */
@Since("1.3.0")
def setItemCol(value: String): this.type = set(itemCol, value)

/** @group setParam */
@Since("1.3.0")
def setRatingCol(value: String): this.type = set(ratingCol, value)

/** @group setParam */
@Since("1.3.0")
def setPredictionCol(value: String): this.type = set(predictionCol, value)

/** @group setParam */
@Since("1.3.0")
def setMaxIter(value: Int): this.type = set(maxIter, value)

/** @group setParam */
@Since("1.3.0")
def setRegParam(value: Double): this.type = set(regParam, value)

/** @group setParam */
@Since("1.3.0")
def setNonnegative(value: Boolean): this.type = set(nonnegative, value)

/** @group setParam */
@Since("1.4.0")
def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value)

/** @group setParam */
@Since("1.3.0")
def setSeed(value: Long): this.type = set(seed, value)

/**
* Sets both numUserBlocks and numItemBlocks to the specific value.
* @group setParam
*/
@Since("1.3.0")
def setNumBlocks(value: Int): this.type = {
setNumUserBlocks(value)
setNumItemBlocks(value)
this
}

@Since("1.3.0")
override def fit(dataset: DataFrame): ALSModel = {
import dataset.sqlContext.implicits._
val r = if ($(ratingCol) != "") col($(ratingCol)).cast(FloatType) else lit(1.0f)
Expand All @@ -381,10 +406,12 @@ class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams
copyValues(model)
}

@Since("1.3.0")
override def transformSchema(schema: StructType): StructType = {
validateAndTransformSchema(schema)
}

@Since("1.5.0")
override def copy(extra: ParamMap): ALS = defaultCopy(extra)
}

Expand Down