This commit is contained in:
wxchan 2017-04-15 19:04:35 +08:00 коммит произвёл Guolin Ke
Родитель 9224a9d125
Коммит 45c1c6e8c1
5 изменённых файлов: 22 добавлений и 12 удалений

Просмотреть файл

@ -1179,6 +1179,7 @@ class Booster(object):
self.__train_data_name = "training"
self.__attr = {}
self.best_iteration = -1
self.best_score = {}
params = {} if params is None else params
if silent:
params["verbose"] = 0

Просмотреть файл

@ -15,9 +15,10 @@ class EarlyStopException(Exception):
best_iteration : int
The best iteration stopped.
"""
def __init__(self, best_iteration):
def __init__(self, best_iteration, best_score):
super(EarlyStopException, self).__init__()
self.best_iteration = best_iteration
self.best_score = best_score
# Callback environment used by callbacks
@ -162,7 +163,7 @@ def early_stopping(stopping_rounds, verbose=True):
"""
best_score = []
best_iter = []
best_msg = []
best_score_list = []
cmp_op = []
def init(env):
@ -176,8 +177,7 @@ def early_stopping(stopping_rounds, verbose=True):
for eval_ret in env.evaluation_result_list:
best_iter.append(0)
if verbose:
best_msg.append(None)
best_score_list.append(None)
if eval_ret[3]:
best_score.append(float('-inf'))
cmp_op.append(gt)
@ -189,20 +189,16 @@ def early_stopping(stopping_rounds, verbose=True):
"""internal function"""
if not cmp_op:
init(env)
best_msg_buffer = None
for i in range_(len(env.evaluation_result_list)):
score = env.evaluation_result_list[i][2]
if cmp_op[i](score, best_score[i]):
best_score[i] = score
best_iter[i] = env.iteration
if verbose:
if not best_msg_buffer:
best_msg_buffer = '[%d]\t%s' % (
env.iteration + 1, '\t'.join([_format_eval_result(x) for x in env.evaluation_result_list]))
best_msg[i] = best_msg_buffer
best_score_list[i] = env.evaluation_result_list
elif env.iteration - best_iter[i] >= stopping_rounds:
if verbose:
print('Early stopping, best iteration is:\n' + best_msg[i])
raise EarlyStopException(best_iter[i])
print('Early stopping, best iteration is:\n[%d]\t%s' % (
best_iter[i] + 1, '\t'.join([_format_eval_result(x) for x in best_score_list[i]])))
raise EarlyStopException(best_iter[i], best_score_list[i])
callback.order = 30
return callback

Просмотреть файл

@ -195,7 +195,11 @@ def train(params, train_set, num_boost_round=100,
evaluation_result_list=evaluation_result_list))
except callback.EarlyStopException as earlyStopException:
booster.best_iteration = earlyStopException.best_iteration + 1
evaluation_result_list = earlyStopException.best_score
break
booster.best_score = collections.defaultdict(dict)
for dataset_name, eval_name, score, _ in evaluation_result_list:
booster.best_score[dataset_name][eval_name] = score
return booster

Просмотреть файл

@ -273,6 +273,7 @@ class LGBMModel(LGBMModelBase):
self._Booster = None
self.evals_result = None
self.best_iteration = -1
self.best_score = {}
if callable(self.objective):
self.fobj = _objective_function_wrapper(self.objective)
else:
@ -414,6 +415,7 @@ class LGBMModel(LGBMModelBase):
if early_stopping_rounds is not None:
self.best_iteration = self._Booster.best_iteration
self.best_score = self._Booster.best_score
return self
def predict(self, X, raw_score=False, num_iteration=0):

Просмотреть файл

@ -96,20 +96,27 @@ class TestEngine(unittest.TestCase):
X_train, X_test, y_train, y_test = train_test_split(*X_y, test_size=0.1, random_state=42)
lgb_train = lgb.Dataset(X_train, y_train)
lgb_eval = lgb.Dataset(X_test, y_test, reference=lgb_train)
valid_set_name = 'valid_set'
# no early stopping
gbm = lgb.train(params, lgb_train,
num_boost_round=10,
valid_sets=lgb_eval,
valid_names=valid_set_name,
verbose_eval=False,
early_stopping_rounds=5)
self.assertEqual(gbm.best_iteration, -1)
self.assertIn(valid_set_name, gbm.best_score)
self.assertIn('binary_logloss', gbm.best_score[valid_set_name])
# early stopping occurs
gbm = lgb.train(params, lgb_train,
num_boost_round=100,
valid_sets=lgb_eval,
valid_names=valid_set_name,
verbose_eval=False,
early_stopping_rounds=5)
self.assertLessEqual(gbm.best_iteration, 100)
self.assertIn(valid_set_name, gbm.best_score)
self.assertIn('binary_logloss', gbm.best_score[valid_set_name])
def test_continue_train_and_other(self):
params = {