[RFC][python] deprecate advanced args of `train()` and `cv()` functions and sklearn wrapper (#4574)

* deprecate advanced args of `train()` and `cv()`

* update Dask test

* improve deducing

* address review comments
This commit is contained in:
Nikita Titov 2021-09-12 22:19:03 +03:00 коммит произвёл GitHub
Родитель a08c37f647
Коммит 86bda6f061
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
4 изменённых файлов: 115 добавлений и 29 удалений

Просмотреть файл

@ -50,19 +50,27 @@ def _format_eval_result(value: list, show_stdv: bool = True) -> str:
def print_evaluation(period: int = 1, show_stdv: bool = True) -> Callable:
"""Create a callback that prints the evaluation results.
"""Create a callback that logs the evaluation results.
By default, standard output resource is used.
Use ``register_logger()`` function to register a custom logger.
Note
----
Requires at least one validation data.
Parameters
----------
period : int, optional (default=1)
The period to print the evaluation results.
The period to log the evaluation results.
The last boosting stage or the boosting stage found by using ``early_stopping`` callback is also logged.
show_stdv : bool, optional (default=True)
Whether to show stdv (if provided).
Whether to log stdv (if provided).
Returns
-------
callback : callable
The callback that prints the evaluation results every ``period`` iteration(s).
The callback that logs the evaluation results every ``period`` boosting iteration(s).
"""
def _callback(env: CallbackEnv) -> None:
if period > 0 and env.evaluation_result_list and (env.iteration + 1) % period == 0:
@ -82,6 +90,24 @@ def record_evaluation(eval_result: Dict[str, Dict[str, List[Any]]]) -> Callable:
This should be initialized outside of your call to ``record_evaluation()`` and should be empty.
Any initial contents of the dictionary will be deleted.
.. rubric:: Example
With two validation sets named 'eval' and 'train', and one evaluation metric named 'logloss'
this dictionary after finishing a model training process will have the following structure:
.. code-block::
{
'train':
{
'logloss': [0.48253, 0.35953, ...]
},
'eval':
{
'logloss': [0.480385, 0.357756, ...]
}
}
Returns
-------
callback : callable
@ -150,11 +176,12 @@ def early_stopping(stopping_rounds: int, first_metric_only: bool = False, verbos
Activates early stopping.
The model will train until the validation score stops improving.
Validation score needs to improve at least every ``early_stopping_rounds`` round(s)
Validation score needs to improve at least every ``stopping_rounds`` round(s)
to continue training.
Requires at least one validation data and one metric.
If there's more than one, will check all of them. But the training data is ignored anyway.
To check only the first metric set ``first_metric_only`` to True.
The index of iteration that has the best performance will be saved in the ``best_iteration`` attribute of a model.
Parameters
----------
@ -163,7 +190,9 @@ def early_stopping(stopping_rounds: int, first_metric_only: bool = False, verbos
first_metric_only : bool, optional (default=False)
Whether to use only the first metric for early stopping.
verbose : bool, optional (default=True)
Whether to print message with early stopping information.
Whether to log message with early stopping information.
By default, standard output resource is used.
Use ``register_logger()`` function to register a custom logger.
Returns
-------

Просмотреть файл

@ -35,7 +35,7 @@ def train(
categorical_feature: Union[List[str], List[int], str] = 'auto',
early_stopping_rounds: Optional[int] = None,
evals_result: Optional[Dict[str, Any]] = None,
verbose_eval: Union[bool, int] = True,
verbose_eval: Union[bool, int, str] = 'warn',
learning_rates: Optional[Union[List[float], Callable[[int], float]]] = None,
keep_training_booster: bool = False,
callbacks: Optional[List[Callable]] = None
@ -121,7 +121,7 @@ def train(
To check only the first metric, set the ``first_metric_only`` parameter to ``True`` in ``params``.
The index of iteration that has the best performance will be saved in the ``best_iteration`` field
if early stopping logic is enabled by setting ``early_stopping_rounds``.
evals_result: dict or None, optional (default=None)
evals_result : dict or None, optional (default=None)
Dictionary used to store all evaluation results of all the items in ``valid_sets``.
This should be initialized outside of your call to ``train()`` and should be empty.
Any initial contents of the dictionary will be deleted.
@ -176,10 +176,13 @@ def train(
num_boost_round = params.pop(alias)
_log_warning(f"Found `{alias}` in params. Will use it instead of argument")
params["num_iterations"] = num_boost_round
# show deprecation warning only for early stop argument, setting early stop via global params should still be possible
if early_stopping_rounds is not None and early_stopping_rounds > 0:
_log_warning("'early_stopping_rounds' argument is deprecated and will be removed in a future release of LightGBM. "
"Pass 'early_stopping()' callback via 'callbacks' argument instead.")
for alias in _ConfigAliases.get("early_stopping_round"):
if alias in params:
early_stopping_rounds = params.pop(alias)
_log_warning(f"Found `{alias}` in params. Will use it instead of argument")
params["early_stopping_round"] = early_stopping_rounds
first_metric_only = params.get('first_metric_only', False)
@ -233,6 +236,14 @@ def train(
callbacks = set(callbacks)
# Most of legacy advanced options becomes callbacks
if verbose_eval != "warn":
_log_warning("'verbose_eval' argument is deprecated and will be removed in a future release of LightGBM. "
"Pass 'print_evaluation()' callback via 'callbacks' argument instead.")
else:
if callbacks: # assume user has already specified print_evaluation callback
verbose_eval = False
else:
verbose_eval = True
if verbose_eval is True:
callbacks.add(callback.print_evaluation())
elif isinstance(verbose_eval, int):
@ -242,9 +253,13 @@ def train(
callbacks.add(callback.early_stopping(early_stopping_rounds, first_metric_only, verbose=bool(verbose_eval)))
if learning_rates is not None:
_log_warning("'learning_rates' argument is deprecated and will be removed in a future release of LightGBM. "
"Pass 'reset_parameter()' callback via 'callbacks' argument instead.")
callbacks.add(callback.reset_parameter(learning_rate=learning_rates))
if evals_result is not None:
_log_warning("'evals_result' argument is deprecated and will be removed in a future release of LightGBM. "
"Pass 'record_evaluation()' callback via 'callbacks' argument instead.")
callbacks.add(callback.record_evaluation(evals_result))
callbacks_before_iter = {cb for cb in callbacks if getattr(cb, 'before_iteration', False)}
@ -520,7 +535,6 @@ def cv(params, train_set, num_boost_round=100,
and returns transformed versions of those.
verbose_eval : bool, int, or None, optional (default=None)
Whether to display the progress.
If None, progress will be displayed when np.ndarray is returned.
If True, progress will be displayed at every boosting stage.
If int, progress will be displayed at every given ``verbose_eval`` boosting stage.
show_stdv : bool, optional (default=True)
@ -560,9 +574,11 @@ def cv(params, train_set, num_boost_round=100,
_log_warning(f"Found `{alias}` in params. Will use it instead of argument")
num_boost_round = params.pop(alias)
params["num_iterations"] = num_boost_round
if early_stopping_rounds is not None and early_stopping_rounds > 0:
_log_warning("'early_stopping_rounds' argument is deprecated and will be removed in a future release of LightGBM. "
"Pass 'early_stopping()' callback via 'callbacks' argument instead.")
for alias in _ConfigAliases.get("early_stopping_round"):
if alias in params:
_log_warning(f"Found `{alias}` in params. Will use it instead of argument")
early_stopping_rounds = params.pop(alias)
params["early_stopping_round"] = early_stopping_rounds
first_metric_only = params.get('first_metric_only', False)
@ -601,6 +617,9 @@ def cv(params, train_set, num_boost_round=100,
callbacks = set(callbacks)
if early_stopping_rounds is not None and early_stopping_rounds > 0:
callbacks.add(callback.early_stopping(early_stopping_rounds, first_metric_only, verbose=False))
if verbose_eval is not None:
_log_warning("'verbose_eval' argument is deprecated and will be removed in a future release of LightGBM. "
"Pass 'print_evaluation()' callback via 'callbacks' argument instead.")
if verbose_eval is True:
callbacks.add(callback.print_evaluation(show_stdv=show_stdv))
elif isinstance(verbose_eval, int):

Просмотреть файл

@ -7,6 +7,7 @@ from typing import Callable, Dict, Optional, Union
import numpy as np
from .basic import Dataset, LightGBMError, _choose_param_value, _ConfigAliases, _log_warning
from .callback import print_evaluation, record_evaluation
from .compat import (SKLEARN_INSTALLED, LGBMNotFittedError, _LGBMAssertAllFinite, _LGBMCheckArray,
_LGBMCheckClassificationTargets, _LGBMCheckSampleWeight, _LGBMCheckXY, _LGBMClassifierBase,
_LGBMComputeSampleWeight, _LGBMLabelEncoder, _LGBMModelBase, _LGBMRegressorBase, dt_DataTable,
@ -570,7 +571,7 @@ class LGBMModel(_LGBMModelBase):
sample_weight=None, init_score=None, group=None,
eval_set=None, eval_names=None, eval_sample_weight=None,
eval_class_weight=None, eval_init_score=None, eval_group=None,
eval_metric=None, early_stopping_rounds=None, verbose=True,
eval_metric=None, early_stopping_rounds=None, verbose='warn',
feature_name='auto', categorical_feature='auto',
callbacks=None, init_model=None):
"""Docstring is set after definition, using a template."""
@ -587,7 +588,7 @@ class LGBMModel(_LGBMModelBase):
self._fobj = _ObjectiveFunctionWrapper(self._objective)
else:
self._fobj = None
evals_result = {}
params = self.get_params()
# user can set verbose with kwargs, it has higher priority
if self.silent != "warn":
@ -718,18 +719,51 @@ class LGBMModel(_LGBMModelBase):
if isinstance(init_model, LGBMModel):
init_model = init_model.booster_
self._Booster = train(params, train_set,
self.n_estimators, valid_sets=valid_sets, valid_names=eval_names,
early_stopping_rounds=early_stopping_rounds,
evals_result=evals_result, fobj=self._fobj, feval=eval_metrics_callable,
verbose_eval=verbose, feature_name=feature_name,
callbacks=callbacks, init_model=init_model)
if early_stopping_rounds is not None and early_stopping_rounds > 0:
_log_warning("'early_stopping_rounds' argument is deprecated and will be removed in a future release of LightGBM. "
"Pass 'early_stopping()' callback via 'callbacks' argument instead.")
params['early_stopping_rounds'] = early_stopping_rounds
if callbacks is None:
callbacks = []
else:
callbacks = copy.deepcopy(callbacks)
if verbose != 'warn':
_log_warning("'verbose' argument is deprecated and will be removed in a future release of LightGBM. "
"Pass 'print_evaluation()' callback via 'callbacks' argument instead.")
else:
if callbacks: # assume user has already specified print_evaluation callback
verbose = False
else:
verbose = True
callbacks.append(print_evaluation(int(verbose)))
evals_result = {}
callbacks.append(record_evaluation(evals_result))
self._Booster = train(
params=params,
train_set=train_set,
num_boost_round=self.n_estimators,
valid_sets=valid_sets,
valid_names=eval_names,
fobj=self._fobj,
feval=eval_metrics_callable,
init_model=init_model,
feature_name=feature_name,
callbacks=callbacks
)
if evals_result:
self._evals_result = evals_result
else: # reset after previous call to fit()
self._evals_result = None
if early_stopping_rounds is not None and early_stopping_rounds > 0:
if self._Booster.best_iteration != 0:
self._best_iteration = self._Booster.best_iteration
else: # reset after previous call to fit()
self._best_iteration = None
self._best_score = self._Booster.best_score
@ -791,16 +825,16 @@ class LGBMModel(_LGBMModelBase):
@property
def best_score_(self):
""":obj:`dict` or :obj:`None`: The best score of fitted model."""
""":obj:`dict`: The best score of fitted model."""
if self._n_features is None:
raise LGBMNotFittedError('No best_score found. Need to call fit beforehand.')
return self._best_score
@property
def best_iteration_(self):
""":obj:`int` or :obj:`None`: The best iteration of fitted model if ``early_stopping_rounds`` has been specified."""
""":obj:`int` or :obj:`None`: The best iteration of fitted model if ``early_stopping()`` callback has been specified."""
if self._n_features is None:
raise LGBMNotFittedError('No best_iteration found. Need to call fit with early_stopping_rounds beforehand.')
raise LGBMNotFittedError('No best_iteration found. Need to call fit with early_stopping callback beforehand.')
return self._best_iteration
@property
@ -819,7 +853,7 @@ class LGBMModel(_LGBMModelBase):
@property
def evals_result_(self):
""":obj:`dict` or :obj:`None`: The evaluation results if ``early_stopping_rounds`` has been specified."""
""":obj:`dict` or :obj:`None`: The evaluation results if validation sets have been specified."""
if self._n_features is None:
raise LGBMNotFittedError('No results found. Need to call fit with eval_set beforehand.')
return self._evals_result
@ -852,7 +886,7 @@ class LGBMRegressor(_LGBMRegressorBase, LGBMModel):
sample_weight=None, init_score=None,
eval_set=None, eval_names=None, eval_sample_weight=None,
eval_init_score=None, eval_metric=None, early_stopping_rounds=None,
verbose=True, feature_name='auto', categorical_feature='auto',
verbose='warn', feature_name='auto', categorical_feature='auto',
callbacks=None, init_model=None):
"""Docstring is inherited from the LGBMModel."""
super().fit(X, y, sample_weight=sample_weight, init_score=init_score,
@ -878,7 +912,7 @@ class LGBMClassifier(_LGBMClassifierBase, LGBMModel):
sample_weight=None, init_score=None,
eval_set=None, eval_names=None, eval_sample_weight=None,
eval_class_weight=None, eval_init_score=None, eval_metric=None,
early_stopping_rounds=None, verbose=True,
early_stopping_rounds=None, verbose='warn',
feature_name='auto', categorical_feature='auto',
callbacks=None, init_model=None):
"""Docstring is inherited from the LGBMModel."""
@ -1006,7 +1040,7 @@ class LGBMRanker(LGBMModel):
sample_weight=None, init_score=None, group=None,
eval_set=None, eval_names=None, eval_sample_weight=None,
eval_init_score=None, eval_group=None, eval_metric=None,
eval_at=(1, 2, 3, 4, 5), early_stopping_rounds=None, verbose=True,
eval_at=(1, 2, 3, 4, 5), early_stopping_rounds=None, verbose='warn',
feature_name='auto', categorical_feature='auto',
callbacks=None, init_model=None):
"""Docstring is inherited from the LGBMModel."""

Просмотреть файл

@ -31,10 +31,14 @@ def test_register_logger(tmp_path):
lgb_data = lgb.Dataset(X, y)
eval_records = {}
callbacks = [
lgb.record_evaluation(eval_records),
lgb.print_evaluation(2),
lgb.early_stopping(4)
]
lgb.train({'objective': 'binary', 'metric': ['auc', 'binary_error']},
lgb_data, num_boost_round=10, feval=dummy_metric,
valid_sets=[lgb_data], evals_result=eval_records,
categorical_feature=[1], early_stopping_rounds=4, verbose_eval=2)
valid_sets=[lgb_data], categorical_feature=[1], callbacks=callbacks)
lgb.plot_metric(eval_records)