-
-
Notifications
You must be signed in to change notification settings - Fork 8.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[pyspark] support pred_contribs #8633
Conversation
@@ -331,3 +334,25 @@ def split_params() -> Tuple[Dict[str, Any], Dict[str, Union[int, float, bool]]]: | |||
assert dvalid.num_col() == dtrain.num_col() | |||
|
|||
return dtrain, dvalid | |||
|
|||
|
|||
def pred_contribs( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@trivialfis Do you think it's better to move pred_contribs function to XGBModel?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we can merge it into the XGBModel/XGBClassifier.predict
method for consistency?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wow, good suggestion.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@trivialfis can you file a followup PR to merge it into XGBModel/XGBClassifier.predict method
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure.
python-package/xgboost/spark/core.py
Outdated
|
||
if pred_contrib_col_name: | ||
contribs = pred_contribs(model, X, base_margin) | ||
assert len(contribs.shape) == 2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not necessarily true. See doc/prediction.rst.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
python-package/xgboost/spark/core.py
Outdated
preds = model.predict( | ||
X, | ||
base_margin=base_margin, | ||
validate_features=False, | ||
**predict_params, | ||
) | ||
yield pd.Series(preds) | ||
data["prediction"] = pd.Series(preds) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can it handle multiple prediction types? For instance, normal prediction + contribs at the same time
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you mean handling "normal prediction + contribs" in the same "predict" of XGBModel or in a single pandas udf?
Right now, this PR can handle the latter, can predict the "normal prediction + contribs" in a single pandas udf. If "predict" of XGBModel can support predicting multiple prediction types, I can change it accordingly.
@trivialfis This PR is ready for review. Please help to review it. Thx |
@hcho3 @trivialfis please help to start the CI. Thx |
@trivialfis seems the failure case is not caused by this PR. =================================== FAILURES ===================================
--
| __________________________ TestLinear.test_coordinate __________________________
|
@WeichenXu123 @trivialfis please help to review it. Thx |
python-package/xgboost/spark/core.py
Outdated
Pred = namedtuple( | ||
"Pred", ("prediction", "raw_prediction", "probability", "pred_contrib") | ||
) | ||
pred = Pred("prediction", "rawPrediction", "probability", "pred_contrib") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to keep a consistent naming scheme? predContrib
v.s. pred_constrib
, based on the use of rawPrediction
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good suggestion. Done
if pred_contrib_col_name: | ||
dataset = dataset.withColumn( | ||
pred_contrib_col_name, | ||
array_to_vector(getattr(col(pred_struct_col), pred.pred_contrib)), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this work with cuDF?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems not, since the prediction udf has not added cudf support.
) | ||
if pred_contrib_col_name: | ||
# We will force setting strict_shape to True when predicting contribs, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Excellent!
|
||
|
||
@pytest.fixture | ||
def reg_data(spark: SparkSession) -> Generator[RegData, None, None]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need tests for multi-class as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, good suggestion.
(Vectors.dense(X[0, :]), int(y[0])), | ||
(Vectors.sparse(3, {1: float(X[1, 1]), 2: float(X[1, 2])}), int(y[1])), | ||
] | ||
cls_df_train = spark.createDataFrame(reg_df_train_data, ["features", "label"]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cls
usually mean classifier.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cls
usually mean classifier.
Based on my experience, "cls" is usually used as an abbreviation for "class", while "clf" is used as an abbreviation for "classifier".
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no specific request here. it's just an arbitrary convention in this project. Both are used.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thx, Done.
data[pred.prediction] = pd.Series(preds) | ||
|
||
if pred_contrib_col_name: | ||
contribs = pred_contribs(model, X, base_margin) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Q: Possible optimization: can we compute pred_contribs
, proba
, and prediction
in one pass ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
AFAIK, the xgboost API can predict 1 type at one time. @trivialfis can you correct me?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's correct. However, the Spark API can iterate through multiple types of prediction for meeting the spark convention if needed.
pred_contrib_col: "Param[str]" = Param( | ||
Params._dummy(), | ||
"pred_contrib_col", | ||
"contribution prediction column name.", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can explain a bit more about the contribution prediction here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
"pred_contrib_col", | ||
"contribution prediction column name.", | ||
typeConverter=TypeConverters.toString, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to add this param into _pyspark_specific_params
dict.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thx, Done.
To fix #8449. This PR supports pred_contribs for pyspark