diff --git a/python-package/lightgbm/engine.py b/python-package/lightgbm/engine.py index 65eb598ea..5a1341e58 100644 --- a/python-package/lightgbm/engine.py +++ b/python-package/lightgbm/engine.py @@ -208,7 +208,7 @@ def train(params, train_data, num_boost_round=100, booster.best_iteration = int(booster.attr('best_iteration')) else: booster.best_iteration = num_boost_round - 1 - return num_boost_round + return booster class CVBooster(object): diff --git a/python-package/lightgbm/sklearn.py b/python-package/lightgbm/sklearn.py index b73a9be21..3d996a128 100644 --- a/python-package/lightgbm/sklearn.py +++ b/python-package/lightgbm/sklearn.py @@ -207,6 +207,9 @@ class LGBMModel(LGBMModelBase): evals_result = {} params = self.get_params() + if other_params is not None: + params.update(other_params) + if callable(self.objective): fobj = _objective_decorator(self.objective) params["objective"] = "None" @@ -215,14 +218,13 @@ class LGBMModel(LGBMModelBase): fobj = None if callable(eval_metric): feval = eval_metric - else: + elif is_str(eval_metric) or isinstance(eval_metric, list): feval = None params.update({'metric': eval_metric}) + else: + feval = None feval = eval_metric if callable(eval_metric) else None - if other_params is not None: - params.update(other_params) - self._Booster = train(params, (X, y), self.n_estimators, valid_datas=eval_set, early_stopping_rounds=early_stopping_rounds, @@ -296,10 +298,12 @@ class LGBMClassifier(LGBMModel, LGBMClassifierBase): other_params = {} if self.n_classes_ > 2: # Switch to using a multiclass objective in the underlying LGBM instance - other_params["objective"] = "multiclass" + if not callable(self.objective): + self.objective = "multiclass" other_params['num_class'] = self.n_classes_ else: - other_params["objective"] = "binary" + if not callable(self.objective): + self.objective = "binary" self._le = LGBMLabelEncoder().fit(y) training_labels = self._le.transform(y) diff --git a/tests/python_package_test/test_sklearn.py b/tests/python_package_test/test_sklearn.py new file mode 100644 index 000000000..3c57deec2 --- /dev/null +++ b/tests/python_package_test/test_sklearn.py @@ -0,0 +1,103 @@ +import numpy as np +import random +import lightgbm as lgb + + +rng = np.random.RandomState(2016) + +def test_binary_classification(): + + from sklearn import datasets, metrics, model_selection + + X, y = datasets.make_classification(n_samples=10000, n_features=100) + x_train, x_test, y_train, y_test = model_selection.train_test_split(X, y, test_size=0.1) + lgb_model = lgb.LGBMClassifier().fit(x_train, y_train, eval_set=[[x_train, y_train],(x_test, y_test)], eval_metric='binary_logloss') + from sklearn.datasets import load_digits + digits = load_digits(2) + y = digits['target'] + X = digits['data'] + x_train, x_test, y_train, y_test = model_selection.train_test_split(X, y, test_size=0.2) + lgb_model = lgb.LGBMClassifier().fit(x_train, y_train, eval_set=[[x_train, y_train],(x_test, y_test)], eval_metric='binary_logloss') + preds = lgb_model.predict(x_test) + err = sum(1 for i in range(len(preds)) + if int(preds[i] > 0.5) != y_test[i]) / float(len(preds)) + assert err < 0.1 + +def test_multiclass_classification(): + from sklearn.datasets import load_iris + from sklearn import datasets, metrics, model_selection + + def check_pred(preds, labels): + err = sum(1 for i in range(len(preds)) + if int(preds[i] > 0.5) != labels[i]) / float(len(preds)) + assert err < 0.7 + + + X, y = datasets.make_classification(n_samples=10000, n_features=100, n_classes=4, n_informative=3) + + x_train, x_test, y_train, y_test = model_selection.train_test_split(X, y, test_size=0.1) + + lgb_model = lgb.LGBMClassifier().fit(x_train, y_train,eval_set=[[x_train, y_train],(x_test, y_test)], eval_metric='multi_logloss') + preds = lgb_model.predict(x_test) + + check_pred(preds, y_test) + +def test_regression(): + from sklearn.metrics import mean_squared_error + from sklearn.datasets import load_boston + from sklearn.cross_validation import KFold + from sklearn import datasets, metrics, model_selection + + boston = load_boston() + y = boston['target'] + X = boston['data'] + x_train, x_test, y_train, y_test = model_selection.train_test_split(X, y, test_size=0.1) + lgb_model = lgb.LGBMRegressor().fit(x_train, y_train,eval_set=[[x_train, y_train],(x_test, y_test)], eval_metric='l2') + preds = lgb_model.predict(x_test) + assert mean_squared_error(preds, y_test) < 30 + +def test_regression_with_custom_objective(): + from sklearn.metrics import mean_squared_error + from sklearn.datasets import load_boston + from sklearn.cross_validation import KFold + from sklearn import datasets, metrics, model_selection + def objective_ls(y_true, y_pred): + grad = (y_pred - y_true) + hess = np.ones(len(y_true)) + return grad, hess + boston = load_boston() + y = boston['target'] + X = boston['data'] + x_train, x_test, y_train, y_test = model_selection.train_test_split(X, y, test_size=0.1) + lgb_model = lgb.LGBMRegressor(objective=objective_ls).fit(x_train, y_train,eval_set=[[x_train, y_train],(x_test, y_test)], eval_metric='l2') + preds = lgb_model.predict(x_test) + assert mean_squared_error(preds, y_test) < 30 + + +def test_binary_classification_with_custom_objective(): + + from sklearn import datasets, metrics, model_selection + def logregobj(y_true, y_pred): + y_pred = 1.0 / (1.0 + np.exp(-y_pred)) + grad = y_pred - y_true + hess = y_pred * (1.0 - y_pred) + return grad, hess + X, y = datasets.make_classification(n_samples=10000, n_features=100) + x_train, x_test, y_train, y_test = model_selection.train_test_split(X, y, test_size=0.1) + lgb_model = lgb.LGBMClassifier(objective=logregobj).fit(x_train, y_train, eval_set=[[x_train, y_train],(x_test, y_test)], eval_metric='binary_logloss') + from sklearn.datasets import load_digits + digits = load_digits(2) + y = digits['target'] + X = digits['data'] + x_train, x_test, y_train, y_test = model_selection.train_test_split(X, y, test_size=0.2) + lgb_model = lgb.LGBMClassifier(objective=logregobj).fit(x_train, y_train, eval_set=[[x_train, y_train],(x_test, y_test)], eval_metric='binary_logloss') + preds = lgb_model.predict(x_test) + err = sum(1 for i in range(len(preds)) + if int(preds[i] > 0.5) != y_test[i]) / float(len(preds)) + assert err < 0.1 + +test_binary_classification() +test_multiclass_classification() +test_regression() +test_regression_with_custom_objective() +test_binary_classification_with_custom_objective() \ No newline at end of file