зеркало из https://github.com/microsoft/LightGBM.git
[python] make `log_evaluation` callback pickleable (#5101)
* make `log_evaluation` callback pickleable * make callback tests stricter
This commit is contained in:
Родитель
417c732cc0
Коммит
8b33e776cc
|
@ -54,7 +54,23 @@ def _format_eval_result(value: _EvalResultTuple, show_stdv: bool = True) -> str:
|
|||
raise ValueError("Wrong metric value")
|
||||
|
||||
|
||||
def log_evaluation(period: int = 1, show_stdv: bool = True) -> Callable:
|
||||
class _LogEvaluationCallback:
|
||||
"""Internal log evaluation callable class."""
|
||||
|
||||
def __init__(self, period: int = 1, show_stdv: bool = True) -> None:
|
||||
self.order = 10
|
||||
self.before_iteration = False
|
||||
|
||||
self.period = period
|
||||
self.show_stdv = show_stdv
|
||||
|
||||
def __call__(self, env: CallbackEnv) -> None:
|
||||
if self.period > 0 and env.evaluation_result_list and (env.iteration + 1) % self.period == 0:
|
||||
result = '\t'.join([_format_eval_result(x, self.show_stdv) for x in env.evaluation_result_list])
|
||||
_log_info(f'[{env.iteration + 1}]\t{result}')
|
||||
|
||||
|
||||
def log_evaluation(period: int = 1, show_stdv: bool = True) -> _LogEvaluationCallback:
|
||||
"""Create a callback that logs the evaluation results.
|
||||
|
||||
By default, standard output resource is used.
|
||||
|
@ -74,15 +90,10 @@ def log_evaluation(period: int = 1, show_stdv: bool = True) -> Callable:
|
|||
|
||||
Returns
|
||||
-------
|
||||
callback : callable
|
||||
callback : _LogEvaluationCallback
|
||||
The callback that logs the evaluation results every ``period`` boosting iteration(s).
|
||||
"""
|
||||
def _callback(env: CallbackEnv) -> None:
|
||||
if period > 0 and env.evaluation_result_list and (env.iteration + 1) % period == 0:
|
||||
result = '\t'.join([_format_eval_result(x, show_stdv) for x in env.evaluation_result_list])
|
||||
_log_info(f'[{env.iteration + 1}]\t{result}')
|
||||
_callback.order = 10 # type: ignore
|
||||
return _callback
|
||||
return _LogEvaluationCallback(period=period, show_stdv=show_stdv)
|
||||
|
||||
|
||||
def record_evaluation(eval_result: Dict[str, Dict[str, List[Any]]]) -> Callable:
|
||||
|
|
|
@ -8,7 +8,8 @@ from .utils import pickle_obj, unpickle_obj
|
|||
|
||||
@pytest.mark.parametrize('serializer', ["pickle", "joblib", "cloudpickle"])
|
||||
def test_early_stopping_callback_is_picklable(serializer, tmp_path):
|
||||
callback = lgb.early_stopping(stopping_rounds=5)
|
||||
rounds = 5
|
||||
callback = lgb.early_stopping(stopping_rounds=rounds)
|
||||
tmp_file = tmp_path / "early_stopping.pkl"
|
||||
pickle_obj(
|
||||
obj=callback,
|
||||
|
@ -20,3 +21,22 @@ def test_early_stopping_callback_is_picklable(serializer, tmp_path):
|
|||
serializer=serializer
|
||||
)
|
||||
assert callback.stopping_rounds == callback_from_disk.stopping_rounds
|
||||
assert callback.stopping_rounds == rounds
|
||||
|
||||
|
||||
@pytest.mark.parametrize('serializer', ["pickle", "joblib", "cloudpickle"])
|
||||
def test_log_evaluation_callback_is_picklable(serializer, tmp_path):
|
||||
periods = 42
|
||||
callback = lgb.log_evaluation(period=periods)
|
||||
tmp_file = tmp_path / "log_evaluation.pkl"
|
||||
pickle_obj(
|
||||
obj=callback,
|
||||
filepath=tmp_file,
|
||||
serializer=serializer
|
||||
)
|
||||
callback_from_disk = unpickle_obj(
|
||||
filepath=tmp_file,
|
||||
serializer=serializer
|
||||
)
|
||||
assert callback.period == callback_from_disk.period
|
||||
assert callback.period == periods
|
||||
|
|
Загрузка…
Ссылка в новой задаче