Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[pyspark] support pred_contribs #8633

Merged
merged 5 commits into from
Jan 11, 2023
Merged

Conversation

wbo4958
Copy link
Contributor

@wbo4958 wbo4958 commented Jan 4, 2023

To fix #8449. This PR supports pred_contribs for pyspark

@@ -331,3 +334,25 @@ def split_params() -> Tuple[Dict[str, Any], Dict[str, Union[int, float, bool]]]:
assert dvalid.num_col() == dtrain.num_col()

return dtrain, dvalid


def pred_contribs(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@trivialfis Do you think it's better to move pred_contribs function to XGBModel?

Copy link
Member

@trivialfis trivialfis Jan 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can merge it into the XGBModel/XGBClassifier.predict method for consistency?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wow, good suggestion.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@trivialfis can you file a followup PR to merge it into XGBModel/XGBClassifier.predict method

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure.


if pred_contrib_col_name:
contribs = pred_contribs(model, X, base_margin)
assert len(contribs.shape) == 2
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not necessarily true. See doc/prediction.rst.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

preds = model.predict(
X,
base_margin=base_margin,
validate_features=False,
**predict_params,
)
yield pd.Series(preds)
data["prediction"] = pd.Series(preds)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can it handle multiple prediction types? For instance, normal prediction + contribs at the same time

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you mean handling "normal prediction + contribs" in the same "predict" of XGBModel or in a single pandas udf?

Right now, this PR can handle the latter, can predict the "normal prediction + contribs" in a single pandas udf. If "predict" of XGBModel can support predicting multiple prediction types, I can change it accordingly.

@wbo4958 wbo4958 marked this pull request as ready for review January 5, 2023 09:43
@wbo4958
Copy link
Contributor Author

wbo4958 commented Jan 5, 2023

@trivialfis This PR is ready for review. Please help to review it. Thx

@wbo4958
Copy link
Contributor Author

wbo4958 commented Jan 5, 2023

@hcho3 @trivialfis please help to start the CI. Thx

@wbo4958
Copy link
Contributor Author

wbo4958 commented Jan 6, 2023

@trivialfis seems the failure case is not caused by this PR.

=================================== FAILURES ===================================
--
  | __________________________ TestLinear.test_coordinate __________________________

@wbo4958
Copy link
Contributor Author

wbo4958 commented Jan 8, 2023

@WeichenXu123 @trivialfis please help to review it. Thx

Pred = namedtuple(
"Pred", ("prediction", "raw_prediction", "probability", "pred_contrib")
)
pred = Pred("prediction", "rawPrediction", "probability", "pred_contrib")
Copy link
Member

@trivialfis trivialfis Jan 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to keep a consistent naming scheme? predContrib v.s. pred_constrib, based on the use of rawPrediction.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good suggestion. Done

if pred_contrib_col_name:
dataset = dataset.withColumn(
pred_contrib_col_name,
array_to_vector(getattr(col(pred_struct_col), pred.pred_contrib)),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this work with cuDF?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems not, since the prediction udf has not added cudf support.

)
if pred_contrib_col_name:
# We will force setting strict_shape to True when predicting contribs,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Excellent!



@pytest.fixture
def reg_data(spark: SparkSession) -> Generator[RegData, None, None]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need tests for multi-class as well?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, good suggestion.

(Vectors.dense(X[0, :]), int(y[0])),
(Vectors.sparse(3, {1: float(X[1, 1]), 2: float(X[1, 2])}), int(y[1])),
]
cls_df_train = spark.createDataFrame(reg_df_train_data, ["features", "label"])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cls usually mean classifier.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cls usually mean classifier.

Based on my experience, "cls" is usually used as an abbreviation for "class", while "clf" is used as an abbreviation for "classifier".

Copy link
Member

@trivialfis trivialfis Jan 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no specific request here. it's just an arbitrary convention in this project. Both are used.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thx, Done.

data[pred.prediction] = pd.Series(preds)

if pred_contrib_col_name:
contribs = pred_contribs(model, X, base_margin)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Q: Possible optimization: can we compute pred_contribs, proba, and prediction in one pass ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AFAIK, the xgboost API can predict 1 type at one time. @trivialfis can you correct me?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's correct. However, the Spark API can iterate through multiple types of prediction for meeting the spark convention if needed.

pred_contrib_col: "Param[str]" = Param(
Params._dummy(),
"pred_contrib_col",
"contribution prediction column name.",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can explain a bit more about the contribution prediction here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

"pred_contrib_col",
"contribution prediction column name.",
typeConverter=TypeConverters.toString,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to add this param into _pyspark_specific_params dict.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thx, Done.

@trivialfis trivialfis merged commit 72ec0c5 into dmlc:master Jan 11, 2023
@wbo4958 wbo4958 deleted the pred_contribs branch January 11, 2023 10:33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[pyspark] pred_contribs (shap) support
4 participants