diff --git a/include/xgboost/generic_parameters.h b/include/xgboost/generic_parameters.h index 1ab4b6a031c9..30fe08882eab 100644 --- a/include/xgboost/generic_parameters.h +++ b/include/xgboost/generic_parameters.h @@ -35,6 +35,7 @@ struct GenericParameter : public dmlc::Parameter { DMLC_DECLARE_PARAMETER(GenericParameter) { DMLC_DECLARE_FIELD(seed).set_default(0).describe( "Random number seed during training."); + DMLC_DECLARE_ALIAS(seed, random_state); DMLC_DECLARE_FIELD(seed_per_iteration) .set_default(false) .describe( diff --git a/python-package/xgboost/sklearn.py b/python-package/xgboost/sklearn.py index 829724258fe7..9e34840d29b4 100644 --- a/python-package/xgboost/sklearn.py +++ b/python-package/xgboost/sklearn.py @@ -62,17 +62,18 @@ class XGBModel(XGBModelBase): Number of trees to fit. verbosity : int The degree of verbosity. Valid values are 0 (silent) - 3 (debug). - silent : boolean - Whether to print messages while running boosting. Deprecated. Use verbosity instead. objective : string or callable Specify the learning task and the corresponding learning objective or a custom objective function to be used (see note below). booster: string Specify which booster to use: gbtree, gblinear or dart. - nthread : int - Number of parallel threads used to run xgboost. (Deprecated, please use ``n_jobs``) + tree_method: string + Specify which tree method to use. Default to auto. If this parameter + is set to default, XGBoost will choose the most conservative option + available. It's recommended to study this option from parameters + document. n_jobs : int - Number of parallel threads used to run xgboost. (replaces ``nthread``) + Number of parallel threads used to run xgboost. gamma : float Minimum loss reduction required to make a further partition on a leaf node of the tree. min_child_weight : int @@ -95,13 +96,17 @@ class XGBModel(XGBModelBase): Balancing of positive and negative weights. base_score: The initial prediction score of all instances, global bias. - seed : int - Random number seed. (Deprecated, please use random_state) random_state : int - Random number seed. (replaces seed) + Random number seed. + + .. note:: Using gblinear booster with shotgun updater is + nondeterministic as it uses Hogwild algorithm. + missing : float, optional Value in the data which needs to be present as a missing value. If None, defaults to np.nan. + num_parallel_tree: int + Used for boosting random forest. importance_type: string, default "gain" The feature importance type for the feature_importances\\_ property: either "gain", "weight", "cover", "total_gain" or "total_cover". @@ -131,25 +136,27 @@ class XGBModel(XGBModelBase): The value of the gradient for each sample point. hess: array_like of shape [n_samples] The value of the second derivative for each sample point + """ def __init__(self, max_depth=3, learning_rate=0.1, n_estimators=100, - verbosity=1, silent=None, objective="reg:squarederror", - booster='gbtree', n_jobs=1, nthread=None, gamma=0, + verbosity=1, objective="reg:squarederror", + booster='gbtree', tree_method='auto', n_jobs=1, gamma=0, min_child_weight=1, max_delta_step=0, subsample=1, colsample_bytree=1, colsample_bylevel=1, colsample_bynode=1, reg_alpha=0, reg_lambda=1, scale_pos_weight=1, base_score=0.5, - random_state=0, seed=None, missing=None, + random_state=0, missing=None, num_parallel_tree=1, importance_type="gain", **kwargs): if not SKLEARN_INSTALLED: - raise XGBoostError('sklearn needs to be installed in order to use this module') + raise XGBoostError( + 'sklearn needs to be installed in order to use this module') self.max_depth = max_depth self.learning_rate = learning_rate self.n_estimators = n_estimators self.verbosity = verbosity - self.silent = silent self.objective = objective self.booster = booster + self.tree_method = tree_method self.gamma = gamma self.min_child_weight = min_child_weight self.max_delta_step = max_delta_step @@ -162,11 +169,10 @@ def __init__(self, max_depth=3, learning_rate=0.1, n_estimators=100, self.scale_pos_weight = scale_pos_weight self.base_score = base_score self.missing = missing if missing is not None else np.nan + self.num_parallel_tree = num_parallel_tree self.kwargs = kwargs self._Booster = None - self.seed = seed self.random_state = random_state - self.nthread = nthread self.n_jobs = n_jobs self.importance_type = importance_type @@ -227,33 +233,6 @@ def get_params(self, deep=False): def get_xgb_params(self): """Get xgboost type parameters.""" xgb_params = self.get_params() - random_state = xgb_params.pop('random_state') - if 'seed' in xgb_params and xgb_params['seed'] is not None: - warnings.warn('The seed parameter is deprecated as of version .6.' - 'Please use random_state instead.' - 'seed is deprecated.', DeprecationWarning) - else: - xgb_params['seed'] = random_state - n_jobs = xgb_params.pop('n_jobs') - if 'nthread' in xgb_params and xgb_params['nthread'] is not None: - warnings.warn('The nthread parameter is deprecated as of version .6.' - 'Please use n_jobs instead.' - 'nthread is deprecated.', DeprecationWarning) - else: - xgb_params['nthread'] = n_jobs - - if 'silent' in xgb_params and xgb_params['silent'] is not None: - warnings.warn('The silent parameter is deprecated.' - 'Please use verbosity instead.' - 'silent is depreated', DeprecationWarning) - # TODO(canonizer): set verbosity explicitly if silent is removed from xgboost, - # but remains in this API - else: - # silent=None shouldn't be passed to xgboost - xgb_params.pop('silent', None) - - if xgb_params['nthread'] <= 0: - xgb_params.pop('nthread', None) return xgb_params def get_num_boosting_rounds(self): @@ -301,7 +280,7 @@ def load_model(self, fname): Input file name or memory buffer(see also save_raw) """ if self._Booster is None: - self._Booster = Booster({'nthread': self.n_jobs}) + self._Booster = Booster({'n_jobs': self.n_jobs}) self._Booster.load_model(fname) def fit(self, X, y, sample_weight=None, eval_set=None, eval_metric=None, @@ -364,13 +343,17 @@ def fit(self, X, y, sample_weight=None, eval_set=None, eval_metric=None, """ if sample_weight is not None: trainDmatrix = DMatrix(X, label=y, weight=sample_weight, - missing=self.missing, nthread=self.n_jobs) + missing=self.missing, + nthread=self.n_jobs) else: - trainDmatrix = DMatrix(X, label=y, missing=self.missing, nthread=self.n_jobs) + trainDmatrix = DMatrix(X, label=y, missing=self.missing, + nthread=self.n_jobs) evals_result = {} if eval_set is not None: + if not isinstance(eval_set[0], (list, tuple)): + raise TypeError('Unexpected input type for `eval_set`') if sample_weight_eval_set is None: sample_weight_eval_set = [None] * len(eval_set) evals = list( @@ -610,22 +593,27 @@ class XGBClassifier(XGBModel, XGBClassifierBase): __doc__ = "Implementation of the scikit-learn API for XGBoost classification.\n\n" \ + '\n'.join(XGBModel.__doc__.split('\n')[2:]) - def __init__(self, max_depth=3, learning_rate=0.1, n_estimators=100, - verbosity=1, silent=None, + def __init__(self, max_depth=3, learning_rate=0.1, + n_estimators=100, verbosity=1, objective="binary:logistic", booster='gbtree', - n_jobs=1, nthread=None, gamma=0, min_child_weight=1, max_delta_step=0, - subsample=1, colsample_bytree=1, colsample_bylevel=1, - colsample_bynode=1, reg_alpha=0, reg_lambda=1, scale_pos_weight=1, - base_score=0.5, random_state=0, seed=None, missing=None, **kwargs): + tree_method='auto', n_jobs=1, gpu_id=-1, gamma=0, + min_child_weight=1, max_delta_step=0, subsample=1, + colsample_bytree=1, colsample_bylevel=1, colsample_bynode=1, + reg_alpha=0, reg_lambda=1, scale_pos_weight=1, base_score=0.5, + random_state=0, missing=None, **kwargs): super(XGBClassifier, self).__init__( - max_depth=max_depth, learning_rate=learning_rate, n_estimators=n_estimators, - verbosity=verbosity, silent=silent, objective=objective, booster=booster, - n_jobs=n_jobs, nthread=nthread, gamma=gamma, - min_child_weight=min_child_weight, max_delta_step=max_delta_step, - subsample=subsample, colsample_bytree=colsample_bytree, - colsample_bylevel=colsample_bylevel, colsample_bynode=colsample_bynode, - reg_alpha=reg_alpha, reg_lambda=reg_lambda, scale_pos_weight=scale_pos_weight, - base_score=base_score, random_state=random_state, seed=seed, missing=missing, + max_depth=max_depth, learning_rate=learning_rate, + n_estimators=n_estimators, verbosity=verbosity, + objective=objective, booster=booster, tree_method=tree_method, + n_jobs=n_jobs, gpu_id=gpu_id, gamma=gamma, + min_child_weight=min_child_weight, + max_delta_step=max_delta_step, subsample=subsample, + colsample_bytree=colsample_bytree, + colsample_bylevel=colsample_bylevel, + colsample_bynode=colsample_bynode, + reg_alpha=reg_alpha, reg_lambda=reg_lambda, + scale_pos_weight=scale_pos_weight, + base_score=base_score, random_state=random_state, missing=missing, **kwargs) def fit(self, X, y, sample_weight=None, eval_set=None, eval_metric=None, @@ -676,6 +664,11 @@ def fit(self, X, y, sample_weight=None, eval_set=None, eval_metric=None, else: evals = () + if len(X.shape) != 2: + # Simply raise an error here since there might be many + # different ways of reshaping + raise ValueError( + 'Please reshape the input data X into 2-dimensional matrix.') self._features_count = X.shape[1] if sample_weight is not None: @@ -846,26 +839,27 @@ def evals_result(self): class XGBRFClassifier(XGBClassifier): # pylint: disable=missing-docstring - __doc__ = "Experimental implementation of the scikit-learn API "\ - + "for XGBoost random forest classification.\n\n"\ - + '\n'.join(XGBModel.__doc__.split('\n')[2:]) + __doc__ = "scikit-learn API for XGBoost random forest classification.\n\n"\ + + '\n'.join(XGBModel.__doc__.split('\n')[2:]) def __init__(self, max_depth=3, learning_rate=1, n_estimators=100, - verbosity=1, silent=None, - objective="binary:logistic", n_jobs=1, nthread=None, gamma=0, - min_child_weight=1, max_delta_step=0, subsample=0.8, colsample_bytree=1, - colsample_bylevel=1, colsample_bynode=0.8, reg_alpha=0, reg_lambda=1e-5, - scale_pos_weight=1, base_score=0.5, random_state=0, seed=None, + verbosity=1, objective="binary:logistic", n_jobs=1, + gpu_id=-1, gamma=0, min_child_weight=1, max_delta_step=0, + subsample=0.8, colsample_bytree=1, colsample_bylevel=1, + colsample_bynode=0.8, reg_alpha=0, reg_lambda=1e-5, + scale_pos_weight=1, base_score=0.5, random_state=0, missing=None, **kwargs): super(XGBRFClassifier, self).__init__( - max_depth=max_depth, learning_rate=learning_rate, n_estimators=n_estimators, - verbosity=verbosity, silent=silent, objective=objective, booster='gbtree', - n_jobs=n_jobs, nthread=nthread, gamma=gamma, - min_child_weight=min_child_weight, max_delta_step=max_delta_step, + max_depth=max_depth, learning_rate=learning_rate, + n_estimators=n_estimators, verbosity=verbosity, + objective=objective, booster='gbtree', n_jobs=n_jobs, + gpu_id=gpu_id, gamma=gamma, min_child_weight=min_child_weight, + max_delta_step=max_delta_step, subsample=subsample, colsample_bytree=colsample_bytree, - colsample_bylevel=colsample_bylevel, colsample_bynode=colsample_bynode, - reg_alpha=reg_alpha, reg_lambda=reg_lambda, scale_pos_weight=scale_pos_weight, - base_score=base_score, random_state=random_state, seed=seed, missing=missing, + colsample_bylevel=colsample_bylevel, + colsample_bynode=colsample_bynode, reg_alpha=reg_alpha, + reg_lambda=reg_lambda, scale_pos_weight=scale_pos_weight, + base_score=base_score, random_state=random_state, missing=missing, **kwargs) def get_xgb_params(self): @@ -885,26 +879,28 @@ class XGBRegressor(XGBModel, XGBRegressorBase): class XGBRFRegressor(XGBRegressor): # pylint: disable=missing-docstring - __doc__ = "Experimental implementation of the scikit-learn API "\ - + "for XGBoost random forest regression.\n\n"\ - + '\n'.join(XGBModel.__doc__.split('\n')[2:]) + __doc__ = "scikit-learn API for XGBoost random forest regression.\n\n"\ + + '\n'.join(XGBModel.__doc__.split('\n')[2:]) def __init__(self, max_depth=3, learning_rate=1, n_estimators=100, - verbosity=1, silent=None, - objective="reg:squarederror", n_jobs=1, nthread=None, gamma=0, - min_child_weight=1, max_delta_step=0, subsample=0.8, colsample_bytree=1, - colsample_bylevel=1, colsample_bynode=0.8, reg_alpha=0, reg_lambda=1e-5, - scale_pos_weight=1, base_score=0.5, random_state=0, seed=None, - missing=None, **kwargs): + verbosity=1, objective="reg:squarederror", n_jobs=1, + gpu_id=-1, gamma=0, min_child_weight=1, + max_delta_step=0, subsample=0.8, colsample_bytree=1, + colsample_bylevel=1, colsample_bynode=0.8, reg_alpha=0, + reg_lambda=1e-5, scale_pos_weight=1, base_score=0.5, + random_state=0, missing=None, **kwargs): super(XGBRFRegressor, self).__init__( - max_depth=max_depth, learning_rate=learning_rate, n_estimators=n_estimators, - verbosity=verbosity, silent=silent, objective=objective, booster='gbtree', - n_jobs=n_jobs, nthread=nthread, gamma=gamma, - min_child_weight=min_child_weight, max_delta_step=max_delta_step, - subsample=subsample, colsample_bytree=colsample_bytree, - colsample_bylevel=colsample_bylevel, colsample_bynode=colsample_bynode, - reg_alpha=reg_alpha, reg_lambda=reg_lambda, scale_pos_weight=scale_pos_weight, - base_score=base_score, random_state=random_state, seed=seed, missing=missing, + max_depth=max_depth, learning_rate=learning_rate, + n_estimators=n_estimators, verbosity=verbosity, + objective=objective, booster='gbtree', n_jobs=n_jobs, + gpu_id=gpu_id, gamma=gamma, min_child_weight=min_child_weight, + max_delta_step=max_delta_step, subsample=subsample, + colsample_bytree=colsample_bytree, + colsample_bylevel=colsample_bylevel, + colsample_bynode=colsample_bynode, + reg_alpha=reg_alpha, reg_lambda=reg_lambda, + scale_pos_weight=scale_pos_weight, + base_score=base_score, random_state=random_state, missing=missing, **kwargs) def get_xgb_params(self): @@ -930,17 +926,13 @@ class XGBRanker(XGBModel): Number of boosted trees to fit. verbosity : int The degree of verbosity. Valid values are 0 (silent) - 3 (debug). - silent : boolean - Whether to print messages while running boosting. Deprecated. Use verbosity instead. objective : string Specify the learning task and the corresponding learning objective. The objective name must start with "rank:". booster: string Specify which booster to use: gbtree, gblinear or dart. - nthread : int - Number of parallel threads used to run xgboost. (Deprecated, please use ``n_jobs``) n_jobs : int - Number of parallel threads used to run xgboost. (replaces ``nthread``) + Number of parallel threads used to run xgboost. gamma : float Minimum loss reduction required to make a further partition on a leaf node of the tree. min_child_weight : int @@ -963,10 +955,12 @@ class XGBRanker(XGBModel): Balancing of positive and negative weights. base_score: The initial prediction score of all instances, global bias. - seed : int - Random number seed. (Deprecated, please use random_state) random_state : int - Random number seed. (replaces seed) + Random number seed. + + .. note:: Using gblinear booster with shotgun updater is + nondeterministic as it uses Hogwild algorithm. + missing : float, optional Value in the data which needs to be present as a missing value. If None, defaults to np.nan. @@ -1015,33 +1009,39 @@ class XGBRanker(XGBModel): +-------+-----------+---------------+ then your group array should be ``[3, 4]``. - """ + + """ def __init__(self, max_depth=3, learning_rate=0.1, n_estimators=100, - verbosity=1, silent=None, objective="rank:pairwise", booster='gbtree', - n_jobs=-1, nthread=None, gamma=0, min_child_weight=1, max_delta_step=0, - subsample=1, colsample_bytree=1, colsample_bylevel=1, colsample_bynode=1, - reg_alpha=0, reg_lambda=1, scale_pos_weight=1, - base_score=0.5, random_state=0, seed=None, missing=None, **kwargs): + verbosity=1, objective="rank:pairwise", booster='gbtree', + tree_method='auto', n_jobs=-1, gpu_id=-1, gamma=0, + min_child_weight=1, max_delta_step=0, subsample=1, + colsample_bytree=1, colsample_bylevel=1, colsample_bynode=1, + reg_alpha=0, reg_lambda=1, scale_pos_weight=1, base_score=0.5, + random_state=0, missing=None, **kwargs): super(XGBRanker, self).__init__( - max_depth=max_depth, learning_rate=learning_rate, n_estimators=n_estimators, - verbosity=verbosity, silent=silent, objective=objective, booster=booster, - n_jobs=n_jobs, nthread=nthread, gamma=gamma, + max_depth=max_depth, learning_rate=learning_rate, + n_estimators=n_estimators, verbosity=verbosity, + objective=objective, booster=booster, tree_method=tree_method, + n_jobs=n_jobs, gpu_id=gpu_id, gamma=gamma, min_child_weight=min_child_weight, max_delta_step=max_delta_step, subsample=subsample, colsample_bytree=colsample_bytree, - colsample_bylevel=colsample_bylevel, colsample_bynode=colsample_bynode, - reg_alpha=reg_alpha, reg_lambda=reg_lambda, - scale_pos_weight=scale_pos_weight, base_score=base_score, - random_state=random_state, seed=seed, missing=missing, **kwargs) + colsample_bylevel=colsample_bylevel, + colsample_bynode=colsample_bynode, reg_alpha=reg_alpha, + reg_lambda=reg_lambda, scale_pos_weight=scale_pos_weight, + base_score=base_score, random_state=random_state, missing=missing, + **kwargs) if callable(self.objective): - raise ValueError("custom objective function not supported by XGBRanker") + raise ValueError( + "custom objective function not supported by XGBRanker") if "rank:" not in self.objective: raise ValueError("please use XGBRanker for ranking task") - def fit(self, X, y, group, sample_weight=None, eval_set=None, sample_weight_eval_set=None, - eval_group=None, eval_metric=None, early_stopping_rounds=None, - verbose=False, xgb_model=None, callbacks=None): + def fit(self, X, y, group, sample_weight=None, eval_set=None, + sample_weight_eval_set=None, eval_group=None, eval_metric=None, + early_stopping_rounds=None, verbose=False, xgb_model=None, + callbacks=None): # pylint: disable = attribute-defined-outside-init,arguments-differ """ Fit gradient boosting ranker @@ -1132,11 +1132,13 @@ def _dmat_init(group, **params): return ret if sample_weight is not None: - train_dmatrix = _dmat_init(group, data=X, label=y, weight=sample_weight, - missing=self.missing, nthread=self.n_jobs) + train_dmatrix = _dmat_init( + group, data=X, label=y, weight=sample_weight, + missing=self.missing, nthread=self.n_jobs) else: - train_dmatrix = _dmat_init(group, data=X, label=y, - missing=self.missing, nthread=self.n_jobs) + train_dmatrix = _dmat_init( + group, data=X, label=y, + missing=self.missing, nthread=self.n_jobs) evals_result = {} diff --git a/src/learner.cc b/src/learner.cc index c43ae744bf83..b427405e0d6d 100644 --- a/src/learner.cc +++ b/src/learner.cc @@ -601,7 +601,7 @@ class LearnerImpl : public Learner { gbm_->Configure(args); if (this->gbm_->UseGPU()) { - if (cfg_.find("gpu_id") == cfg_.cend()) { + if (generic_param_.gpu_id == -1) { generic_param_.gpu_id = 0; } } diff --git a/tests/python-gpu/test_gpu_with_sklearn.py b/tests/python-gpu/test_gpu_with_sklearn.py new file mode 100644 index 000000000000..fd465b4194f9 --- /dev/null +++ b/tests/python-gpu/test_gpu_with_sklearn.py @@ -0,0 +1,31 @@ +import xgboost as xgb +import pytest +import sys +import numpy as np + +sys.path.append("tests/python") +import testing as tm + +pytestmark = pytest.mark.skipif(**tm.no_sklearn()) + +rng = np.random.RandomState(1994) + + +def test_gpu_binary_classification(): + from sklearn.datasets import load_digits + from sklearn.model_selection import KFold + + digits = load_digits(2) + y = digits['target'] + X = digits['data'] + kf = KFold(n_splits=2, shuffle=True, random_state=rng) + for cls in (xgb.XGBClassifier, xgb.XGBRFClassifier): + for train_index, test_index in kf.split(X, y): + xgb_model = cls( + random_state=42, tree_method='gpu_hist', + n_estimators=4, gpu_id='0').fit(X[train_index], y[train_index]) + preds = xgb_model.predict(X[test_index]) + labels = y[test_index] + err = sum(1 for i in range(len(preds)) + if int(preds[i] > 0.5) != labels[i]) / float(len(preds)) + assert err < 0.1 diff --git a/tests/python/test_with_sklearn.py b/tests/python/test_with_sklearn.py index 013d4195199a..43c1f2767b01 100644 --- a/tests/python/test_with_sklearn.py +++ b/tests/python/test_with_sklearn.py @@ -175,6 +175,21 @@ def test_feature_importances_gain(): np.testing.assert_almost_equal(xgb_model.feature_importances_, exp) +def test_num_parallel_tree(): + from sklearn.datasets import load_boston + reg = xgb.XGBRegressor(n_estimators=4, num_parallel_tree=4, + tree_method='hist') + boston = load_boston() + bst = reg.fit(X=boston['data'], y=boston['target']) + dump = bst.get_booster().get_dump(dump_format='json') + assert len(dump) == 16 + + reg = xgb.XGBRFRegressor(n_estimators=4) + bst = reg.fit(X=boston['data'], y=boston['target']) + dump = bst.get_booster().get_dump(dump_format='json') + assert len(dump) == 4 + + def test_boston_housing_regression(): from sklearn.metrics import mean_squared_error from sklearn.datasets import load_boston @@ -430,18 +445,18 @@ def test_split_value_histograms(): def test_sklearn_random_state(): clf = xgb.XGBClassifier(random_state=402) - assert clf.get_xgb_params()['seed'] == 402 + assert clf.get_xgb_params()['random_state'] == 402 - clf = xgb.XGBClassifier(seed=401) - assert clf.get_xgb_params()['seed'] == 401 + clf = xgb.XGBClassifier(random_state=401) + assert clf.get_xgb_params()['random_state'] == 401 def test_sklearn_n_jobs(): clf = xgb.XGBClassifier(n_jobs=1) - assert clf.get_xgb_params()['nthread'] == 1 + assert clf.get_xgb_params()['n_jobs'] == 1 - clf = xgb.XGBClassifier(nthread=2) - assert clf.get_xgb_params()['nthread'] == 2 + clf = xgb.XGBClassifier(n_jobs=2) + assert clf.get_xgb_params()['n_jobs'] == 2 def test_kwargs(): @@ -482,7 +497,7 @@ def test_kwargs_error(): def test_sklearn_clone(): from sklearn.base import clone - clf = xgb.XGBClassifier(n_jobs=2, nthread=3) + clf = xgb.XGBClassifier(n_jobs=2) clf.n_jobs = -1 clone(clf)