зеркало из https://github.com/microsoft/LightGBM.git
[python] Bug fix for first_metric_only on earlystopping. (#2209)
* Bug fix for first_metric_only if the first metric is train metric. * Update bug fix for feval issue. * Disable feval for first_metric_only. * Additional test items. * Fix wrong assertEqual settings & formating. * Change dataset of test. * Fix random seed for test. * Modiry assumed test result due to different sklearn verion between CI and local. * Remove f-string * Applying variable assumed test result for test. * Fix flake8 error. * Modifying in accordance with review comments. * Modifying for pylint. * simplified tests * Deleting error criteria `if eval_metric is None`. * Delete test items of classification. * Simplifying if condition. * Applying first_metric_only for sklearn wrapper. * Modifying test_sklearn for comforming to python 2.x * Fix flake8 error. * Additional fix for sklearn and add tests. * Bug fix and add test cases. * some refactor * fixed lint * fixed lint * Fix duplicated metrics scores to pass the test. * Fix the case first_metric_only not in params. * Converting metrics aliases. * Add comment. * Modify comment for pylint. * Modify comment for pydocstyle. * Using split test set for two eval_set. * added test case for metric aliases and length checks * minor style fixes * fixed rmse name and alias position * Fix the case metric=[] * Fix using env.model._train_data_name * Fix wrong test condition. * Move initial process to _init() func. * Modify test setting for test_sklearn & training data matching on callback.py * test_sklearn.py -> A test case for training is wrong, so fixed. * callback.py -> A condition of if statement for detecting test dataset is wrong, so fixed. * Support composite name metrics. * Remove metric check process & reduce redundant test cases. For #2273 fixed not only the order of metrics in cpp, removing metric check process at callback.py * Revised according to the matters pointed out on a review. * increased code readability * Fix the issue of order of validation set. * Changing to OrderdDict from default dict for score result. * added missed check in cv function for first_metric_only and feval co-occurrence * keep order only for metrics but not for datasets in best_score * move OrderedDict initialization to init phase * fixed minor printing issues * move first metric detection to init phase and split can be performed without checks * split only once during callback * removed excess code * fixed typo in variable name and squashed ifs * use setdefault * hotfix * fixed failing test * refined tests * refined sklearn test * Making "feval" effective on early stopping. * allow feval and first_metric_only for cv * removed unused code * added tests for feval * fixed printing * add note about whitespaces in feval name * Modifying final iteration process in case valid set is training data.
This commit is contained in:
Родитель
939a02482b
Коммит
84754399a6
|
@ -1631,7 +1631,7 @@ class Booster(object):
|
|||
self.handle = None
|
||||
self.network = False
|
||||
self.__need_reload_eval_info = True
|
||||
self.__train_data_name = "training"
|
||||
self._train_data_name = "training"
|
||||
self.__attr = {}
|
||||
self.__set_objective_to_none = False
|
||||
self.best_iteration = -1
|
||||
|
@ -1820,7 +1820,7 @@ class Booster(object):
|
|||
self : Booster
|
||||
Booster with set training Dataset name.
|
||||
"""
|
||||
self.__train_data_name = name
|
||||
self._train_data_name = name
|
||||
return self
|
||||
|
||||
def add_valid(self, data, name):
|
||||
|
@ -2047,7 +2047,7 @@ class Booster(object):
|
|||
eval_data : Dataset
|
||||
The evaluation dataset.
|
||||
eval_name : string
|
||||
The name of evaluation function.
|
||||
The name of evaluation function (without whitespaces).
|
||||
eval_result : float
|
||||
The eval result.
|
||||
is_higher_better : bool
|
||||
|
@ -2093,7 +2093,7 @@ class Booster(object):
|
|||
train_data : Dataset
|
||||
The training dataset.
|
||||
eval_name : string
|
||||
The name of evaluation function.
|
||||
The name of evaluation function (without whitespaces).
|
||||
eval_result : float
|
||||
The eval result.
|
||||
is_higher_better : bool
|
||||
|
@ -2107,7 +2107,7 @@ class Booster(object):
|
|||
result : list
|
||||
List with evaluation results.
|
||||
"""
|
||||
return self.__inner_eval(self.__train_data_name, 0, feval)
|
||||
return self.__inner_eval(self._train_data_name, 0, feval)
|
||||
|
||||
def eval_valid(self, feval=None):
|
||||
"""Evaluate for validation data.
|
||||
|
@ -2124,7 +2124,7 @@ class Booster(object):
|
|||
valid_data : Dataset
|
||||
The validation dataset.
|
||||
eval_name : string
|
||||
The name of evaluation function.
|
||||
The name of evaluation function (without whitespaces).
|
||||
eval_result : float
|
||||
The eval result.
|
||||
is_higher_better : bool
|
||||
|
|
|
@ -89,12 +89,13 @@ def record_evaluation(eval_result):
|
|||
The callback that records the evaluation history into the passed dictionary.
|
||||
"""
|
||||
if not isinstance(eval_result, dict):
|
||||
raise TypeError('Eval_result should be a dictionary')
|
||||
raise TypeError('eval_result should be a dictionary')
|
||||
eval_result.clear()
|
||||
|
||||
def _init(env):
|
||||
for data_name, _, _, _ in env.evaluation_result_list:
|
||||
eval_result.setdefault(data_name, collections.defaultdict(list))
|
||||
for data_name, eval_name, _, _ in env.evaluation_result_list:
|
||||
eval_result.setdefault(data_name, collections.OrderedDict())
|
||||
eval_result[data_name].setdefault(eval_name, [])
|
||||
|
||||
def _callback(env):
|
||||
if not eval_result:
|
||||
|
@ -132,7 +133,7 @@ def reset_parameter(**kwargs):
|
|||
if key in ['num_class', 'num_classes',
|
||||
'boosting', 'boost', 'boosting_type',
|
||||
'metric', 'metrics', 'metric_types']:
|
||||
raise RuntimeError("cannot reset {} during training".format(repr(key)))
|
||||
raise RuntimeError("Cannot reset {} during training".format(repr(key)))
|
||||
if isinstance(value, list):
|
||||
if len(value) != env.end_iteration - env.begin_iteration:
|
||||
raise ValueError("Length of list {} has to equal to 'num_boost_round'."
|
||||
|
@ -182,6 +183,7 @@ def early_stopping(stopping_rounds, first_metric_only=False, verbose=True):
|
|||
best_score_list = []
|
||||
cmp_op = []
|
||||
enabled = [True]
|
||||
first_metric = ['']
|
||||
|
||||
def _init(env):
|
||||
enabled[0] = not any((boost_alias in env.params
|
||||
|
@ -196,9 +198,11 @@ def early_stopping(stopping_rounds, first_metric_only=False, verbose=True):
|
|||
'at least one dataset and eval metric is required for evaluation')
|
||||
|
||||
if verbose:
|
||||
msg = "Training until validation scores don't improve for {} rounds."
|
||||
msg = "Training until validation scores don't improve for {} rounds"
|
||||
print(msg.format(stopping_rounds))
|
||||
|
||||
# split is needed for "<dataset type> <metric>" case (e.g. "train l1")
|
||||
first_metric[0] = env.evaluation_result_list[0][1].split(" ")[-1]
|
||||
for eval_ret in env.evaluation_result_list:
|
||||
best_iter.append(0)
|
||||
best_score_list.append(None)
|
||||
|
@ -209,6 +213,15 @@ def early_stopping(stopping_rounds, first_metric_only=False, verbose=True):
|
|||
best_score.append(float('inf'))
|
||||
cmp_op.append(lt)
|
||||
|
||||
def _final_iteration_check(env, eval_name_splitted, i):
|
||||
if env.iteration == env.end_iteration - 1:
|
||||
if verbose:
|
||||
print('Did not meet early stopping. Best iteration is:\n[%d]\t%s' % (
|
||||
best_iter[i] + 1, '\t'.join([_format_eval_result(x) for x in best_score_list[i]])))
|
||||
if first_metric_only:
|
||||
print("Evaluated only: {}".format(eval_name_splitted[-1]))
|
||||
raise EarlyStopException(best_iter[i], best_score_list[i])
|
||||
|
||||
def _callback(env):
|
||||
if not cmp_op:
|
||||
_init(env)
|
||||
|
@ -220,17 +233,21 @@ def early_stopping(stopping_rounds, first_metric_only=False, verbose=True):
|
|||
best_score[i] = score
|
||||
best_iter[i] = env.iteration
|
||||
best_score_list[i] = env.evaluation_result_list
|
||||
# split is needed for "<dataset type> <metric>" case (e.g. "train l1")
|
||||
eval_name_splitted = env.evaluation_result_list[i][1].split(" ")
|
||||
if first_metric_only and first_metric[0] != eval_name_splitted[-1]:
|
||||
continue # use only the first metric for early stopping
|
||||
if ((env.evaluation_result_list[i][0] == "cv_agg" and eval_name_splitted[0] == "train"
|
||||
or env.evaluation_result_list[i][0] == env.model._train_data_name)):
|
||||
_final_iteration_check(env, eval_name_splitted, i)
|
||||
continue # train data for lgb.cv or sklearn wrapper (underlying lgb.train)
|
||||
elif env.iteration - best_iter[i] >= stopping_rounds:
|
||||
if verbose:
|
||||
print('Early stopping, best iteration is:\n[%d]\t%s' % (
|
||||
best_iter[i] + 1, '\t'.join([_format_eval_result(x) for x in best_score_list[i]])))
|
||||
if first_metric_only:
|
||||
print("Evaluated only: {}".format(eval_name_splitted[-1]))
|
||||
raise EarlyStopException(best_iter[i], best_score_list[i])
|
||||
if env.iteration == env.end_iteration - 1:
|
||||
if verbose:
|
||||
print('Did not meet early stopping. Best iteration is:\n[%d]\t%s' % (
|
||||
best_iter[i] + 1, '\t'.join([_format_eval_result(x) for x in best_score_list[i]])))
|
||||
raise EarlyStopException(best_iter[i], best_score_list[i])
|
||||
if first_metric_only: # the only first metric is used for early stopping
|
||||
break
|
||||
_final_iteration_check(env, eval_name_splitted, i)
|
||||
_callback.order = 30
|
||||
return _callback
|
||||
|
|
|
@ -65,7 +65,7 @@ def train(params, train_set, num_boost_round=100,
|
|||
train_data : Dataset
|
||||
The training dataset.
|
||||
eval_name : string
|
||||
The name of evaluation function.
|
||||
The name of evaluation function (without whitespaces).
|
||||
eval_result : float
|
||||
The eval result.
|
||||
is_higher_better : bool
|
||||
|
@ -266,7 +266,7 @@ def train(params, train_set, num_boost_round=100,
|
|||
booster.best_iteration = earlyStopException.best_iteration + 1
|
||||
evaluation_result_list = earlyStopException.best_score
|
||||
break
|
||||
booster.best_score = collections.defaultdict(dict)
|
||||
booster.best_score = collections.defaultdict(collections.OrderedDict)
|
||||
for dataset_name, eval_name, score, _ in evaluation_result_list:
|
||||
booster.best_score[dataset_name][eval_name] = score
|
||||
if not keep_training_booster:
|
||||
|
@ -356,7 +356,7 @@ def _make_n_folds(full_data, folds, nfold, params, seed, fpreproc=None, stratifi
|
|||
|
||||
def _agg_cv_result(raw_results, eval_train_metric=False):
|
||||
"""Aggregate cross-validation results."""
|
||||
cvmap = collections.defaultdict(list)
|
||||
cvmap = collections.OrderedDict()
|
||||
metric_type = {}
|
||||
for one_result in raw_results:
|
||||
for one_line in one_result:
|
||||
|
@ -365,6 +365,7 @@ def _agg_cv_result(raw_results, eval_train_metric=False):
|
|||
else:
|
||||
key = one_line[1]
|
||||
metric_type[key] = one_line[3]
|
||||
cvmap.setdefault(key, [])
|
||||
cvmap[key].append(one_line[2])
|
||||
return [('cv_agg', k, np.mean(v), metric_type[k], np.std(v)) for k, v in cvmap.items()]
|
||||
|
||||
|
@ -429,7 +430,7 @@ def cv(params, train_set, num_boost_round=100,
|
|||
train_data : Dataset
|
||||
The training dataset.
|
||||
eval_name : string
|
||||
The name of evaluation function.
|
||||
The name of evaluation function (without whitespaces).
|
||||
eval_result : float
|
||||
The eval result.
|
||||
is_higher_better : bool
|
||||
|
|
|
@ -121,7 +121,7 @@ class _EvalFunctionWrapper(object):
|
|||
group : array-like
|
||||
Group/query data, used for ranking task.
|
||||
eval_name : string
|
||||
The name of evaluation function.
|
||||
The name of evaluation function (without whitespaces).
|
||||
eval_result : float
|
||||
The eval result.
|
||||
is_higher_better : bool
|
||||
|
@ -147,7 +147,7 @@ class _EvalFunctionWrapper(object):
|
|||
Returns
|
||||
-------
|
||||
eval_name : string
|
||||
The name of evaluation function.
|
||||
The name of evaluation function (without whitespaces).
|
||||
eval_result : float
|
||||
The eval result.
|
||||
is_higher_better : bool
|
||||
|
@ -464,7 +464,7 @@ class LGBMModel(_LGBMModelBase):
|
|||
group : array-like
|
||||
Group/query data, used for ranking task.
|
||||
eval_name : string
|
||||
The name of evaluation function.
|
||||
The name of evaluation function (without whitespaces).
|
||||
eval_result : float
|
||||
The eval result.
|
||||
is_higher_better : bool
|
||||
|
@ -524,7 +524,8 @@ class LGBMModel(_LGBMModelBase):
|
|||
# concatenate metric from params (or default if not provided in params) and eval_metric
|
||||
original_metric = [original_metric] if isinstance(original_metric, (string_type, type(None))) else original_metric
|
||||
eval_metric = [eval_metric] if isinstance(eval_metric, (string_type, type(None))) else eval_metric
|
||||
params['metric'] = set(original_metric + eval_metric)
|
||||
params['metric'] = [e for e in eval_metric if e not in original_metric] + original_metric
|
||||
params['metric'] = [metric for metric in params['metric'] if metric is not None]
|
||||
|
||||
if not isinstance(X, (DataFrame, DataTable)):
|
||||
_X, _y = _LGBMCheckXY(X, y, accept_sparse=True, force_all_finite=False, ensure_min_samples=2)
|
||||
|
|
|
@ -494,10 +494,8 @@ class TestEngine(unittest.TestCase):
|
|||
evals_result=evals_result,
|
||||
init_model=init_gbm)
|
||||
ret = mean_absolute_error(y_test, gbm.predict(X_test))
|
||||
self.assertLess(ret, 3.5)
|
||||
self.assertLess(ret, 2.5)
|
||||
self.assertAlmostEqual(evals_result['valid_0']['l1'][-1], ret, places=5)
|
||||
for l1, mae in zip(evals_result['valid_0']['l1'], evals_result['valid_0']['mae']):
|
||||
self.assertAlmostEqual(l1, mae, places=5)
|
||||
|
||||
def test_continue_train_multiclass(self):
|
||||
X, y = load_iris(True)
|
||||
|
@ -1545,17 +1543,6 @@ class TestEngine(unittest.TestCase):
|
|||
self.assertRaises(lgb.basic.LightGBMError, gbm.get_split_value_histogram, 2)
|
||||
|
||||
def test_early_stopping_for_only_first_metric(self):
|
||||
X, y = load_boston(True)
|
||||
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)
|
||||
params = {
|
||||
'objective': 'regression',
|
||||
'metric': 'None',
|
||||
'verbose': -1
|
||||
}
|
||||
lgb_train = lgb.Dataset(X_train, y_train)
|
||||
lgb_eval = lgb.Dataset(X_test, y_test, reference=lgb_train)
|
||||
|
||||
decreasing_generator = itertools.count(0, -1)
|
||||
|
||||
def decreasing_metric(preds, train_data):
|
||||
return ('decreasing_metric', next(decreasing_generator), False)
|
||||
|
@ -1563,27 +1550,117 @@ class TestEngine(unittest.TestCase):
|
|||
def constant_metric(preds, train_data):
|
||||
return ('constant_metric', 0.0, False)
|
||||
|
||||
# test that all metrics are checked (default behaviour)
|
||||
gbm = lgb.train(params, lgb_train, num_boost_round=20, valid_sets=[lgb_eval],
|
||||
feval=lambda preds, train_data: [decreasing_metric(preds, train_data),
|
||||
constant_metric(preds, train_data)],
|
||||
early_stopping_rounds=5, verbose_eval=False)
|
||||
self.assertEqual(gbm.best_iteration, 1)
|
||||
def metrics_combination_train_regression(valid_sets, metric_list, assumed_iteration,
|
||||
first_metric_only, feval=None):
|
||||
params = {
|
||||
'objective': 'regression',
|
||||
'learning_rate': 1.1,
|
||||
'num_leaves': 10,
|
||||
'metric': metric_list,
|
||||
'verbose': -1,
|
||||
'seed': 123
|
||||
}
|
||||
gbm = lgb.train(dict(params, first_metric_only=first_metric_only), lgb_train,
|
||||
num_boost_round=25, valid_sets=valid_sets, feval=feval,
|
||||
early_stopping_rounds=5, verbose_eval=False)
|
||||
self.assertEqual(assumed_iteration, gbm.best_iteration)
|
||||
|
||||
# test that only the first metric is checked
|
||||
gbm = lgb.train(dict(params, first_metric_only=True), lgb_train,
|
||||
num_boost_round=20, valid_sets=[lgb_eval],
|
||||
feval=lambda preds, train_data: [decreasing_metric(preds, train_data),
|
||||
constant_metric(preds, train_data)],
|
||||
early_stopping_rounds=5, verbose_eval=False)
|
||||
self.assertEqual(gbm.best_iteration, 20)
|
||||
# ... change the order of metrics
|
||||
gbm = lgb.train(dict(params, first_metric_only=True), lgb_train,
|
||||
num_boost_round=20, valid_sets=[lgb_eval],
|
||||
feval=lambda preds, train_data: [constant_metric(preds, train_data),
|
||||
decreasing_metric(preds, train_data)],
|
||||
early_stopping_rounds=5, verbose_eval=False)
|
||||
self.assertEqual(gbm.best_iteration, 1)
|
||||
def metrics_combination_cv_regression(metric_list, assumed_iteration,
|
||||
first_metric_only, eval_train_metric, feval=None):
|
||||
params = {
|
||||
'objective': 'regression',
|
||||
'learning_rate': 0.9,
|
||||
'num_leaves': 10,
|
||||
'metric': metric_list,
|
||||
'verbose': -1,
|
||||
'seed': 123,
|
||||
'gpu_use_dp': True
|
||||
}
|
||||
ret = lgb.cv(dict(params, first_metric_only=first_metric_only),
|
||||
train_set=lgb_train, num_boost_round=25,
|
||||
stratified=False, feval=feval,
|
||||
early_stopping_rounds=5, verbose_eval=False,
|
||||
eval_train_metric=eval_train_metric)
|
||||
self.assertEqual(assumed_iteration, len(ret[list(ret.keys())[0]]))
|
||||
|
||||
decreasing_generator = itertools.count(0, -1)
|
||||
X, y = load_boston(True)
|
||||
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
|
||||
X_test1, X_test2, y_test1, y_test2 = train_test_split(X_test, y_test, test_size=0.5, random_state=73)
|
||||
lgb_train = lgb.Dataset(X_train, y_train)
|
||||
lgb_valid1 = lgb.Dataset(X_test1, y_test1, reference=lgb_train)
|
||||
lgb_valid2 = lgb.Dataset(X_test2, y_test2, reference=lgb_train)
|
||||
|
||||
iter_valid1_l1 = 3
|
||||
iter_valid1_l2 = 14
|
||||
iter_valid2_l1 = 2
|
||||
iter_valid2_l2 = 15
|
||||
self.assertEqual(len(set([iter_valid1_l1, iter_valid1_l2, iter_valid2_l1, iter_valid2_l2])), 4)
|
||||
iter_min_l1 = min([iter_valid1_l1, iter_valid2_l1])
|
||||
iter_min_l2 = min([iter_valid1_l2, iter_valid2_l2])
|
||||
iter_min = min([iter_min_l1, iter_min_l2])
|
||||
iter_min_valid1 = min([iter_valid1_l1, iter_valid1_l2])
|
||||
|
||||
iter_cv_l1 = 3
|
||||
iter_cv_l2 = 17
|
||||
self.assertEqual(len(set([iter_cv_l1, iter_cv_l2])), 2)
|
||||
iter_cv_min = min([iter_cv_l1, iter_cv_l2])
|
||||
|
||||
# test for lgb.train
|
||||
metrics_combination_train_regression(lgb_valid1, [], iter_valid1_l2, False)
|
||||
metrics_combination_train_regression(lgb_valid1, [], iter_valid1_l2, True)
|
||||
metrics_combination_train_regression(lgb_valid1, None, iter_valid1_l2, False)
|
||||
metrics_combination_train_regression(lgb_valid1, None, iter_valid1_l2, True)
|
||||
metrics_combination_train_regression(lgb_valid1, 'l2', iter_valid1_l2, True)
|
||||
metrics_combination_train_regression(lgb_valid1, 'l1', iter_valid1_l1, True)
|
||||
metrics_combination_train_regression(lgb_valid1, ['l2', 'l1'], iter_valid1_l2, True)
|
||||
metrics_combination_train_regression(lgb_valid1, ['l1', 'l2'], iter_valid1_l1, True)
|
||||
metrics_combination_train_regression(lgb_valid1, ['l2', 'l1'], iter_min_valid1, False)
|
||||
metrics_combination_train_regression(lgb_valid1, ['l1', 'l2'], iter_min_valid1, False)
|
||||
|
||||
# test feval for lgb.train
|
||||
metrics_combination_train_regression(lgb_valid1, 'None', 1, False,
|
||||
feval=lambda preds, train_data: [decreasing_metric(preds, train_data),
|
||||
constant_metric(preds, train_data)])
|
||||
metrics_combination_train_regression(lgb_valid1, 'None', 25, True,
|
||||
feval=lambda preds, train_data: [decreasing_metric(preds, train_data),
|
||||
constant_metric(preds, train_data)])
|
||||
metrics_combination_train_regression(lgb_valid1, 'None', 1, True,
|
||||
feval=lambda preds, train_data: [constant_metric(preds, train_data),
|
||||
decreasing_metric(preds, train_data)])
|
||||
|
||||
# test with two valid data for lgb.train
|
||||
metrics_combination_train_regression([lgb_valid1, lgb_valid2], ['l2', 'l1'], iter_min_l2, True)
|
||||
metrics_combination_train_regression([lgb_valid2, lgb_valid1], ['l2', 'l1'], iter_min_l2, True)
|
||||
metrics_combination_train_regression([lgb_valid1, lgb_valid2], ['l1', 'l2'], iter_min_l1, True)
|
||||
metrics_combination_train_regression([lgb_valid2, lgb_valid1], ['l1', 'l2'], iter_min_l1, True)
|
||||
|
||||
# test for lgb.cv
|
||||
metrics_combination_cv_regression(None, iter_cv_l2, True, False)
|
||||
metrics_combination_cv_regression('l2', iter_cv_l2, True, False)
|
||||
metrics_combination_cv_regression('l1', iter_cv_l1, True, False)
|
||||
metrics_combination_cv_regression(['l2', 'l1'], iter_cv_l2, True, False)
|
||||
metrics_combination_cv_regression(['l1', 'l2'], iter_cv_l1, True, False)
|
||||
metrics_combination_cv_regression(['l2', 'l1'], iter_cv_min, False, False)
|
||||
metrics_combination_cv_regression(['l1', 'l2'], iter_cv_min, False, False)
|
||||
metrics_combination_cv_regression(None, iter_cv_l2, True, True)
|
||||
metrics_combination_cv_regression('l2', iter_cv_l2, True, True)
|
||||
metrics_combination_cv_regression('l1', iter_cv_l1, True, True)
|
||||
metrics_combination_cv_regression(['l2', 'l1'], iter_cv_l2, True, True)
|
||||
metrics_combination_cv_regression(['l1', 'l2'], iter_cv_l1, True, True)
|
||||
metrics_combination_cv_regression(['l2', 'l1'], iter_cv_min, False, True)
|
||||
metrics_combination_cv_regression(['l1', 'l2'], iter_cv_min, False, True)
|
||||
|
||||
# test feval for lgb.cv
|
||||
metrics_combination_cv_regression('None', 1, False, False,
|
||||
feval=lambda preds, train_data: [decreasing_metric(preds, train_data),
|
||||
constant_metric(preds, train_data)])
|
||||
metrics_combination_cv_regression('None', 25, True, False,
|
||||
feval=lambda preds, train_data: [decreasing_metric(preds, train_data),
|
||||
constant_metric(preds, train_data)])
|
||||
metrics_combination_cv_regression('None', 1, True, False,
|
||||
feval=lambda preds, train_data: [constant_metric(preds, train_data),
|
||||
decreasing_metric(preds, train_data)])
|
||||
|
||||
def test_node_level_subcol(self):
|
||||
X, y = load_breast_cancer(True)
|
||||
|
|
|
@ -638,6 +638,113 @@ class TestSklearn(unittest.TestCase):
|
|||
gbm = lgb.LGBMRegressor(**params).fit(**params_fit)
|
||||
np.testing.assert_allclose(gbm.evals_result_['training']['l2'], np.nan)
|
||||
|
||||
def test_first_metric_only(self):
|
||||
|
||||
def decreasing_metric(y_true, y_pred):
|
||||
return ('decreasing_metric', next(decreasing_generator), False)
|
||||
|
||||
def constant_metric(y_true, y_pred):
|
||||
return ('constant_metric', 0.0, False)
|
||||
|
||||
def fit_and_check(eval_set_names, metric_names, assumed_iteration, first_metric_only):
|
||||
params['first_metric_only'] = first_metric_only
|
||||
gbm = lgb.LGBMRegressor(**params).fit(**params_fit)
|
||||
self.assertEqual(len(gbm.evals_result_), len(eval_set_names))
|
||||
for eval_set_name in eval_set_names:
|
||||
self.assertIn(eval_set_name, gbm.evals_result_)
|
||||
self.assertEqual(len(gbm.evals_result_[eval_set_name]), len(metric_names))
|
||||
for metric_name in metric_names:
|
||||
self.assertIn(metric_name, gbm.evals_result_[eval_set_name])
|
||||
|
||||
actual = len(gbm.evals_result_[eval_set_name][metric_name])
|
||||
expected = assumed_iteration + (params_fit['early_stopping_rounds']
|
||||
if eval_set_name != 'training'
|
||||
and assumed_iteration != gbm.n_estimators else 0)
|
||||
self.assertEqual(expected, actual)
|
||||
self.assertEqual(assumed_iteration if eval_set_name != 'training' else params['n_estimators'],
|
||||
gbm.best_iteration_)
|
||||
|
||||
decreasing_generator = itertools.count(0, -1)
|
||||
X, y = load_boston(True)
|
||||
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
|
||||
X_test1, X_test2, y_test1, y_test2 = train_test_split(X_test, y_test, test_size=0.5, random_state=72)
|
||||
params = {'n_estimators': 30,
|
||||
'learning_rate': 0.8,
|
||||
'num_leaves': 15,
|
||||
'verbose': -1,
|
||||
'seed': 123}
|
||||
params_fit = {'X': X_train,
|
||||
'y': y_train,
|
||||
'early_stopping_rounds': 5,
|
||||
'verbose': False}
|
||||
|
||||
iter_valid1_l1 = 3
|
||||
iter_valid1_l2 = 18
|
||||
iter_valid2_l1 = 11
|
||||
iter_valid2_l2 = 7
|
||||
self.assertEqual(len(set([iter_valid1_l1, iter_valid1_l2, iter_valid2_l1, iter_valid2_l2])), 4)
|
||||
iter_min_l1 = min([iter_valid1_l1, iter_valid2_l1])
|
||||
iter_min_l2 = min([iter_valid1_l2, iter_valid2_l2])
|
||||
iter_min = min([iter_min_l1, iter_min_l2])
|
||||
iter_min_valid1 = min([iter_valid1_l1, iter_valid1_l2])
|
||||
|
||||
# training data as eval_set
|
||||
params_fit['eval_set'] = (X_train, y_train)
|
||||
fit_and_check(['training'], ['l2'], 30, False)
|
||||
fit_and_check(['training'], ['l2'], 30, True)
|
||||
|
||||
# feval
|
||||
params['metric'] = 'None'
|
||||
params_fit['eval_metric'] = lambda preds, train_data: [decreasing_metric(preds, train_data),
|
||||
constant_metric(preds, train_data)]
|
||||
params_fit['eval_set'] = (X_test1, y_test1)
|
||||
fit_and_check(['valid_0'], ['decreasing_metric', 'constant_metric'], 1, False)
|
||||
fit_and_check(['valid_0'], ['decreasing_metric', 'constant_metric'], 30, True)
|
||||
params_fit['eval_metric'] = lambda preds, train_data: [constant_metric(preds, train_data),
|
||||
decreasing_metric(preds, train_data)]
|
||||
fit_and_check(['valid_0'], ['decreasing_metric', 'constant_metric'], 1, True)
|
||||
|
||||
# single eval_set
|
||||
params.pop('metric')
|
||||
params_fit.pop('eval_metric')
|
||||
fit_and_check(['valid_0'], ['l2'], iter_valid1_l2, False)
|
||||
fit_and_check(['valid_0'], ['l2'], iter_valid1_l2, True)
|
||||
|
||||
params_fit['eval_metric'] = "l2"
|
||||
fit_and_check(['valid_0'], ['l2'], iter_valid1_l2, False)
|
||||
fit_and_check(['valid_0'], ['l2'], iter_valid1_l2, True)
|
||||
|
||||
params_fit['eval_metric'] = "l1"
|
||||
fit_and_check(['valid_0'], ['l1', 'l2'], iter_min_valid1, False)
|
||||
fit_and_check(['valid_0'], ['l1', 'l2'], iter_valid1_l1, True)
|
||||
|
||||
params_fit['eval_metric'] = ["l1", "l2"]
|
||||
fit_and_check(['valid_0'], ['l1', 'l2'], iter_min_valid1, False)
|
||||
fit_and_check(['valid_0'], ['l1', 'l2'], iter_valid1_l1, True)
|
||||
|
||||
params_fit['eval_metric'] = ["l2", "l1"]
|
||||
fit_and_check(['valid_0'], ['l1', 'l2'], iter_min_valid1, False)
|
||||
fit_and_check(['valid_0'], ['l1', 'l2'], iter_valid1_l2, True)
|
||||
|
||||
params_fit['eval_metric'] = ["l2", "regression", "mse"] # test aliases
|
||||
fit_and_check(['valid_0'], ['l2'], iter_valid1_l2, False)
|
||||
fit_and_check(['valid_0'], ['l2'], iter_valid1_l2, True)
|
||||
|
||||
# two eval_set
|
||||
params_fit['eval_set'] = [(X_test1, y_test1), (X_test2, y_test2)]
|
||||
params_fit['eval_metric'] = ["l1", "l2"]
|
||||
fit_and_check(['valid_0', 'valid_1'], ['l1', 'l2'], iter_min_l1, True)
|
||||
params_fit['eval_metric'] = ["l2", "l1"]
|
||||
fit_and_check(['valid_0', 'valid_1'], ['l1', 'l2'], iter_min_l2, True)
|
||||
|
||||
params_fit['eval_set'] = [(X_test2, y_test2), (X_test1, y_test1)]
|
||||
params_fit['eval_metric'] = ["l1", "l2"]
|
||||
fit_and_check(['valid_0', 'valid_1'], ['l1', 'l2'], iter_min, False)
|
||||
fit_and_check(['valid_0', 'valid_1'], ['l1', 'l2'], iter_min_l1, True)
|
||||
params_fit['eval_metric'] = ["l2", "l1"]
|
||||
fit_and_check(['valid_0', 'valid_1'], ['l1', 'l2'], iter_min, False)
|
||||
fit_and_check(['valid_0', 'valid_1'], ['l1', 'l2'], iter_min_l2, True)
|
||||
|
||||
def test_class_weight(self):
|
||||
X, y = load_digits(10, True)
|
||||
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
|
||||
|
|
Загрузка…
Ссылка в новой задаче