-
Notifications
You must be signed in to change notification settings - Fork 28.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-10264][Documentation, ML] Added Since annotation for ml.recomendation #8532
Changes from 7 commits
3820f41
57b366d
4f32939
f3cc61a
3988055
adf64e1
624d85d
7b65e62
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -31,7 +31,7 @@ import org.apache.hadoop.fs.{FileSystem, Path} | |
import org.netlib.util.intW | ||
|
||
import org.apache.spark.{Logging, Partitioner} | ||
import org.apache.spark.annotation.{DeveloperApi, Experimental} | ||
import org.apache.spark.annotation.{DeveloperApi, Experimental, Since} | ||
import org.apache.spark.ml.{Estimator, Model} | ||
import org.apache.spark.ml.param._ | ||
import org.apache.spark.ml.param.shared._ | ||
|
@@ -178,6 +178,7 @@ private[recommendation] trait ALSParams extends ALSModelParams with HasMaxIter w | |
* @param itemFactors a DataFrame that stores item factors in two columns: `id` and `features` | ||
*/ | ||
@Experimental | ||
@Since("1.3.0") | ||
class ALSModel private[ml] ( | ||
override val uid: String, | ||
val rank: Int, | ||
|
@@ -186,14 +187,18 @@ class ALSModel private[ml] ( | |
extends Model[ALSModel] with ALSModelParams { | ||
|
||
/** @group setParam */ | ||
@Since("1.4.0") | ||
def setUserCol(value: String): this.type = set(userCol, value) | ||
|
||
/** @group setParam */ | ||
@Since("1.4.0") | ||
def setItemCol(value: String): this.type = set(itemCol, value) | ||
|
||
/** @group setParam */ | ||
@Since("1.3.0") | ||
def setPredictionCol(value: String): this.type = set(predictionCol, value) | ||
|
||
@Since("1.3.0") | ||
override def transform(dataset: DataFrame): DataFrame = { | ||
// Register a UDF for DataFrame, and then | ||
// create a new column named map(predictionCol) by running the predict UDF. | ||
|
@@ -211,12 +216,14 @@ class ALSModel private[ml] ( | |
predict(userFactors("features"), itemFactors("features")).as($(predictionCol))) | ||
} | ||
|
||
@Since("1.3.0") | ||
override def transformSchema(schema: StructType): StructType = { | ||
SchemaUtils.checkColumnType(schema, $(userCol), IntegerType) | ||
SchemaUtils.checkColumnType(schema, $(itemCol), IntegerType) | ||
SchemaUtils.appendColumn(schema, $(predictionCol), FloatType) | ||
} | ||
|
||
@Since("1.5.0") | ||
override def copy(extra: ParamMap): ALSModel = { | ||
val copied = new ALSModel(uid, rank, userFactors, itemFactors) | ||
copyValues(copied, extra).setParent(parent) | ||
|
@@ -255,64 +262,82 @@ class ALSModel private[ml] ( | |
* preferences rather than explicit ratings given to items. | ||
*/ | ||
@Experimental | ||
class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams { | ||
@Since("1.3.0") | ||
class ALS(@Since("1.4.0") override val uid: String) extends Estimator[ALSModel] with ALSParams { | ||
|
||
import org.apache.spark.ml.recommendation.ALS.Rating | ||
|
||
@Since("1.4.0") | ||
def this() = this(Identifiable.randomUID("als")) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please annotate auxillary constructor There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added Auxiliary constructor |
||
|
||
/** @group setParam */ | ||
@Since("1.3.0") | ||
def setRank(value: Int): this.type = set(rank, value) | ||
|
||
/** @group setParam */ | ||
@Since("1.3.0") | ||
def setNumUserBlocks(value: Int): this.type = set(numUserBlocks, value) | ||
|
||
/** @group setParam */ | ||
@Since("1.3.0") | ||
def setNumItemBlocks(value: Int): this.type = set(numItemBlocks, value) | ||
|
||
/** @group setParam */ | ||
@Since("1.3.0") | ||
def setImplicitPrefs(value: Boolean): this.type = set(implicitPrefs, value) | ||
|
||
/** @group setParam */ | ||
@Since("1.3.0") | ||
def setAlpha(value: Double): this.type = set(alpha, value) | ||
|
||
/** @group setParam */ | ||
@Since("1.3.0") | ||
def setUserCol(value: String): this.type = set(userCol, value) | ||
|
||
/** @group setParam */ | ||
@Since("1.3.0") | ||
def setItemCol(value: String): this.type = set(itemCol, value) | ||
|
||
/** @group setParam */ | ||
@Since("1.3.0") | ||
def setRatingCol(value: String): this.type = set(ratingCol, value) | ||
|
||
/** @group setParam */ | ||
@Since("1.3.0") | ||
def setPredictionCol(value: String): this.type = set(predictionCol, value) | ||
|
||
/** @group setParam */ | ||
@Since("1.3.0") | ||
def setMaxIter(value: Int): this.type = set(maxIter, value) | ||
|
||
/** @group setParam */ | ||
@Since("1.3.0") | ||
def setRegParam(value: Double): this.type = set(regParam, value) | ||
|
||
/** @group setParam */ | ||
@Since("1.3.0") | ||
def setNonnegative(value: Boolean): this.type = set(nonnegative, value) | ||
|
||
/** @group setParam */ | ||
@Since("1.4.0") | ||
def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) | ||
|
||
/** @group setParam */ | ||
@Since("1.4.0") | ||
def setSeed(value: Long): this.type = set(seed, value) | ||
|
||
/** | ||
* Sets both numUserBlocks and numItemBlocks to the specific value. | ||
* @group setParam | ||
*/ | ||
@Since("1.3.0") | ||
def setNumBlocks(value: Int): this.type = { | ||
setNumUserBlocks(value) | ||
setNumItemBlocks(value) | ||
this | ||
} | ||
|
||
@Since("1.3.0") | ||
override def fit(dataset: DataFrame): ALSModel = { | ||
import dataset.sqlContext.implicits._ | ||
val ratings = dataset | ||
|
@@ -332,10 +357,12 @@ class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams { | |
copyValues(model) | ||
} | ||
|
||
@Since("1.3.0") | ||
override def transformSchema(schema: StructType): StructType = { | ||
validateAndTransformSchema(schema) | ||
} | ||
|
||
@Since("1.5.0") | ||
override def copy(extra: ParamMap): ALS = defaultCopy(extra) | ||
} | ||
|
||
|
@@ -348,14 +375,19 @@ class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams { | |
* than 2 billion. | ||
*/ | ||
@DeveloperApi | ||
@Since("1.3.0") | ||
object ALS extends Logging { | ||
|
||
/** | ||
* :: DeveloperApi :: | ||
* Rating class for better code readability. | ||
*/ | ||
@DeveloperApi | ||
case class Rating[@specialized(Int, Long) ID](user: ID, item: ID, rating: Float) | ||
@Since("1.3.0") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We don't need |
||
case class Rating[@specialized(Int, Long) ID] @Since("1.3.0")( | ||
@Since("1.3.0") user: ID, | ||
@Since("1.3.0") item: ID, | ||
@Since("1.3.0") rating: Float) | ||
|
||
/** Trait for least squares solvers applied to the normal equation. */ | ||
private[recommendation] trait LeastSquaresNESolver extends Serializable { | ||
|
@@ -426,6 +458,7 @@ object ALS extends Logging { | |
* min_x_ norm(A x - b)^2^ + lambda * n * norm(x)^2^ | ||
* subject to x >= 0 | ||
*/ | ||
@Since("1.3.0") | ||
override def solve(ne: NormalEquation, lambda: Double): Array[Float] = { | ||
val rank = ne.k | ||
initialize(rank) | ||
|
@@ -519,6 +552,7 @@ object ALS extends Logging { | |
* Implementation of the ALS algorithm. | ||
*/ | ||
@DeveloperApi | ||
@Since("1.3.0") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto |
||
def train[ID: ClassTag]( // scalastyle:ignore | ||
ratings: RDD[Rating[ID]], | ||
rank: Int = 10, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Document public params in the constructor. Though the constructor is private, those variables are public.