Skip to content

Commit

Permalink
[pyspark] factor out pyspark to support typing
Browse files Browse the repository at this point in the history
Currently, pyspark will run into circular imports issue when enabling
typing for model.py. So this PR tried to refactor out pyspark a little
bit to avoid this.
  • Loading branch information
wbo4958 committed May 17, 2023
1 parent cb370c4 commit 57b7724
Show file tree
Hide file tree
Showing 4 changed files with 346 additions and 336 deletions.
271 changes: 189 additions & 82 deletions python-package/xgboost/spark/core.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
# type: ignore
"""Xgboost pyspark integration submodule for core code."""
import base64

# pylint: disable=fixme, too-many-ancestors, protected-access, no-member, invalid-name
# pylint: disable=too-few-public-methods, too-many-lines, too-many-branches
import json
import os
from collections import namedtuple
from typing import Iterator, List, Optional, Tuple

import numpy as np
import pandas as pd
from pyspark import cloudpickle
from pyspark.ml import Estimator, Model
from pyspark.ml.functions import array_to_vector, vector_to_array
from pyspark.ml.linalg import VectorUDT
Expand All @@ -21,7 +25,14 @@
HasValidationIndicatorCol,
HasWeightCol,
)
from pyspark.ml.util import MLReadable, MLWritable
from pyspark.ml.util import (
DefaultParamsReader,
DefaultParamsWriter,
MLReadable,
MLReader,
MLWritable,
MLWriter,
)
from pyspark.sql import DataFrame
from pyspark.sql.functions import col, countDistinct, pandas_udf, rand, struct
from pyspark.sql.types import (
Expand All @@ -36,7 +47,7 @@
from scipy.special import expit, softmax # pylint: disable=no-name-in-module

import xgboost
from xgboost import XGBClassifier, XGBRanker, XGBRegressor
from xgboost import XGBClassifier
from xgboost.compat import is_cudf_available
from xgboost.core import Booster
from xgboost.sklearn import DEFAULT_N_ESTIMATORS
Expand All @@ -49,12 +60,6 @@
pred_contribs,
stack_series,
)
from .model import (
SparkXGBModelReader,
SparkXGBModelWriter,
SparkXGBReader,
SparkXGBWriter,
)
from .params import (
HasArbitraryParamsDict,
HasBaseMarginCol,
Expand All @@ -71,8 +76,11 @@
_get_rabit_args,
_get_spark_session,
_is_local,
deserialize_booster,
deserialize_xgb_model,
get_class_name,
get_logger,
serialize_booster,
)

# Put pyspark specific params here, they won't be passed to XGBoost.
Expand Down Expand Up @@ -156,6 +164,8 @@
)
pred = Pred("prediction", "rawPrediction", "probability", "predContrib")

_INIT_BOOSTER_SAVE_PATH = "init_booster.json"


class _SparkXGBParams(
HasFeaturesCol,
Expand Down Expand Up @@ -1122,31 +1132,7 @@ def predict_udf(iterator: Iterator[pd.DataFrame]) -> Iterator[pd.Series]:
return dataset


class SparkXGBRegressorModel(_SparkXGBModel):
"""
The model returned by :func:`xgboost.spark.SparkXGBRegressor.fit`
.. Note:: This API is experimental.
"""

@classmethod
def _xgb_cls(cls):
return XGBRegressor


class SparkXGBRankerModel(_SparkXGBModel):
"""
The model returned by :func:`xgboost.spark.SparkXGBRanker.fit`
.. Note:: This API is experimental.
"""

@classmethod
def _xgb_cls(cls):
return XGBRanker


class SparkXGBClassifierModel(
class _ClassificationModel( # pylint: disable=abstract-method
_SparkXGBModel, HasProbabilityCol, HasRawPredictionCol, HasContribPredictionCol
):
"""
Expand All @@ -1155,10 +1141,6 @@ class SparkXGBClassifierModel(
.. Note:: This API is experimental.
"""

@classmethod
def _xgb_cls(cls):
return XGBClassifier

def _transform(self, dataset):
# pylint: disable=too-many-statements, too-many-locals
# Save xgb_sklearn_model and predict_params to be local variable
Expand Down Expand Up @@ -1286,53 +1268,178 @@ def predict_udf(
return dataset.drop(pred_struct_col)


def _set_pyspark_xgb_cls_param_attrs(pyspark_estimator_class, pyspark_model_class):
params_dict = pyspark_estimator_class._get_xgb_params_default()

def param_value_converter(v):
if isinstance(v, np.generic):
# convert numpy scalar values to corresponding python scalar values
return np.array(v).item()
if isinstance(v, dict):
return {k: param_value_converter(nv) for k, nv in v.items()}
if isinstance(v, list):
return [param_value_converter(nv) for nv in v]
return v

def set_param_attrs(attr_name, param_obj_):
param_obj_.typeConverter = param_value_converter
setattr(pyspark_estimator_class, attr_name, param_obj_)
setattr(pyspark_model_class, attr_name, param_obj_)

for name in params_dict.keys():
doc = (
f"Refer to XGBoost doc of "
f"{get_class_name(pyspark_estimator_class._xgb_cls())} for this param {name}"
class _SparkXGBSharedReadWrite:
@staticmethod
def saveMetadata(instance, path, sc, logger, extraMetadata=None):
"""
Save the metadata of an xgboost.spark._SparkXGBEstimator or
xgboost.spark._SparkXGBModel.
"""
instance._validate_params()
skipParams = ["callbacks", "xgb_model"]
jsonParams = {}
for p, v in instance._paramMap.items(): # pylint: disable=protected-access
if p.name not in skipParams:
jsonParams[p.name] = v

extraMetadata = extraMetadata or {}
callbacks = instance.getOrDefault(instance.callbacks)
if callbacks is not None:
logger.warning(
"The callbacks parameter is saved using cloudpickle and it "
"is not a fully self-contained format. It may fail to load "
"with different versions of dependencies."
)
serialized_callbacks = base64.encodebytes(
cloudpickle.dumps(callbacks)
).decode("ascii")
extraMetadata["serialized_callbacks"] = serialized_callbacks
init_booster = instance.getOrDefault(instance.xgb_model)
if init_booster is not None:
extraMetadata["init_booster"] = _INIT_BOOSTER_SAVE_PATH
DefaultParamsWriter.saveMetadata(
instance, path, sc, extraMetadata=extraMetadata, paramMap=jsonParams
)
if init_booster is not None:
ser_init_booster = serialize_booster(init_booster)
save_path = os.path.join(path, _INIT_BOOSTER_SAVE_PATH)
_get_spark_session().createDataFrame(
[(ser_init_booster,)], ["init_booster"]
).write.parquet(save_path)

@staticmethod
def loadMetadataAndInstance(pyspark_xgb_cls, path, sc, logger):
"""
Load the metadata and the instance of an xgboost.spark._SparkXGBEstimator or
xgboost.spark._SparkXGBModel.
param_obj = Param(Params._dummy(), name=name, doc=doc)
set_param_attrs(name, param_obj)

fit_params_dict = pyspark_estimator_class._get_fit_params_default()
for name in fit_params_dict.keys():
doc = (
f"Refer to XGBoost doc of {get_class_name(pyspark_estimator_class._xgb_cls())}"
f".fit() for this param {name}"
:return: a tuple of (metadata, instance)
"""
metadata = DefaultParamsReader.loadMetadata(
path, sc, expectedClassName=get_class_name(pyspark_xgb_cls)
)
if name == "callbacks":
doc += (
"The callbacks can be arbitrary functions. It is saved using cloudpickle "
"which is not a fully self-contained format. It may fail to load with "
"different versions of dependencies."
pyspark_xgb = pyspark_xgb_cls()
DefaultParamsReader.getAndSetParams(pyspark_xgb, metadata)

if "serialized_callbacks" in metadata:
serialized_callbacks = metadata["serialized_callbacks"]
try:
callbacks = cloudpickle.loads(
base64.decodebytes(serialized_callbacks.encode("ascii"))
)
pyspark_xgb.set(pyspark_xgb.callbacks, callbacks)
except Exception as e: # pylint: disable=W0703
logger.warning(
f"Fails to load the callbacks param due to {e}. Please set the "
"callbacks param manually for the loaded estimator."
)

if "init_booster" in metadata:
load_path = os.path.join(path, metadata["init_booster"])
ser_init_booster = (
_get_spark_session().read.parquet(load_path).collect()[0].init_booster
)
param_obj = Param(Params._dummy(), name=name, doc=doc)
set_param_attrs(name, param_obj)

predict_params_dict = pyspark_estimator_class._get_predict_params_default()
for name in predict_params_dict.keys():
doc = (
f"Refer to XGBoost doc of {get_class_name(pyspark_estimator_class._xgb_cls())}"
f".predict() for this param {name}"
init_booster = deserialize_booster(ser_init_booster)
pyspark_xgb.set(pyspark_xgb.xgb_model, init_booster)

pyspark_xgb._resetUid(metadata["uid"]) # pylint: disable=protected-access
return metadata, pyspark_xgb


class SparkXGBWriter(MLWriter):
"""
Spark Xgboost estimator writer.
"""

def __init__(self, instance):
super().__init__()
self.instance = instance
self.logger = get_logger(self.__class__.__name__, level="WARN")

def saveImpl(self, path):
"""
save model.
"""
_SparkXGBSharedReadWrite.saveMetadata(self.instance, path, self.sc, self.logger)


class SparkXGBReader(MLReader):
"""
Spark Xgboost estimator reader.
"""

def __init__(self, cls):
super().__init__()
self.cls = cls
self.logger = get_logger(self.__class__.__name__, level="WARN")

def load(self, path):
"""
load model.
"""
_, pyspark_xgb = _SparkXGBSharedReadWrite.loadMetadataAndInstance(
self.cls, path, self.sc, self.logger
)
return pyspark_xgb


class SparkXGBModelWriter(MLWriter):
"""
Spark Xgboost model writer.
"""

def __init__(self, instance):
super().__init__()
self.instance = instance
self.logger = get_logger(self.__class__.__name__, level="WARN")

def saveImpl(self, path):
"""
Save metadata and model for a :py:class:`_SparkXGBModel`
- save metadata to path/metadata
- save model to path/model.json
"""
xgb_model = self.instance._xgb_sklearn_model
_SparkXGBSharedReadWrite.saveMetadata(self.instance, path, self.sc, self.logger)
model_save_path = os.path.join(path, "model")
booster = xgb_model.get_booster().save_raw("json").decode("utf-8")
_get_spark_session().sparkContext.parallelize([booster], 1).saveAsTextFile(
model_save_path
)


class SparkXGBModelReader(MLReader):
"""
Spark Xgboost model reader.
"""

def __init__(self, cls):
super().__init__()
self.cls = cls
self.logger = get_logger(self.__class__.__name__, level="WARN")

def load(self, path):
"""
Load metadata and model for a :py:class:`_SparkXGBModel`
:return: SparkXGBRegressorModel or SparkXGBClassifierModel instance
"""
_, py_model = _SparkXGBSharedReadWrite.loadMetadataAndInstance(
self.cls, path, self.sc, self.logger
)

xgb_sklearn_params = py_model._gen_xgb_params_dict(
gen_xgb_sklearn_estimator_param=True
)
model_load_path = os.path.join(path, "model")

ser_xgb_model = (
_get_spark_session().sparkContext.textFile(model_load_path).collect()[0]
)
param_obj = Param(Params._dummy(), name=name, doc=doc)
set_param_attrs(name, param_obj)

def create_xgb_model():
return self.cls._xgb_cls()(**xgb_sklearn_params)

xgb_model = deserialize_xgb_model(ser_xgb_model, create_xgb_model)
py_model._xgb_sklearn_model = xgb_model
return py_model
Loading

0 comments on commit 57b7724

Please sign in to comment.