зеркало из https://github.com/microsoft/LightGBM.git
add more tests
This commit is contained in:
Родитель
44fcf16c6f
Коммит
452b41f015
|
@ -208,7 +208,7 @@ def train(params, train_data, num_boost_round=100,
|
|||
booster.best_iteration = int(booster.attr('best_iteration'))
|
||||
else:
|
||||
booster.best_iteration = num_boost_round - 1
|
||||
return num_boost_round
|
||||
return booster
|
||||
|
||||
|
||||
class CVBooster(object):
|
||||
|
|
|
@ -207,6 +207,9 @@ class LGBMModel(LGBMModelBase):
|
|||
evals_result = {}
|
||||
params = self.get_params()
|
||||
|
||||
if other_params is not None:
|
||||
params.update(other_params)
|
||||
|
||||
if callable(self.objective):
|
||||
fobj = _objective_decorator(self.objective)
|
||||
params["objective"] = "None"
|
||||
|
@ -215,14 +218,13 @@ class LGBMModel(LGBMModelBase):
|
|||
fobj = None
|
||||
if callable(eval_metric):
|
||||
feval = eval_metric
|
||||
else:
|
||||
elif is_str(eval_metric) or isinstance(eval_metric, list):
|
||||
feval = None
|
||||
params.update({'metric': eval_metric})
|
||||
else:
|
||||
feval = None
|
||||
feval = eval_metric if callable(eval_metric) else None
|
||||
|
||||
if other_params is not None:
|
||||
params.update(other_params)
|
||||
|
||||
self._Booster = train(params, (X, y),
|
||||
self.n_estimators, valid_datas=eval_set,
|
||||
early_stopping_rounds=early_stopping_rounds,
|
||||
|
@ -296,10 +298,12 @@ class LGBMClassifier(LGBMModel, LGBMClassifierBase):
|
|||
other_params = {}
|
||||
if self.n_classes_ > 2:
|
||||
# Switch to using a multiclass objective in the underlying LGBM instance
|
||||
other_params["objective"] = "multiclass"
|
||||
if not callable(self.objective):
|
||||
self.objective = "multiclass"
|
||||
other_params['num_class'] = self.n_classes_
|
||||
else:
|
||||
other_params["objective"] = "binary"
|
||||
if not callable(self.objective):
|
||||
self.objective = "binary"
|
||||
|
||||
self._le = LGBMLabelEncoder().fit(y)
|
||||
training_labels = self._le.transform(y)
|
||||
|
|
|
@ -0,0 +1,103 @@
|
|||
import numpy as np
|
||||
import random
|
||||
import lightgbm as lgb
|
||||
|
||||
|
||||
rng = np.random.RandomState(2016)
|
||||
|
||||
def test_binary_classification():
|
||||
|
||||
from sklearn import datasets, metrics, model_selection
|
||||
|
||||
X, y = datasets.make_classification(n_samples=10000, n_features=100)
|
||||
x_train, x_test, y_train, y_test = model_selection.train_test_split(X, y, test_size=0.1)
|
||||
lgb_model = lgb.LGBMClassifier().fit(x_train, y_train, eval_set=[[x_train, y_train],(x_test, y_test)], eval_metric='binary_logloss')
|
||||
from sklearn.datasets import load_digits
|
||||
digits = load_digits(2)
|
||||
y = digits['target']
|
||||
X = digits['data']
|
||||
x_train, x_test, y_train, y_test = model_selection.train_test_split(X, y, test_size=0.2)
|
||||
lgb_model = lgb.LGBMClassifier().fit(x_train, y_train, eval_set=[[x_train, y_train],(x_test, y_test)], eval_metric='binary_logloss')
|
||||
preds = lgb_model.predict(x_test)
|
||||
err = sum(1 for i in range(len(preds))
|
||||
if int(preds[i] > 0.5) != y_test[i]) / float(len(preds))
|
||||
assert err < 0.1
|
||||
|
||||
def test_multiclass_classification():
|
||||
from sklearn.datasets import load_iris
|
||||
from sklearn import datasets, metrics, model_selection
|
||||
|
||||
def check_pred(preds, labels):
|
||||
err = sum(1 for i in range(len(preds))
|
||||
if int(preds[i] > 0.5) != labels[i]) / float(len(preds))
|
||||
assert err < 0.7
|
||||
|
||||
|
||||
X, y = datasets.make_classification(n_samples=10000, n_features=100, n_classes=4, n_informative=3)
|
||||
|
||||
x_train, x_test, y_train, y_test = model_selection.train_test_split(X, y, test_size=0.1)
|
||||
|
||||
lgb_model = lgb.LGBMClassifier().fit(x_train, y_train,eval_set=[[x_train, y_train],(x_test, y_test)], eval_metric='multi_logloss')
|
||||
preds = lgb_model.predict(x_test)
|
||||
|
||||
check_pred(preds, y_test)
|
||||
|
||||
def test_regression():
|
||||
from sklearn.metrics import mean_squared_error
|
||||
from sklearn.datasets import load_boston
|
||||
from sklearn.cross_validation import KFold
|
||||
from sklearn import datasets, metrics, model_selection
|
||||
|
||||
boston = load_boston()
|
||||
y = boston['target']
|
||||
X = boston['data']
|
||||
x_train, x_test, y_train, y_test = model_selection.train_test_split(X, y, test_size=0.1)
|
||||
lgb_model = lgb.LGBMRegressor().fit(x_train, y_train,eval_set=[[x_train, y_train],(x_test, y_test)], eval_metric='l2')
|
||||
preds = lgb_model.predict(x_test)
|
||||
assert mean_squared_error(preds, y_test) < 30
|
||||
|
||||
def test_regression_with_custom_objective():
|
||||
from sklearn.metrics import mean_squared_error
|
||||
from sklearn.datasets import load_boston
|
||||
from sklearn.cross_validation import KFold
|
||||
from sklearn import datasets, metrics, model_selection
|
||||
def objective_ls(y_true, y_pred):
|
||||
grad = (y_pred - y_true)
|
||||
hess = np.ones(len(y_true))
|
||||
return grad, hess
|
||||
boston = load_boston()
|
||||
y = boston['target']
|
||||
X = boston['data']
|
||||
x_train, x_test, y_train, y_test = model_selection.train_test_split(X, y, test_size=0.1)
|
||||
lgb_model = lgb.LGBMRegressor(objective=objective_ls).fit(x_train, y_train,eval_set=[[x_train, y_train],(x_test, y_test)], eval_metric='l2')
|
||||
preds = lgb_model.predict(x_test)
|
||||
assert mean_squared_error(preds, y_test) < 30
|
||||
|
||||
|
||||
def test_binary_classification_with_custom_objective():
|
||||
|
||||
from sklearn import datasets, metrics, model_selection
|
||||
def logregobj(y_true, y_pred):
|
||||
y_pred = 1.0 / (1.0 + np.exp(-y_pred))
|
||||
grad = y_pred - y_true
|
||||
hess = y_pred * (1.0 - y_pred)
|
||||
return grad, hess
|
||||
X, y = datasets.make_classification(n_samples=10000, n_features=100)
|
||||
x_train, x_test, y_train, y_test = model_selection.train_test_split(X, y, test_size=0.1)
|
||||
lgb_model = lgb.LGBMClassifier(objective=logregobj).fit(x_train, y_train, eval_set=[[x_train, y_train],(x_test, y_test)], eval_metric='binary_logloss')
|
||||
from sklearn.datasets import load_digits
|
||||
digits = load_digits(2)
|
||||
y = digits['target']
|
||||
X = digits['data']
|
||||
x_train, x_test, y_train, y_test = model_selection.train_test_split(X, y, test_size=0.2)
|
||||
lgb_model = lgb.LGBMClassifier(objective=logregobj).fit(x_train, y_train, eval_set=[[x_train, y_train],(x_test, y_test)], eval_metric='binary_logloss')
|
||||
preds = lgb_model.predict(x_test)
|
||||
err = sum(1 for i in range(len(preds))
|
||||
if int(preds[i] > 0.5) != y_test[i]) / float(len(preds))
|
||||
assert err < 0.1
|
||||
|
||||
test_binary_classification()
|
||||
test_multiclass_classification()
|
||||
test_regression()
|
||||
test_regression_with_custom_objective()
|
||||
test_binary_classification_with_custom_objective()
|
Загрузка…
Ссылка в новой задаче