[python] bring pandas support to the sklearn wrapper back (#904)

* added test for sklearn handle categorical features

* use raw X, y in sklearn wrapper in case of pandas.DataFrame

* fixed probs
This commit is contained in:
Nikita Titov 2017-09-19 09:55:03 +03:00 коммит произвёл Guolin Ke
Родитель 7689a4d064
Коммит 0350a9a6ff
3 изменённых файлов: 54 добавлений и 9 удалений

Просмотреть файл

@ -5,6 +5,11 @@ from __future__ import absolute_import
import numpy as np
import warnings
try:
import pandas as pd
_IS_PANDAS_INSTALLED = True
except ImportError:
_IS_PANDAS_INSTALLED = False
from .basic import Dataset, LightGBMError
from .compat import (SKLEARN_INSTALLED, _LGBMClassifierBase, LGBMDeprecated,
@ -332,7 +337,7 @@ class LGBMModel(_LGBMModelBase):
categorical_feature : list of strings or int, or 'auto', optional (default="auto")
Categorical features.
If list of int, interpreted as indices.
If list of strings, interpreted as feature names (need to specify feature_name as well).
If list of strings, interpreted as feature names (need to specify ``feature_name`` as well).
If 'auto' and data is pandas DataFrame, pandas categorical columns are used.
callbacks : list of callback functions or None, optional (default=None)
List of callback functions that are applied at each iteration.
@ -407,8 +412,10 @@ class LGBMModel(_LGBMModelBase):
feval = None
params['metric'] = eval_metric
X, y = _LGBMCheckXY(X, y, accept_sparse=True, force_all_finite=False, ensure_min_samples=2)
_LGBMCheckConsistentLength(X, y, sample_weight)
if not _IS_PANDAS_INSTALLED or not isinstance(X, pd.DataFrame):
X, y = _LGBMCheckXY(X, y, accept_sparse=True, force_all_finite=False, ensure_min_samples=2)
_LGBMCheckConsistentLength(X, y, sample_weight)
self._n_features = X.shape[1]
def _construct_dataset(X, y, sample_weight, init_score, group, params):
@ -482,7 +489,8 @@ class LGBMModel(_LGBMModelBase):
"""
if self._n_features is None:
raise LGBMNotFittedError("Estimator not fitted, call `fit` before exploiting the model.")
X = _LGBMCheckArray(X, accept_sparse=True, force_all_finite=False)
if not _IS_PANDAS_INSTALLED or not isinstance(X, pd.DataFrame):
X = _LGBMCheckArray(X, accept_sparse=True, force_all_finite=False)
n_features = X.shape[1]
if self._n_features != n_features:
raise ValueError("Number of features of the model must "
@ -508,7 +516,8 @@ class LGBMModel(_LGBMModelBase):
"""
if self._n_features is None:
raise LGBMNotFittedError("Estimator not fitted, call `fit` before exploiting the model.")
X = _LGBMCheckArray(X, accept_sparse=True, force_all_finite=False)
if not _IS_PANDAS_INSTALLED or not isinstance(X, pd.DataFrame):
X = _LGBMCheckArray(X, accept_sparse=True, force_all_finite=False)
n_features = X.shape[1]
if self._n_features != n_features:
raise ValueError("Number of features of the model must "
@ -686,7 +695,8 @@ class LGBMClassifier(LGBMModel, _LGBMClassifierBase):
"""
if self._n_features is None:
raise LGBMNotFittedError("Estimator not fitted, call `fit` before exploiting the model.")
X = _LGBMCheckArray(X, accept_sparse=True, force_all_finite=False)
if not _IS_PANDAS_INSTALLED or not isinstance(X, pd.DataFrame):
X = _LGBMCheckArray(X, accept_sparse=True, force_all_finite=False)
n_features = X.shape[1]
if self._n_features != n_features:
raise ValueError("Number of features of the model must "

Просмотреть файл

@ -78,7 +78,7 @@ class TestEngine(unittest.TestCase):
self.assertLess(ret, 0.25)
self.assertAlmostEqual(evals_result['valid_0']['binary_logloss'][-1], ret, places=5)
def test_regreesion(self):
def test_regression(self):
X, y = load_boston(True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)
params = {
@ -444,7 +444,6 @@ class TestEngine(unittest.TestCase):
gbm3 = lgb.train(params, lgb_train, num_boost_round=10, verbose_eval=False,
categorical_feature=['A', 'B', 'C', 'D'])
pred3 = list(gbm3.predict(X_test))
lgb_train = lgb.Dataset(X, y)
gbm3.save_model('categorical.model')
gbm4 = lgb.Booster(model_file='categorical.model')
pred4 = list(gbm4.predict(X_test))

Просмотреть файл

@ -19,6 +19,11 @@ try:
sklearn_at_least_019 = True
except ImportError:
sklearn_at_least_019 = False
try:
import pandas as pd
IS_PANDAS_INSTALLED = True
except ImportError:
IS_PANDAS_INSTALLED = False
def multi_error(y_true, y_pred):
@ -40,7 +45,7 @@ class TestSklearn(unittest.TestCase):
self.assertLess(ret, 0.15)
self.assertAlmostEqual(ret, gbm.evals_result_['valid_0']['binary_logloss'][gbm.best_iteration_ - 1], places=5)
def test_regreesion(self):
def test_regression(self):
X, y = load_boston(True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)
gbm = lgb.LGBMRegressor(n_estimators=50, silent=True)
@ -194,3 +199,34 @@ class TestSklearn(unittest.TestCase):
check(name, estimator)
except SkipTest as message:
warnings.warn(message, SkipTestWarning)
@unittest.skipIf(not IS_PANDAS_INSTALLED, 'pandas not installed')
def test_pandas_categorical(self):
X = pd.DataFrame({"A": np.random.permutation(['a', 'b', 'c', 'd'] * 75), # str
"B": np.random.permutation([1, 2, 3] * 100), # int
"C": np.random.permutation([0.1, 0.2, -0.1, -0.1, 0.2] * 60), # float
"D": np.random.permutation([True, False] * 150)}) # bool
y = np.random.permutation([0, 1] * 150)
X_test = pd.DataFrame({"A": np.random.permutation(['a', 'b', 'e'] * 20),
"B": np.random.permutation([1, 3] * 30),
"C": np.random.permutation([0.1, -0.1, 0.2, 0.2] * 15),
"D": np.random.permutation([True, False] * 30)})
for col in ["A", "B", "C", "D"]:
X[col] = X[col].astype('category')
X_test[col] = X_test[col].astype('category')
gbm0 = lgb.sklearn.LGBMClassifier().fit(X, y)
pred0 = list(gbm0.predict(X_test))
gbm1 = lgb.sklearn.LGBMClassifier().fit(X, y, categorical_feature=[0])
pred1 = list(gbm1.predict(X_test))
gbm2 = lgb.sklearn.LGBMClassifier().fit(X, y, categorical_feature=['A'])
pred2 = list(gbm2.predict(X_test))
gbm3 = lgb.sklearn.LGBMClassifier().fit(X, y, categorical_feature=['A', 'B', 'C', 'D'])
pred3 = list(gbm3.predict(X_test))
gbm3.booster_.save_model('categorical.model')
gbm4 = lgb.Booster(model_file='categorical.model')
pred4 = list(gbm4.predict(X_test))
pred_prob = list(gbm0.predict_proba(X_test)[:, 1])
np.testing.assert_almost_equal(pred0, pred1)
np.testing.assert_almost_equal(pred0, pred2)
np.testing.assert_almost_equal(pred0, pred3)
np.testing.assert_almost_equal(pred_prob, pred4)