-
Notifications
You must be signed in to change notification settings - Fork 28.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-10264][Documentation, ML] Added Since annotation for ml.recomendation #8532
Closed
Closed
Changes from all commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
3820f41
Modified for adding @Since annotation
tijoparacka 57b366d
updated to add Since Annotation
4f32939
Modified for checkstyle fix
f3cc61a
Added Auxiliary constructor
3988055
Added annotation for case class
adf64e1
Fixed compilation error dueto earlier fix
624d85d
removed unwanted space
7b65e62
updated review comments
tijoparacka File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -31,7 +31,7 @@ import org.apache.hadoop.fs.{FileSystem, Path} | |
import org.netlib.util.intW | ||
|
||
import org.apache.spark.{Logging, Partitioner} | ||
import org.apache.spark.annotation.{DeveloperApi, Experimental} | ||
import org.apache.spark.annotation.{DeveloperApi, Experimental, Since} | ||
import org.apache.spark.ml.{Estimator, Model} | ||
import org.apache.spark.ml.param._ | ||
import org.apache.spark.ml.param.shared._ | ||
|
@@ -178,22 +178,27 @@ private[recommendation] trait ALSParams extends ALSModelParams with HasMaxIter w | |
* @param itemFactors a DataFrame that stores item factors in two columns: `id` and `features` | ||
*/ | ||
@Experimental | ||
@Since("1.3.0") | ||
class ALSModel private[ml] ( | ||
override val uid: String, | ||
val rank: Int, | ||
@Since("1.4.0") override val uid: String, | ||
@Since("1.4.0") val rank: Int, | ||
@transient val userFactors: DataFrame, | ||
@transient val itemFactors: DataFrame) | ||
extends Model[ALSModel] with ALSModelParams { | ||
|
||
/** @group setParam */ | ||
@Since("1.4.0") | ||
def setUserCol(value: String): this.type = set(userCol, value) | ||
|
||
/** @group setParam */ | ||
@Since("1.4.0") | ||
def setItemCol(value: String): this.type = set(itemCol, value) | ||
|
||
/** @group setParam */ | ||
@Since("1.3.0") | ||
def setPredictionCol(value: String): this.type = set(predictionCol, value) | ||
|
||
@Since("1.3.0") | ||
override def transform(dataset: DataFrame): DataFrame = { | ||
// Register a UDF for DataFrame, and then | ||
// create a new column named map(predictionCol) by running the predict UDF. | ||
|
@@ -211,12 +216,14 @@ class ALSModel private[ml] ( | |
predict(userFactors("features"), itemFactors("features")).as($(predictionCol))) | ||
} | ||
|
||
@Since("1.3.0") | ||
override def transformSchema(schema: StructType): StructType = { | ||
SchemaUtils.checkColumnType(schema, $(userCol), IntegerType) | ||
SchemaUtils.checkColumnType(schema, $(itemCol), IntegerType) | ||
SchemaUtils.appendColumn(schema, $(predictionCol), FloatType) | ||
} | ||
|
||
@Since("1.5.0") | ||
override def copy(extra: ParamMap): ALSModel = { | ||
val copied = new ALSModel(uid, rank, userFactors, itemFactors) | ||
copyValues(copied, extra).setParent(parent) | ||
|
@@ -255,64 +262,82 @@ class ALSModel private[ml] ( | |
* preferences rather than explicit ratings given to items. | ||
*/ | ||
@Experimental | ||
class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams { | ||
@Since("1.3.0") | ||
class ALS(@Since("1.4.0") override val uid: String) extends Estimator[ALSModel] with ALSParams { | ||
|
||
import org.apache.spark.ml.recommendation.ALS.Rating | ||
|
||
@Since("1.4.0") | ||
def this() = this(Identifiable.randomUID("als")) | ||
|
||
/** @group setParam */ | ||
@Since("1.3.0") | ||
def setRank(value: Int): this.type = set(rank, value) | ||
|
||
/** @group setParam */ | ||
@Since("1.3.0") | ||
def setNumUserBlocks(value: Int): this.type = set(numUserBlocks, value) | ||
|
||
/** @group setParam */ | ||
@Since("1.3.0") | ||
def setNumItemBlocks(value: Int): this.type = set(numItemBlocks, value) | ||
|
||
/** @group setParam */ | ||
@Since("1.3.0") | ||
def setImplicitPrefs(value: Boolean): this.type = set(implicitPrefs, value) | ||
|
||
/** @group setParam */ | ||
@Since("1.3.0") | ||
def setAlpha(value: Double): this.type = set(alpha, value) | ||
|
||
/** @group setParam */ | ||
@Since("1.3.0") | ||
def setUserCol(value: String): this.type = set(userCol, value) | ||
|
||
/** @group setParam */ | ||
@Since("1.3.0") | ||
def setItemCol(value: String): this.type = set(itemCol, value) | ||
|
||
/** @group setParam */ | ||
@Since("1.3.0") | ||
def setRatingCol(value: String): this.type = set(ratingCol, value) | ||
|
||
/** @group setParam */ | ||
@Since("1.3.0") | ||
def setPredictionCol(value: String): this.type = set(predictionCol, value) | ||
|
||
/** @group setParam */ | ||
@Since("1.3.0") | ||
def setMaxIter(value: Int): this.type = set(maxIter, value) | ||
|
||
/** @group setParam */ | ||
@Since("1.3.0") | ||
def setRegParam(value: Double): this.type = set(regParam, value) | ||
|
||
/** @group setParam */ | ||
@Since("1.3.0") | ||
def setNonnegative(value: Boolean): this.type = set(nonnegative, value) | ||
|
||
/** @group setParam */ | ||
@Since("1.4.0") | ||
def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) | ||
|
||
/** @group setParam */ | ||
@Since("1.4.0") | ||
def setSeed(value: Long): this.type = set(seed, value) | ||
|
||
/** | ||
* Sets both numUserBlocks and numItemBlocks to the specific value. | ||
* @group setParam | ||
*/ | ||
@Since("1.3.0") | ||
def setNumBlocks(value: Int): this.type = { | ||
setNumUserBlocks(value) | ||
setNumItemBlocks(value) | ||
this | ||
} | ||
|
||
@Since("1.3.0") | ||
override def fit(dataset: DataFrame): ALSModel = { | ||
import dataset.sqlContext.implicits._ | ||
val ratings = dataset | ||
|
@@ -332,10 +357,12 @@ class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams { | |
copyValues(model) | ||
} | ||
|
||
@Since("1.3.0") | ||
override def transformSchema(schema: StructType): StructType = { | ||
validateAndTransformSchema(schema) | ||
} | ||
|
||
@Since("1.5.0") | ||
override def copy(extra: ParamMap): ALS = defaultCopy(extra) | ||
} | ||
|
||
|
@@ -348,14 +375,19 @@ class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams { | |
* than 2 billion. | ||
*/ | ||
@DeveloperApi | ||
@Since("1.3.0") | ||
object ALS extends Logging { | ||
|
||
/** | ||
* :: DeveloperApi :: | ||
* Rating class for better code readability. | ||
*/ | ||
@DeveloperApi | ||
case class Rating[@specialized(Int, Long) ID](user: ID, item: ID, rating: Float) | ||
@Since("1.3.0") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We don't need |
||
case class Rating[@specialized(Int, Long) ID] @Since("1.3.0")( | ||
@Since("1.3.0") user: ID, | ||
@Since("1.3.0") item: ID, | ||
@Since("1.3.0") rating: Float) | ||
|
||
/** Trait for least squares solvers applied to the normal equation. */ | ||
private[recommendation] trait LeastSquaresNESolver extends Serializable { | ||
|
@@ -426,6 +458,7 @@ object ALS extends Logging { | |
* min_x_ norm(A x - b)^2^ + lambda * n * norm(x)^2^ | ||
* subject to x >= 0 | ||
*/ | ||
@Since("1.3.0") | ||
override def solve(ne: NormalEquation, lambda: Double): Array[Float] = { | ||
val rank = ne.k | ||
initialize(rank) | ||
|
@@ -519,6 +552,7 @@ object ALS extends Logging { | |
* Implementation of the ALS algorithm. | ||
*/ | ||
@DeveloperApi | ||
@Since("1.3.0") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto |
||
def train[ID: ClassTag]( // scalastyle:ignore | ||
ratings: RDD[Rating[ID]], | ||
rank: Int = 10, | ||
|
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please annotate auxillary constructor
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added Auxiliary constructor