Skip to content

Commit

Permalink
[PySpark] Expose Training and Validation Metrics (#11133)
Browse files Browse the repository at this point in the history
  • Loading branch information
ayoub317 authored Jan 13, 2025
1 parent c3aa7fe commit 461d27c
Show file tree
Hide file tree
Showing 3 changed files with 241 additions and 17 deletions.
26 changes: 18 additions & 8 deletions python-package/xgboost/spark/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@
HasFeaturesCols,
HasQueryIdCol,
)
from .summary import XGBoostTrainingSummary
from .utils import (
CommunicatorContext,
_get_default_params_from_func,
Expand Down Expand Up @@ -704,8 +705,10 @@ def _pyspark_model_cls(cls) -> Type["_SparkXGBModel"]:
"""
raise NotImplementedError()

def _create_pyspark_model(self, xgb_model: XGBModel) -> "_SparkXGBModel":
return self._pyspark_model_cls()(xgb_model)
def _create_pyspark_model(
self, xgb_model: XGBModel, training_summary: XGBoostTrainingSummary
) -> "_SparkXGBModel":
return self._pyspark_model_cls()(xgb_model, training_summary)

def _convert_to_sklearn_model(self, booster: bytearray, config: str) -> XGBModel:
xgb_sklearn_params = self._gen_xgb_params_dict(
Expand Down Expand Up @@ -1148,7 +1151,7 @@ def _train_booster(
if dvalid is not None:
dval = [(dtrain, "training"), (dvalid, "validation")]
else:
dval = None
dval = [(dtrain, "training")]
booster = worker_train(
params=booster_params,
dtrain=dtrain,
Expand All @@ -1159,6 +1162,7 @@ def _train_booster(
context.barrier()

if context.partitionId() == 0:
yield pd.DataFrame({"data": [json.dumps(dict(evals_result))]})
config = booster.save_config()
yield pd.DataFrame({"data": [config]})
booster_json = booster.save_raw("json").decode("utf-8")
Expand All @@ -1167,7 +1171,7 @@ def _train_booster(
booster_chunk = booster_json[offset : offset + _MODEL_CHUNK_SIZE]
yield pd.DataFrame({"data": [booster_chunk]})

def _run_job() -> Tuple[str, str]:
def _run_job() -> Tuple[str, str, str]:
rdd = (
dataset.mapInPandas(
_train_booster, # type: ignore
Expand All @@ -1179,7 +1183,7 @@ def _run_job() -> Tuple[str, str]:
rdd_with_resource = self._try_stage_level_scheduling(rdd)
ret = rdd_with_resource.collect()
data = [v[0] for v in ret]
return data[0], "".join(data[1:])
return data[0], data[1], "".join(data[2:])

get_logger(_LOG_TAG).info(
"Running xgboost-%s on %s workers with"
Expand All @@ -1192,13 +1196,14 @@ def _run_job() -> Tuple[str, str]:
train_call_kwargs_params,
dmatrix_kwargs,
)
(config, booster) = _run_job()
(evals_result, config, booster) = _run_job()
get_logger(_LOG_TAG).info("Finished xgboost training!")

result_xgb_model = self._convert_to_sklearn_model(
bytearray(booster, "utf-8"), config
)
spark_model = self._create_pyspark_model(result_xgb_model)
training_summary = XGBoostTrainingSummary.from_metrics(json.loads(evals_result))
spark_model = self._create_pyspark_model(result_xgb_model, training_summary)
# According to pyspark ML convention, the model uid should be the same
# with estimator uid.
spark_model._resetUid(self.uid)
Expand All @@ -1219,9 +1224,14 @@ def read(cls) -> "SparkXGBReader":


class _SparkXGBModel(Model, _SparkXGBParams, MLReadable, MLWritable):
def __init__(self, xgb_sklearn_model: Optional[XGBModel] = None) -> None:
def __init__(
self,
xgb_sklearn_model: Optional[XGBModel] = None,
training_summary: Optional[XGBoostTrainingSummary] = None,
) -> None:
super().__init__()
self._xgb_sklearn_model = xgb_sklearn_model
self.training_summary = training_summary

@classmethod
def _xgb_cls(cls) -> Type[XGBModel]:
Expand Down
43 changes: 43 additions & 0 deletions python-package/xgboost/spark/summary.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
"""Xgboost training summary integration submodule."""

from dataclasses import dataclass, field
from typing import Dict, List


@dataclass
class XGBoostTrainingSummary:
"""
A class that holds the training and validation objective history
of an XGBoost model during its training process.
"""

train_objective_history: Dict[str, List[float]] = field(default_factory=dict)
validation_objective_history: Dict[str, List[float]] = field(default_factory=dict)

@staticmethod
def from_metrics(
metrics: Dict[str, Dict[str, List[float]]]
) -> "XGBoostTrainingSummary":
"""
Create an XGBoostTrainingSummary instance from a nested dictionary of metrics.
Parameters
----------
metrics : dict of str to dict of str to list of float
A dictionary containing training and validation metrics.
Example format:
{
"training": {"logloss": [0.1, 0.08]},
"validation": {"logloss": [0.12, 0.1]}
}
Returns
-------
A new instance of XGBoostTrainingSummary.
"""
train_objective_history = metrics.get("training", {})
validation_objective_history = metrics.get("validation", {})
return XGBoostTrainingSummary(
train_objective_history, validation_objective_history
)
Loading

0 comments on commit 461d27c

Please sign in to comment.