[python] make `log_evaluation` callback pickleable (#5101)

* make `log_evaluation` callback pickleable

* make callback tests stricter
This commit is contained in:
Nikita Titov 2022-03-30 21:52:46 +03:00 коммит произвёл GitHub
Родитель 417c732cc0
Коммит 8b33e776cc
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
2 изменённых файлов: 40 добавлений и 9 удалений

Просмотреть файл

@ -54,7 +54,23 @@ def _format_eval_result(value: _EvalResultTuple, show_stdv: bool = True) -> str:
raise ValueError("Wrong metric value")
def log_evaluation(period: int = 1, show_stdv: bool = True) -> Callable:
class _LogEvaluationCallback:
"""Internal log evaluation callable class."""
def __init__(self, period: int = 1, show_stdv: bool = True) -> None:
self.order = 10
self.before_iteration = False
self.period = period
self.show_stdv = show_stdv
def __call__(self, env: CallbackEnv) -> None:
if self.period > 0 and env.evaluation_result_list and (env.iteration + 1) % self.period == 0:
result = '\t'.join([_format_eval_result(x, self.show_stdv) for x in env.evaluation_result_list])
_log_info(f'[{env.iteration + 1}]\t{result}')
def log_evaluation(period: int = 1, show_stdv: bool = True) -> _LogEvaluationCallback:
"""Create a callback that logs the evaluation results.
By default, standard output resource is used.
@ -74,15 +90,10 @@ def log_evaluation(period: int = 1, show_stdv: bool = True) -> Callable:
Returns
-------
callback : callable
callback : _LogEvaluationCallback
The callback that logs the evaluation results every ``period`` boosting iteration(s).
"""
def _callback(env: CallbackEnv) -> None:
if period > 0 and env.evaluation_result_list and (env.iteration + 1) % period == 0:
result = '\t'.join([_format_eval_result(x, show_stdv) for x in env.evaluation_result_list])
_log_info(f'[{env.iteration + 1}]\t{result}')
_callback.order = 10 # type: ignore
return _callback
return _LogEvaluationCallback(period=period, show_stdv=show_stdv)
def record_evaluation(eval_result: Dict[str, Dict[str, List[Any]]]) -> Callable:

Просмотреть файл

@ -8,7 +8,8 @@ from .utils import pickle_obj, unpickle_obj
@pytest.mark.parametrize('serializer', ["pickle", "joblib", "cloudpickle"])
def test_early_stopping_callback_is_picklable(serializer, tmp_path):
callback = lgb.early_stopping(stopping_rounds=5)
rounds = 5
callback = lgb.early_stopping(stopping_rounds=rounds)
tmp_file = tmp_path / "early_stopping.pkl"
pickle_obj(
obj=callback,
@ -20,3 +21,22 @@ def test_early_stopping_callback_is_picklable(serializer, tmp_path):
serializer=serializer
)
assert callback.stopping_rounds == callback_from_disk.stopping_rounds
assert callback.stopping_rounds == rounds
@pytest.mark.parametrize('serializer', ["pickle", "joblib", "cloudpickle"])
def test_log_evaluation_callback_is_picklable(serializer, tmp_path):
periods = 42
callback = lgb.log_evaluation(period=periods)
tmp_file = tmp_path / "log_evaluation.pkl"
pickle_obj(
obj=callback,
filepath=tmp_file,
serializer=serializer
)
callback_from_disk = unpickle_obj(
filepath=tmp_file,
serializer=serializer
)
assert callback.period == callback_from_disk.period
assert callback.period == periods