зеркало из https://github.com/microsoft/LightGBM.git
add best score (#413)
This commit is contained in:
Родитель
9224a9d125
Коммит
45c1c6e8c1
|
@ -1179,6 +1179,7 @@ class Booster(object):
|
|||
self.__train_data_name = "training"
|
||||
self.__attr = {}
|
||||
self.best_iteration = -1
|
||||
self.best_score = {}
|
||||
params = {} if params is None else params
|
||||
if silent:
|
||||
params["verbose"] = 0
|
||||
|
|
|
@ -15,9 +15,10 @@ class EarlyStopException(Exception):
|
|||
best_iteration : int
|
||||
The best iteration stopped.
|
||||
"""
|
||||
def __init__(self, best_iteration):
|
||||
def __init__(self, best_iteration, best_score):
|
||||
super(EarlyStopException, self).__init__()
|
||||
self.best_iteration = best_iteration
|
||||
self.best_score = best_score
|
||||
|
||||
|
||||
# Callback environment used by callbacks
|
||||
|
@ -162,7 +163,7 @@ def early_stopping(stopping_rounds, verbose=True):
|
|||
"""
|
||||
best_score = []
|
||||
best_iter = []
|
||||
best_msg = []
|
||||
best_score_list = []
|
||||
cmp_op = []
|
||||
|
||||
def init(env):
|
||||
|
@ -176,8 +177,7 @@ def early_stopping(stopping_rounds, verbose=True):
|
|||
|
||||
for eval_ret in env.evaluation_result_list:
|
||||
best_iter.append(0)
|
||||
if verbose:
|
||||
best_msg.append(None)
|
||||
best_score_list.append(None)
|
||||
if eval_ret[3]:
|
||||
best_score.append(float('-inf'))
|
||||
cmp_op.append(gt)
|
||||
|
@ -189,20 +189,16 @@ def early_stopping(stopping_rounds, verbose=True):
|
|||
"""internal function"""
|
||||
if not cmp_op:
|
||||
init(env)
|
||||
best_msg_buffer = None
|
||||
for i in range_(len(env.evaluation_result_list)):
|
||||
score = env.evaluation_result_list[i][2]
|
||||
if cmp_op[i](score, best_score[i]):
|
||||
best_score[i] = score
|
||||
best_iter[i] = env.iteration
|
||||
if verbose:
|
||||
if not best_msg_buffer:
|
||||
best_msg_buffer = '[%d]\t%s' % (
|
||||
env.iteration + 1, '\t'.join([_format_eval_result(x) for x in env.evaluation_result_list]))
|
||||
best_msg[i] = best_msg_buffer
|
||||
best_score_list[i] = env.evaluation_result_list
|
||||
elif env.iteration - best_iter[i] >= stopping_rounds:
|
||||
if verbose:
|
||||
print('Early stopping, best iteration is:\n' + best_msg[i])
|
||||
raise EarlyStopException(best_iter[i])
|
||||
print('Early stopping, best iteration is:\n[%d]\t%s' % (
|
||||
best_iter[i] + 1, '\t'.join([_format_eval_result(x) for x in best_score_list[i]])))
|
||||
raise EarlyStopException(best_iter[i], best_score_list[i])
|
||||
callback.order = 30
|
||||
return callback
|
||||
|
|
|
@ -195,7 +195,11 @@ def train(params, train_set, num_boost_round=100,
|
|||
evaluation_result_list=evaluation_result_list))
|
||||
except callback.EarlyStopException as earlyStopException:
|
||||
booster.best_iteration = earlyStopException.best_iteration + 1
|
||||
evaluation_result_list = earlyStopException.best_score
|
||||
break
|
||||
booster.best_score = collections.defaultdict(dict)
|
||||
for dataset_name, eval_name, score, _ in evaluation_result_list:
|
||||
booster.best_score[dataset_name][eval_name] = score
|
||||
return booster
|
||||
|
||||
|
||||
|
|
|
@ -273,6 +273,7 @@ class LGBMModel(LGBMModelBase):
|
|||
self._Booster = None
|
||||
self.evals_result = None
|
||||
self.best_iteration = -1
|
||||
self.best_score = {}
|
||||
if callable(self.objective):
|
||||
self.fobj = _objective_function_wrapper(self.objective)
|
||||
else:
|
||||
|
@ -414,6 +415,7 @@ class LGBMModel(LGBMModelBase):
|
|||
|
||||
if early_stopping_rounds is not None:
|
||||
self.best_iteration = self._Booster.best_iteration
|
||||
self.best_score = self._Booster.best_score
|
||||
return self
|
||||
|
||||
def predict(self, X, raw_score=False, num_iteration=0):
|
||||
|
|
|
@ -96,20 +96,27 @@ class TestEngine(unittest.TestCase):
|
|||
X_train, X_test, y_train, y_test = train_test_split(*X_y, test_size=0.1, random_state=42)
|
||||
lgb_train = lgb.Dataset(X_train, y_train)
|
||||
lgb_eval = lgb.Dataset(X_test, y_test, reference=lgb_train)
|
||||
valid_set_name = 'valid_set'
|
||||
# no early stopping
|
||||
gbm = lgb.train(params, lgb_train,
|
||||
num_boost_round=10,
|
||||
valid_sets=lgb_eval,
|
||||
valid_names=valid_set_name,
|
||||
verbose_eval=False,
|
||||
early_stopping_rounds=5)
|
||||
self.assertEqual(gbm.best_iteration, -1)
|
||||
self.assertIn(valid_set_name, gbm.best_score)
|
||||
self.assertIn('binary_logloss', gbm.best_score[valid_set_name])
|
||||
# early stopping occurs
|
||||
gbm = lgb.train(params, lgb_train,
|
||||
num_boost_round=100,
|
||||
valid_sets=lgb_eval,
|
||||
valid_names=valid_set_name,
|
||||
verbose_eval=False,
|
||||
early_stopping_rounds=5)
|
||||
self.assertLessEqual(gbm.best_iteration, 100)
|
||||
self.assertIn(valid_set_name, gbm.best_score)
|
||||
self.assertIn('binary_logloss', gbm.best_score[valid_set_name])
|
||||
|
||||
def test_continue_train_and_other(self):
|
||||
params = {
|
||||
|
|
Загрузка…
Ссылка в новой задаче