From f6effa17344aeb79a7739f562abe9acc8c9bf488 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Sat, 17 Dec 2022 00:15:15 +0800 Subject: [PATCH] Support `Series` and Python primitives in `inplace_predict` and QDM (#8547) --- python-package/xgboost/core.py | 57 +++++++++++++++----------- python-package/xgboost/data.py | 13 +++--- python-package/xgboost/testing/data.py | 6 +++ tests/python/test_predict.py | 17 +++++++- tests/python/test_with_pandas.py | 33 +++++++++------ 5 files changed, 82 insertions(+), 44 deletions(-) diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index 3c467e9fe2c4..61c9038982b1 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -2220,17 +2220,15 @@ def inplace_predict( preds = ctypes.POINTER(ctypes.c_float)() # once caching is supported, we can pass id(data) as cache id. - args = { - "type": 0, - "training": False, - "iteration_begin": iteration_range[0], - "iteration_end": iteration_range[1], - "missing": missing, - "strict_shape": strict_shape, - "cache_id": 0, - } - if predict_type == "margin": - args["type"] = 1 + args = make_jcargs( + type=1 if predict_type == "margin" else 0, + training=False, + iteration_begin=iteration_range[0], + iteration_end=iteration_range[1], + missing=missing, + strict_shape=strict_shape, + cache_id=0, + ) shape = ctypes.POINTER(c_bst_ulong)() dims = c_bst_ulong() @@ -2243,30 +2241,39 @@ def inplace_predict( proxy = None p_handle = ctypes.c_void_p() assert proxy is None or isinstance(proxy, _ProxyDMatrix) - if validate_features: - if not hasattr(data, "shape"): - raise TypeError( - "`shape` attribute is required when `validate_features` is True." - ) - if len(data.shape) != 1 and self.num_features() != data.shape[1]: - raise ValueError( - f"Feature shape mismatch, expected: {self.num_features()}, " - f"got {data.shape[1]}" - ) from .data import ( _array_interface, _is_cudf_df, _is_cupy_array, + _is_list, _is_pandas_df, + _is_pandas_series, + _is_tuple, _transform_pandas_df, ) enable_categorical = True + if _is_pandas_series(data): + import pandas as pd + data = pd.DataFrame(data) if _is_pandas_df(data): data, fns, _ = _transform_pandas_df(data, enable_categorical) if validate_features: self._validate_features(fns) + if _is_list(data) or _is_tuple(data): + data = np.array(data) + + if validate_features: + if not hasattr(data, "shape"): + raise TypeError( + "`shape` attribute is required when `validate_features` is True." + ) + if len(data.shape) != 1 and self.num_features() != data.shape[1]: + raise ValueError( + f"Feature shape mismatch, expected: {self.num_features()}, " + f"got {data.shape[1]}" + ) if isinstance(data, np.ndarray): from .data import _ensure_np_dtype @@ -2276,7 +2283,7 @@ def inplace_predict( _LIB.XGBoosterPredictFromDense( self.handle, _array_interface(data), - from_pystr_to_cstr(json.dumps(args)), + args, p_handle, ctypes.byref(shape), ctypes.byref(dims), @@ -2293,7 +2300,7 @@ def inplace_predict( _array_interface(csr.indices), _array_interface(csr.data), c_bst_ulong(csr.shape[1]), - from_pystr_to_cstr(json.dumps(args)), + args, p_handle, ctypes.byref(shape), ctypes.byref(dims), @@ -2310,7 +2317,7 @@ def inplace_predict( _LIB.XGBoosterPredictFromCudaArray( self.handle, interface_str, - from_pystr_to_cstr(json.dumps(args)), + args, p_handle, ctypes.byref(shape), ctypes.byref(dims), @@ -2331,7 +2338,7 @@ def inplace_predict( _LIB.XGBoosterPredictFromCudaColumnar( self.handle, interfaces_str, - from_pystr_to_cstr(json.dumps(args)), + args, p_handle, ctypes.byref(shape), ctypes.byref(dims), diff --git a/python-package/xgboost/data.py b/python-package/xgboost/data.py index d06fa27e6502..6daed547095d 100644 --- a/python-package/xgboost/data.py +++ b/python-package/xgboost/data.py @@ -958,12 +958,12 @@ def dispatch_data_backend( return _from_list(data, missing, threads, feature_names, feature_types) if _is_tuple(data): return _from_tuple(data, missing, threads, feature_names, feature_types) - if _is_pandas_df(data): - return _from_pandas_df(data, enable_categorical, missing, threads, - feature_names, feature_types) if _is_pandas_series(data): - return _from_pandas_series( - data, missing, threads, enable_categorical, feature_names, feature_types + import pandas as pd + data = pd.DataFrame(data) + if _is_pandas_df(data): + return _from_pandas_df( + data, enable_categorical, missing, threads, feature_names, feature_types ) if _is_cudf_df(data) or _is_cudf_ser(data): return _from_cudf_df( @@ -1205,6 +1205,9 @@ def _proxy_transform( return data, None, feature_names, feature_types if _is_scipy_csr(data): return data, None, feature_names, feature_types + if _is_pandas_series(data): + import pandas as pd + data = pd.DataFrame(data) if _is_pandas_df(data): arr, feature_names, feature_types = _transform_pandas_df( data, enable_categorical, feature_names, feature_types diff --git a/python-package/xgboost/testing/data.py b/python-package/xgboost/testing/data.py index 7d63097dc055..d6aad54e6d24 100644 --- a/python-package/xgboost/testing/data.py +++ b/python-package/xgboost/testing/data.py @@ -40,6 +40,7 @@ def np_dtypes( for dtype in dtypes: X = np.array(orig, dtype=dtype) yield orig, X + yield orig.tolist(), X.tolist() for dtype in dtypes: X = np.array(orig, dtype=dtype) @@ -101,6 +102,11 @@ def pd_dtypes() -> Generator: {"f0": [1.0, 2.0, Null, 3.0], "f1": [3.0, 2.0, Null, 1.0]}, dtype=dtype ) yield orig, df + ser_orig = orig["f0"] + ser = df["f0"] + assert isinstance(ser, pd.Series) + assert isinstance(ser_orig, pd.Series) + yield ser_orig, ser # Categorical orig = orig.astype("category") diff --git a/tests/python/test_predict.py b/tests/python/test_predict.py index 787188b11dac..63c0ff9d753b 100644 --- a/tests/python/test_predict.py +++ b/tests/python/test_predict.py @@ -5,7 +5,7 @@ import pandas as pd import pytest from scipy import sparse -from xgboost.testing.data import np_dtypes +from xgboost.testing.data import np_dtypes, pd_dtypes from xgboost.testing.shared import validate_leaf_output import xgboost as xgb @@ -231,6 +231,7 @@ def test_base_margin(self): from_dmatrix = booster.predict(dtrain) np.testing.assert_allclose(from_dmatrix, from_inplace) + @pytest.mark.skipif(**tm.no_pandas()) def test_dtypes(self) -> None: for orig, x in np_dtypes(self.rows, self.cols): predt_orig = self.booster.inplace_predict(orig) @@ -246,3 +247,17 @@ def test_dtypes(self) -> None: X: np.ndarray = np.array(orig, dtype=dtype) with pytest.raises(ValueError): self.booster.inplace_predict(X) + + @pytest.mark.skipif(**tm.no_pandas()) + def test_pd_dtypes(self) -> None: + from pandas.api.types import is_bool_dtype + for orig, x in pd_dtypes(): + dtypes = orig.dtypes if isinstance(orig, pd.DataFrame) else [orig.dtypes] + if isinstance(orig, pd.DataFrame) and is_bool_dtype(dtypes[0]): + continue + y = np.arange(x.shape[0]) + Xy = xgb.DMatrix(orig, y, enable_categorical=True) + booster = xgb.train({"tree_method": "hist"}, Xy, num_boost_round=1) + predt_orig = booster.inplace_predict(orig) + predt = booster.inplace_predict(x) + np.testing.assert_allclose(predt, predt_orig) diff --git a/tests/python/test_with_pandas.py b/tests/python/test_with_pandas.py index 863569691274..9bb81c6581d5 100644 --- a/tests/python/test_with_pandas.py +++ b/tests/python/test_with_pandas.py @@ -298,22 +298,29 @@ def test_cv_as_pandas(self): assert 'auc' not in cv.columns[0] assert 'error' in cv.columns[0] - def test_nullable_type(self) -> None: + @pytest.mark.parametrize("DMatrixT", [xgb.DMatrix, xgb.QuantileDMatrix]) + def test_nullable_type(self, DMatrixT) -> None: from pandas.api.types import is_categorical - for DMatrixT in (xgb.DMatrix, xgb.QuantileDMatrix): - for orig, df in pd_dtypes(): + for orig, df in pd_dtypes(): + if hasattr(df.dtypes, "__iter__"): enable_categorical = any(is_categorical for dtype in df.dtypes) - - m_orig = DMatrixT(orig, enable_categorical=enable_categorical) - # extension types - m_etype = DMatrixT(df, enable_categorical=enable_categorical) - # different from pd.BooleanDtype(), None is converted to False with bool - if any(dtype == "bool" for dtype in orig.dtypes): - assert not tm.predictor_equal(m_orig, m_etype) - else: - assert tm.predictor_equal(m_orig, m_etype) - + else: + # series + enable_categorical = is_categorical(df.dtype) + + m_orig = DMatrixT(orig, enable_categorical=enable_categorical) + # extension types + m_etype = DMatrixT(df, enable_categorical=enable_categorical) + # different from pd.BooleanDtype(), None is converted to False with bool + if hasattr(orig.dtypes, "__iter__") and any( + dtype == "bool" for dtype in orig.dtypes + ): + assert not tm.predictor_equal(m_orig, m_etype) + else: + assert tm.predictor_equal(m_orig, m_etype) + + if isinstance(df, pd.DataFrame): f0 = df["f0"] with pytest.raises(ValueError, match="Label contains NaN"): xgb.DMatrix(df, f0, enable_categorical=enable_categorical)