зеркало из https://github.com/microsoft/LightGBM.git
refine early stopping and add a test case (#369)
This commit is contained in:
Родитель
1141ed9de0
Коммит
6ed335df29
|
@ -201,7 +201,6 @@ def early_stopping(stopping_rounds, verbose=True):
|
|||
env.iteration + 1, '\t'.join([_format_eval_result(x) for x in env.evaluation_result_list]))
|
||||
best_msg[i] = best_msg_buffer
|
||||
elif env.iteration - best_iter[i] >= stopping_rounds:
|
||||
env.model.set_attr(best_iteration=str(best_iter[i]))
|
||||
if verbose:
|
||||
print('Early stopping, best iteration is:\n' + best_msg[i])
|
||||
raise EarlyStopException(best_iter[i])
|
||||
|
|
|
@ -165,6 +165,7 @@ def train(params, train_set, num_boost_round=100,
|
|||
booster.set_train_data_name(train_data_name)
|
||||
for valid_set, name_valid_set in zip(reduced_valid_sets, name_valid_sets):
|
||||
booster.add_valid(valid_set, name_valid_set)
|
||||
booster.best_iteration = -1
|
||||
|
||||
"""start training"""
|
||||
for i in range_(init_iteration, init_iteration + num_boost_round):
|
||||
|
@ -192,12 +193,9 @@ def train(params, train_set, num_boost_round=100,
|
|||
begin_iteration=init_iteration,
|
||||
end_iteration=init_iteration + num_boost_round,
|
||||
evaluation_result_list=evaluation_result_list))
|
||||
except callback.EarlyStopException:
|
||||
except callback.EarlyStopException as earlyStopException:
|
||||
booster.best_iteration = earlyStopException.best_iteration + 1
|
||||
break
|
||||
if booster.attr('best_iteration') is not None:
|
||||
booster.best_iteration = int(booster.attr('best_iteration')) + 1
|
||||
else:
|
||||
booster.best_iteration = -1
|
||||
return booster
|
||||
|
||||
|
||||
|
@ -205,6 +203,7 @@ class CVBooster(object):
|
|||
""""Auxiliary data struct to hold all boosters of CV."""
|
||||
def __init__(self):
|
||||
self.boosters = []
|
||||
self.best_iteration = -1
|
||||
|
||||
def append(self, booster):
|
||||
"""add a booster to CVBooster"""
|
||||
|
@ -408,8 +407,9 @@ def cv(params, train_set, num_boost_round=10,
|
|||
begin_iteration=0,
|
||||
end_iteration=num_boost_round,
|
||||
evaluation_result_list=res))
|
||||
except callback.EarlyStopException as e:
|
||||
except callback.EarlyStopException as earlyStopException:
|
||||
cvfolds.best_iteration = earlyStopException.best_iteration + 1
|
||||
for k in results:
|
||||
results[k] = results[k][:e.best_iteration + 1]
|
||||
results[k] = results[k][:cvfolds.best_iteration]
|
||||
break
|
||||
return dict(results)
|
||||
|
|
|
@ -86,6 +86,32 @@ class TestEngine(unittest.TestCase):
|
|||
self.assertLess(ret, 0.2)
|
||||
self.assertAlmostEqual(min(evals_result['eval']['multi_logloss']), ret, places=5)
|
||||
|
||||
def test_early_stopping(self):
|
||||
X_y = load_breast_cancer(True)
|
||||
params = {
|
||||
'objective': 'binary',
|
||||
'metric': 'binary_logloss',
|
||||
'verbose': -1,
|
||||
'seed': 42
|
||||
}
|
||||
X_train, X_test, y_train, y_test = train_test_split(*X_y, test_size=0.1, random_state=42)
|
||||
lgb_train = lgb.Dataset(X_train, y_train)
|
||||
lgb_eval = lgb.Dataset(X_test, y_test, reference=lgb_train)
|
||||
# no early stopping
|
||||
gbm = lgb.train(params, lgb_train,
|
||||
num_boost_round=10,
|
||||
valid_sets=lgb_eval,
|
||||
verbose_eval=False,
|
||||
early_stopping_rounds=5)
|
||||
self.assertEqual(gbm.best_iteration, -1)
|
||||
# early stopping occurs
|
||||
gbm = lgb.train(params, lgb_train,
|
||||
num_boost_round=100,
|
||||
valid_sets=lgb_eval,
|
||||
verbose_eval=False,
|
||||
early_stopping_rounds=5)
|
||||
self.assertLessEqual(gbm.best_iteration, 100)
|
||||
|
||||
def test_continue_train_and_other(self):
|
||||
params = {
|
||||
'objective': 'regression',
|
||||
|
|
Загрузка…
Ссылка в новой задаче