This commit is contained in:
Guolin Ke 2016-11-30 11:40:31 +08:00
Родитель 44fcf16c6f
Коммит 452b41f015
3 изменённых файлов: 114 добавлений и 7 удалений

Просмотреть файл

@ -208,7 +208,7 @@ def train(params, train_data, num_boost_round=100,
booster.best_iteration = int(booster.attr('best_iteration'))
else:
booster.best_iteration = num_boost_round - 1
return num_boost_round
return booster
class CVBooster(object):

Просмотреть файл

@ -207,6 +207,9 @@ class LGBMModel(LGBMModelBase):
evals_result = {}
params = self.get_params()
if other_params is not None:
params.update(other_params)
if callable(self.objective):
fobj = _objective_decorator(self.objective)
params["objective"] = "None"
@ -215,14 +218,13 @@ class LGBMModel(LGBMModelBase):
fobj = None
if callable(eval_metric):
feval = eval_metric
else:
elif is_str(eval_metric) or isinstance(eval_metric, list):
feval = None
params.update({'metric': eval_metric})
else:
feval = None
feval = eval_metric if callable(eval_metric) else None
if other_params is not None:
params.update(other_params)
self._Booster = train(params, (X, y),
self.n_estimators, valid_datas=eval_set,
early_stopping_rounds=early_stopping_rounds,
@ -296,10 +298,12 @@ class LGBMClassifier(LGBMModel, LGBMClassifierBase):
other_params = {}
if self.n_classes_ > 2:
# Switch to using a multiclass objective in the underlying LGBM instance
other_params["objective"] = "multiclass"
if not callable(self.objective):
self.objective = "multiclass"
other_params['num_class'] = self.n_classes_
else:
other_params["objective"] = "binary"
if not callable(self.objective):
self.objective = "binary"
self._le = LGBMLabelEncoder().fit(y)
training_labels = self._le.transform(y)

Просмотреть файл

@ -0,0 +1,103 @@
import numpy as np
import random
import lightgbm as lgb
rng = np.random.RandomState(2016)
def test_binary_classification():
from sklearn import datasets, metrics, model_selection
X, y = datasets.make_classification(n_samples=10000, n_features=100)
x_train, x_test, y_train, y_test = model_selection.train_test_split(X, y, test_size=0.1)
lgb_model = lgb.LGBMClassifier().fit(x_train, y_train, eval_set=[[x_train, y_train],(x_test, y_test)], eval_metric='binary_logloss')
from sklearn.datasets import load_digits
digits = load_digits(2)
y = digits['target']
X = digits['data']
x_train, x_test, y_train, y_test = model_selection.train_test_split(X, y, test_size=0.2)
lgb_model = lgb.LGBMClassifier().fit(x_train, y_train, eval_set=[[x_train, y_train],(x_test, y_test)], eval_metric='binary_logloss')
preds = lgb_model.predict(x_test)
err = sum(1 for i in range(len(preds))
if int(preds[i] > 0.5) != y_test[i]) / float(len(preds))
assert err < 0.1
def test_multiclass_classification():
from sklearn.datasets import load_iris
from sklearn import datasets, metrics, model_selection
def check_pred(preds, labels):
err = sum(1 for i in range(len(preds))
if int(preds[i] > 0.5) != labels[i]) / float(len(preds))
assert err < 0.7
X, y = datasets.make_classification(n_samples=10000, n_features=100, n_classes=4, n_informative=3)
x_train, x_test, y_train, y_test = model_selection.train_test_split(X, y, test_size=0.1)
lgb_model = lgb.LGBMClassifier().fit(x_train, y_train,eval_set=[[x_train, y_train],(x_test, y_test)], eval_metric='multi_logloss')
preds = lgb_model.predict(x_test)
check_pred(preds, y_test)
def test_regression():
from sklearn.metrics import mean_squared_error
from sklearn.datasets import load_boston
from sklearn.cross_validation import KFold
from sklearn import datasets, metrics, model_selection
boston = load_boston()
y = boston['target']
X = boston['data']
x_train, x_test, y_train, y_test = model_selection.train_test_split(X, y, test_size=0.1)
lgb_model = lgb.LGBMRegressor().fit(x_train, y_train,eval_set=[[x_train, y_train],(x_test, y_test)], eval_metric='l2')
preds = lgb_model.predict(x_test)
assert mean_squared_error(preds, y_test) < 30
def test_regression_with_custom_objective():
from sklearn.metrics import mean_squared_error
from sklearn.datasets import load_boston
from sklearn.cross_validation import KFold
from sklearn import datasets, metrics, model_selection
def objective_ls(y_true, y_pred):
grad = (y_pred - y_true)
hess = np.ones(len(y_true))
return grad, hess
boston = load_boston()
y = boston['target']
X = boston['data']
x_train, x_test, y_train, y_test = model_selection.train_test_split(X, y, test_size=0.1)
lgb_model = lgb.LGBMRegressor(objective=objective_ls).fit(x_train, y_train,eval_set=[[x_train, y_train],(x_test, y_test)], eval_metric='l2')
preds = lgb_model.predict(x_test)
assert mean_squared_error(preds, y_test) < 30
def test_binary_classification_with_custom_objective():
from sklearn import datasets, metrics, model_selection
def logregobj(y_true, y_pred):
y_pred = 1.0 / (1.0 + np.exp(-y_pred))
grad = y_pred - y_true
hess = y_pred * (1.0 - y_pred)
return grad, hess
X, y = datasets.make_classification(n_samples=10000, n_features=100)
x_train, x_test, y_train, y_test = model_selection.train_test_split(X, y, test_size=0.1)
lgb_model = lgb.LGBMClassifier(objective=logregobj).fit(x_train, y_train, eval_set=[[x_train, y_train],(x_test, y_test)], eval_metric='binary_logloss')
from sklearn.datasets import load_digits
digits = load_digits(2)
y = digits['target']
X = digits['data']
x_train, x_test, y_train, y_test = model_selection.train_test_split(X, y, test_size=0.2)
lgb_model = lgb.LGBMClassifier(objective=logregobj).fit(x_train, y_train, eval_set=[[x_train, y_train],(x_test, y_test)], eval_metric='binary_logloss')
preds = lgb_model.predict(x_test)
err = sum(1 for i in range(len(preds))
if int(preds[i] > 0.5) != y_test[i]) / float(len(preds))
assert err < 0.1
test_binary_classification()
test_multiclass_classification()
test_regression()
test_regression_with_custom_objective()
test_binary_classification_with_custom_objective()