зеркало из https://github.com/microsoft/LightGBM.git
* initial changes * initial version * better handling of cases * warn only with positive threshold * remove early_stopping_threshold from high-level functions * remove remaining early_stopping_threshold * update test to use callback * better handling of cases * rename threshold to min_delta enhance parameter description update tests * Apply suggestions from code review Co-authored-by: Nikita Titov <nekit94-08@mail.ru> * reduce num_boost_round in tests * Apply suggestions from code review Co-authored-by: Nikita Titov <nekit94-08@mail.ru> * trigger ci Co-authored-by: Nikita Titov <nekit94-08@mail.ru> Co-authored-by: Nikita Titov <nekit94-12@hotmail.com>
This commit is contained in:
Родитель
0a4d190828
Коммит
99e0a4bd7b
|
@ -1,12 +1,20 @@
|
|||
# coding: utf-8
|
||||
"""Callbacks library."""
|
||||
import collections
|
||||
from operator import gt, lt
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Dict, List, Union
|
||||
|
||||
from .basic import _ConfigAliases, _log_info, _log_warning
|
||||
|
||||
|
||||
def _gt_delta(curr_score: float, best_score: float, delta: float) -> bool:
|
||||
return curr_score > best_score + delta
|
||||
|
||||
|
||||
def _lt_delta(curr_score: float, best_score: float, delta: float) -> bool:
|
||||
return curr_score < best_score - delta
|
||||
|
||||
|
||||
class EarlyStopException(Exception):
|
||||
"""Exception of early stopping."""
|
||||
|
||||
|
@ -181,11 +189,11 @@ def reset_parameter(**kwargs: Union[list, Callable]) -> Callable:
|
|||
return _callback
|
||||
|
||||
|
||||
def early_stopping(stopping_rounds: int, first_metric_only: bool = False, verbose: bool = True) -> Callable:
|
||||
def early_stopping(stopping_rounds: int, first_metric_only: bool = False, verbose: bool = True, min_delta: Union[float, List[float]] = 0.0) -> Callable:
|
||||
"""Create a callback that activates early stopping.
|
||||
|
||||
Activates early stopping.
|
||||
The model will train until the validation score stops improving.
|
||||
The model will train until the validation score doesn't improve by at least ``min_delta``.
|
||||
Validation score needs to improve at least every ``stopping_rounds`` round(s)
|
||||
to continue training.
|
||||
Requires at least one validation data and one metric.
|
||||
|
@ -203,6 +211,10 @@ def early_stopping(stopping_rounds: int, first_metric_only: bool = False, verbos
|
|||
Whether to log message with early stopping information.
|
||||
By default, standard output resource is used.
|
||||
Use ``register_logger()`` function to register a custom logger.
|
||||
min_delta : float or list of float, optional (default=0.0)
|
||||
Minimum improvement in score to keep training.
|
||||
If float, this single value is used for all metrics.
|
||||
If list, its length should match the total number of metrics.
|
||||
|
||||
Returns
|
||||
-------
|
||||
|
@ -229,17 +241,43 @@ def early_stopping(stopping_rounds: int, first_metric_only: bool = False, verbos
|
|||
if verbose:
|
||||
_log_info(f"Training until validation scores don't improve for {stopping_rounds} rounds")
|
||||
|
||||
n_metrics = len(set(m[1] for m in env.evaluation_result_list))
|
||||
n_datasets = len(env.evaluation_result_list) // n_metrics
|
||||
if isinstance(min_delta, list):
|
||||
if not all(t >= 0 for t in min_delta):
|
||||
raise ValueError('Values for early stopping min_delta must be non-negative.')
|
||||
if len(min_delta) == 0:
|
||||
if verbose:
|
||||
_log_info('Disabling min_delta for early stopping.')
|
||||
deltas = [0.0] * n_datasets * n_metrics
|
||||
elif len(min_delta) == 1:
|
||||
if verbose:
|
||||
_log_info(f'Using {min_delta[0]} as min_delta for all metrics.')
|
||||
deltas = min_delta * n_datasets * n_metrics
|
||||
else:
|
||||
if len(min_delta) != n_metrics:
|
||||
raise ValueError('Must provide a single value for min_delta or as many as metrics.')
|
||||
if first_metric_only and verbose:
|
||||
_log_info(f'Using only {min_delta[0]} as early stopping min_delta.')
|
||||
deltas = min_delta * n_datasets
|
||||
else:
|
||||
if min_delta < 0:
|
||||
raise ValueError('Early stopping min_delta must be non-negative.')
|
||||
if min_delta > 0 and n_metrics > 1 and not first_metric_only and verbose:
|
||||
_log_info(f'Using {min_delta} as min_delta for all metrics.')
|
||||
deltas = [min_delta] * n_datasets * n_metrics
|
||||
|
||||
# split is needed for "<dataset type> <metric>" case (e.g. "train l1")
|
||||
first_metric[0] = env.evaluation_result_list[0][1].split(" ")[-1]
|
||||
for eval_ret in env.evaluation_result_list:
|
||||
for eval_ret, delta in zip(env.evaluation_result_list, deltas):
|
||||
best_iter.append(0)
|
||||
best_score_list.append(None)
|
||||
if eval_ret[3]:
|
||||
if eval_ret[3]: # greater is better
|
||||
best_score.append(float('-inf'))
|
||||
cmp_op.append(gt)
|
||||
cmp_op.append(partial(_gt_delta, delta=delta))
|
||||
else:
|
||||
best_score.append(float('inf'))
|
||||
cmp_op.append(lt)
|
||||
cmp_op.append(partial(_lt_delta, delta=delta))
|
||||
|
||||
def _final_iteration_check(env: CallbackEnv, eval_name_splitted: List[str], i: int) -> None:
|
||||
if env.iteration == env.end_iteration - 1:
|
||||
|
|
|
@ -643,6 +643,81 @@ def test_early_stopping():
|
|||
assert 'binary_logloss' in gbm.best_score[valid_set_name]
|
||||
|
||||
|
||||
@pytest.mark.parametrize('first_only', [True, False])
|
||||
@pytest.mark.parametrize('single_metric', [True, False])
|
||||
@pytest.mark.parametrize('greater_is_better', [True, False])
|
||||
def test_early_stopping_min_delta(first_only, single_metric, greater_is_better):
|
||||
if single_metric and not first_only:
|
||||
pytest.skip("first_metric_only doesn't affect single metric.")
|
||||
metric2min_delta = {
|
||||
'auc': 0.001,
|
||||
'binary_logloss': 0.01,
|
||||
'average_precision': 0.001,
|
||||
'mape': 0.01,
|
||||
}
|
||||
if single_metric:
|
||||
if greater_is_better:
|
||||
metric = 'auc'
|
||||
else:
|
||||
metric = 'binary_logloss'
|
||||
else:
|
||||
if first_only:
|
||||
if greater_is_better:
|
||||
metric = ['auc', 'binary_logloss']
|
||||
else:
|
||||
metric = ['binary_logloss', 'auc']
|
||||
else:
|
||||
if greater_is_better:
|
||||
metric = ['auc', 'average_precision']
|
||||
else:
|
||||
metric = ['binary_logloss', 'mape']
|
||||
|
||||
X, y = load_breast_cancer(return_X_y=True)
|
||||
X_train, X_valid, y_train, y_valid = train_test_split(X, y, test_size=0.2, random_state=0)
|
||||
train_ds = lgb.Dataset(X_train, y_train)
|
||||
valid_ds = lgb.Dataset(X_valid, y_valid, reference=train_ds)
|
||||
|
||||
params = {'objective': 'binary', 'metric': metric, 'verbose': -1}
|
||||
if isinstance(metric, str):
|
||||
min_delta = metric2min_delta[metric]
|
||||
elif first_only:
|
||||
min_delta = metric2min_delta[metric[0]]
|
||||
else:
|
||||
min_delta = [metric2min_delta[m] for m in metric]
|
||||
train_kwargs = dict(
|
||||
params=params,
|
||||
train_set=train_ds,
|
||||
num_boost_round=50,
|
||||
valid_sets=[train_ds, valid_ds],
|
||||
valid_names=['training', 'valid'],
|
||||
)
|
||||
|
||||
# regular early stopping
|
||||
train_kwargs['callbacks'] = [lgb.callback.early_stopping(10, first_only, verbose=0)]
|
||||
evals_result = {}
|
||||
bst = lgb.train(evals_result=evals_result, **train_kwargs)
|
||||
scores = np.vstack(list(evals_result['valid'].values())).T
|
||||
|
||||
# positive min_delta
|
||||
train_kwargs['callbacks'] = [lgb.callback.early_stopping(10, first_only, verbose=0, min_delta=min_delta)]
|
||||
delta_result = {}
|
||||
delta_bst = lgb.train(evals_result=delta_result, **train_kwargs)
|
||||
delta_scores = np.vstack(list(delta_result['valid'].values())).T
|
||||
|
||||
if first_only:
|
||||
scores = scores[:, 0]
|
||||
delta_scores = delta_scores[:, 0]
|
||||
|
||||
assert delta_bst.num_trees() < bst.num_trees()
|
||||
np.testing.assert_allclose(scores[:len(delta_scores)], delta_scores)
|
||||
last_score = delta_scores[-1]
|
||||
best_score = delta_scores[delta_bst.num_trees() - 1]
|
||||
if greater_is_better:
|
||||
assert np.less_equal(last_score, best_score + min_delta).any()
|
||||
else:
|
||||
assert np.greater_equal(last_score, best_score - min_delta).any()
|
||||
|
||||
|
||||
def test_continue_train():
|
||||
X, y = load_boston(return_X_y=True)
|
||||
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)
|
||||
|
|
Загрузка…
Ссылка в новой задаче