[python] make `early_stopping` callback pickleable (#5012)

* Turn `early_stopping` into a Callable class

* Fix

* Lint

* Remove print

* Fix order

* Revert "Lint"

This reverts commit 7ca8b55757.

* Apply suggestion from code review

* Nit

* Lint

* Move callable class outside the func for pickling

* Move _pickle and _unpickle to tests utils

* Add early stopping callback picklability test

* Nit

* Fix

* Lint

* Improve type hint

* Lint

* Lint

* Add cloudpickle to test_windows

* Update tests/python_package_test/test_engine.py

* Fix

* Apply suggestions from code review
This commit is contained in:
Antoni Baum 2022-03-17 05:03:53 +01:00 коммит произвёл GitHub
Родитель eb686a7658
Коммит f77e0adf59
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
5 изменённых файлов: 193 добавлений и 170 удалений

Просмотреть файл

@ -50,7 +50,7 @@ if ($env:TASK -eq "swig") {
Exit 0
}
conda install -q -y -n $env:CONDA_ENV joblib matplotlib numpy pandas psutil pytest scikit-learn scipy ; Check-Output $?
conda install -q -y -n $env:CONDA_ENV cloudpickle joblib matplotlib numpy pandas psutil pytest scikit-learn scipy ; Check-Output $?
# python-graphviz has to be installed separately to prevent conda from downgrading to pypy
conda install -q -y -n $env:CONDA_ENV libxml2 python-graphviz ; Check-Output $?

Просмотреть файл

@ -12,14 +12,6 @@ _EvalResultTuple = Union[
]
def _gt_delta(curr_score: float, best_score: float, delta: float) -> bool:
return curr_score > best_score + delta
def _lt_delta(curr_score: float, best_score: float, delta: float) -> bool:
return curr_score < best_score - delta
class EarlyStopException(Exception):
"""Exception of early stopping."""
@ -199,7 +191,136 @@ def reset_parameter(**kwargs: Union[list, Callable]) -> Callable:
return _callback
def early_stopping(stopping_rounds: int, first_metric_only: bool = False, verbose: bool = True, min_delta: Union[float, List[float]] = 0.0) -> Callable:
class _EarlyStoppingCallback:
"""Internal early stopping callable class."""
def __init__(
self,
stopping_rounds: int,
first_metric_only: bool = False,
verbose: bool = True,
min_delta: Union[float, List[float]] = 0.0
) -> None:
self.order = 30
self.before_iteration = False
self.stopping_rounds = stopping_rounds
self.first_metric_only = first_metric_only
self.verbose = verbose
self.min_delta = min_delta
self.enabled = True
self._reset_storages()
def _reset_storages(self) -> None:
self.best_score = []
self.best_iter = []
self.best_score_list = []
self.cmp_op = []
self.first_metric = ''
def _gt_delta(self, curr_score: float, best_score: float, delta: float) -> bool:
return curr_score > best_score + delta
def _lt_delta(self, curr_score: float, best_score: float, delta: float) -> bool:
return curr_score < best_score - delta
def _init(self, env: CallbackEnv) -> None:
self.enabled = not any(env.params.get(boost_alias, "") == 'dart' for boost_alias
in _ConfigAliases.get("boosting"))
if not self.enabled:
_log_warning('Early stopping is not available in dart mode')
return
if not env.evaluation_result_list:
raise ValueError('For early stopping, '
'at least one dataset and eval metric is required for evaluation')
if self.stopping_rounds <= 0:
raise ValueError("stopping_rounds should be greater than zero.")
if self.verbose:
_log_info(f"Training until validation scores don't improve for {self.stopping_rounds} rounds")
self._reset_storages()
n_metrics = len(set(m[1] for m in env.evaluation_result_list))
n_datasets = len(env.evaluation_result_list) // n_metrics
if isinstance(self.min_delta, list):
if not all(t >= 0 for t in self.min_delta):
raise ValueError('Values for early stopping min_delta must be non-negative.')
if len(self.min_delta) == 0:
if self.verbose:
_log_info('Disabling min_delta for early stopping.')
deltas = [0.0] * n_datasets * n_metrics
elif len(self.min_delta) == 1:
if self.verbose:
_log_info(f'Using {self.min_delta[0]} as min_delta for all metrics.')
deltas = self.min_delta * n_datasets * n_metrics
else:
if len(self.min_delta) != n_metrics:
raise ValueError('Must provide a single value for min_delta or as many as metrics.')
if self.first_metric_only and self.verbose:
_log_info(f'Using only {self.min_delta[0]} as early stopping min_delta.')
deltas = self.min_delta * n_datasets
else:
if self.min_delta < 0:
raise ValueError('Early stopping min_delta must be non-negative.')
if self.min_delta > 0 and n_metrics > 1 and not self.first_metric_only and self.verbose:
_log_info(f'Using {self.min_delta} as min_delta for all metrics.')
deltas = [self.min_delta] * n_datasets * n_metrics
# split is needed for "<dataset type> <metric>" case (e.g. "train l1")
self.first_metric = env.evaluation_result_list[0][1].split(" ")[-1]
for eval_ret, delta in zip(env.evaluation_result_list, deltas):
self.best_iter.append(0)
self.best_score_list.append(None)
if eval_ret[3]: # greater is better
self.best_score.append(float('-inf'))
self.cmp_op.append(partial(self._gt_delta, delta=delta))
else:
self.best_score.append(float('inf'))
self.cmp_op.append(partial(self._lt_delta, delta=delta))
def _final_iteration_check(self, env: CallbackEnv, eval_name_splitted: List[str], i: int) -> None:
if env.iteration == env.end_iteration - 1:
if self.verbose:
best_score_str = '\t'.join([_format_eval_result(x) for x in self.best_score_list[i]])
_log_info('Did not meet early stopping. '
f'Best iteration is:\n[{self.best_iter[i] + 1}]\t{best_score_str}')
if self.first_metric_only:
_log_info(f"Evaluated only: {eval_name_splitted[-1]}")
raise EarlyStopException(self.best_iter[i], self.best_score_list[i])
def __call__(self, env: CallbackEnv) -> None:
if env.iteration == env.begin_iteration:
self._init(env)
if not self.enabled:
return
for i in range(len(env.evaluation_result_list)):
score = env.evaluation_result_list[i][2]
if self.best_score_list[i] is None or self.cmp_op[i](score, self.best_score[i]):
self.best_score[i] = score
self.best_iter[i] = env.iteration
self.best_score_list[i] = env.evaluation_result_list
# split is needed for "<dataset type> <metric>" case (e.g. "train l1")
eval_name_splitted = env.evaluation_result_list[i][1].split(" ")
if self.first_metric_only and self.first_metric != eval_name_splitted[-1]:
continue # use only the first metric for early stopping
if ((env.evaluation_result_list[i][0] == "cv_agg" and eval_name_splitted[0] == "train"
or env.evaluation_result_list[i][0] == env.model._train_data_name)):
self._final_iteration_check(env, eval_name_splitted, i)
continue # train data for lgb.cv or sklearn wrapper (underlying lgb.train)
elif env.iteration - self.best_iter[i] >= self.stopping_rounds:
if self.verbose:
eval_result_str = '\t'.join([_format_eval_result(x) for x in self.best_score_list[i]])
_log_info(f"Early stopping, best iteration is:\n[{self.best_iter[i] + 1}]\t{eval_result_str}")
if self.first_metric_only:
_log_info(f"Evaluated only: {eval_name_splitted[-1]}")
raise EarlyStopException(self.best_iter[i], self.best_score_list[i])
self._final_iteration_check(env, eval_name_splitted, i)
def early_stopping(stopping_rounds: int, first_metric_only: bool = False, verbose: bool = True, min_delta: Union[float, List[float]] = 0.0) -> _EarlyStoppingCallback:
"""Create a callback that activates early stopping.
Activates early stopping.
@ -228,127 +349,7 @@ def early_stopping(stopping_rounds: int, first_metric_only: bool = False, verbos
Returns
-------
callback : callable
callback : _EarlyStoppingCallback
The callback that activates early stopping.
"""
best_score = []
best_iter = []
best_score_list: list = []
cmp_op = []
enabled = True
first_metric = ''
def _init(env: CallbackEnv) -> None:
nonlocal best_score
nonlocal best_iter
nonlocal best_score_list
nonlocal cmp_op
nonlocal enabled
nonlocal first_metric
enabled = not any(env.params.get(boost_alias, "") == 'dart' for boost_alias
in _ConfigAliases.get("boosting"))
if not enabled:
_log_warning('Early stopping is not available in dart mode')
return
if not env.evaluation_result_list:
raise ValueError('For early stopping, '
'at least one dataset and eval metric is required for evaluation')
if stopping_rounds <= 0:
raise ValueError("stopping_rounds should be greater than zero.")
if verbose:
_log_info(f"Training until validation scores don't improve for {stopping_rounds} rounds")
# reset storages
best_score = []
best_iter = []
best_score_list = []
cmp_op = []
first_metric = ''
n_metrics = len(set(m[1] for m in env.evaluation_result_list))
n_datasets = len(env.evaluation_result_list) // n_metrics
if isinstance(min_delta, list):
if not all(t >= 0 for t in min_delta):
raise ValueError('Values for early stopping min_delta must be non-negative.')
if len(min_delta) == 0:
if verbose:
_log_info('Disabling min_delta for early stopping.')
deltas = [0.0] * n_datasets * n_metrics
elif len(min_delta) == 1:
if verbose:
_log_info(f'Using {min_delta[0]} as min_delta for all metrics.')
deltas = min_delta * n_datasets * n_metrics
else:
if len(min_delta) != n_metrics:
raise ValueError('Must provide a single value for min_delta or as many as metrics.')
if first_metric_only and verbose:
_log_info(f'Using only {min_delta[0]} as early stopping min_delta.')
deltas = min_delta * n_datasets
else:
if min_delta < 0:
raise ValueError('Early stopping min_delta must be non-negative.')
if min_delta > 0 and n_metrics > 1 and not first_metric_only and verbose:
_log_info(f'Using {min_delta} as min_delta for all metrics.')
deltas = [min_delta] * n_datasets * n_metrics
# split is needed for "<dataset type> <metric>" case (e.g. "train l1")
first_metric = env.evaluation_result_list[0][1].split(" ")[-1]
for eval_ret, delta in zip(env.evaluation_result_list, deltas):
best_iter.append(0)
best_score_list.append(None)
if eval_ret[3]: # greater is better
best_score.append(float('-inf'))
cmp_op.append(partial(_gt_delta, delta=delta))
else:
best_score.append(float('inf'))
cmp_op.append(partial(_lt_delta, delta=delta))
def _final_iteration_check(env: CallbackEnv, eval_name_splitted: List[str], i: int) -> None:
nonlocal best_iter
nonlocal best_score_list
if env.iteration == env.end_iteration - 1:
if verbose:
best_score_str = '\t'.join([_format_eval_result(x) for x in best_score_list[i]])
_log_info('Did not meet early stopping. '
f'Best iteration is:\n[{best_iter[i] + 1}]\t{best_score_str}')
if first_metric_only:
_log_info(f"Evaluated only: {eval_name_splitted[-1]}")
raise EarlyStopException(best_iter[i], best_score_list[i])
def _callback(env: CallbackEnv) -> None:
nonlocal best_score
nonlocal best_iter
nonlocal best_score_list
nonlocal cmp_op
nonlocal enabled
nonlocal first_metric
if env.iteration == env.begin_iteration:
_init(env)
if not enabled:
return
for i in range(len(env.evaluation_result_list)):
score = env.evaluation_result_list[i][2]
if best_score_list[i] is None or cmp_op[i](score, best_score[i]):
best_score[i] = score
best_iter[i] = env.iteration
best_score_list[i] = env.evaluation_result_list
# split is needed for "<dataset type> <metric>" case (e.g. "train l1")
eval_name_splitted = env.evaluation_result_list[i][1].split(" ")
if first_metric_only and first_metric != eval_name_splitted[-1]:
continue # use only the first metric for early stopping
if ((env.evaluation_result_list[i][0] == "cv_agg" and eval_name_splitted[0] == "train"
or env.evaluation_result_list[i][0] == env.model._train_data_name)):
_final_iteration_check(env, eval_name_splitted, i)
continue # train data for lgb.cv or sklearn wrapper (underlying lgb.train)
elif env.iteration - best_iter[i] >= stopping_rounds:
if verbose:
eval_result_str = '\t'.join([_format_eval_result(x) for x in best_score_list[i]])
_log_info(f"Early stopping, best iteration is:\n[{best_iter[i] + 1}]\t{eval_result_str}")
if first_metric_only:
_log_info(f"Evaluated only: {eval_name_splitted[-1]}")
raise EarlyStopException(best_iter[i], best_score_list[i])
_final_iteration_check(env, eval_name_splitted, i)
_callback.order = 30 # type: ignore
return _callback
return _EarlyStoppingCallback(stopping_rounds=stopping_rounds, first_metric_only=first_metric_only, verbose=verbose, min_delta=min_delta)

Просмотреть файл

@ -0,0 +1,22 @@
# coding: utf-8
import pytest
import lightgbm as lgb
from .utils import pickle_obj, unpickle_obj
@pytest.mark.parametrize('serializer', ["pickle", "joblib", "cloudpickle"])
def test_early_stopping_callback_is_picklable(serializer, tmp_path):
callback = lgb.early_stopping(stopping_rounds=5)
tmp_file = tmp_path / "early_stopping.pkl"
pickle_obj(
obj=callback,
filepath=tmp_file,
serializer=serializer
)
callback_from_disk = unpickle_obj(
filepath=tmp_file,
serializer=serializer
)
assert callback.stopping_rounds == callback_from_disk.stopping_rounds

Просмотреть файл

@ -2,7 +2,6 @@
"""Tests for lightgbm.dask module"""
import inspect
import pickle
import random
import socket
from itertools import groupby
@ -24,10 +23,8 @@ if machine() != 'x86_64':
if not lgb.compat.DASK_INSTALLED:
pytest.skip('Dask is not installed', allow_module_level=True)
import cloudpickle
import dask.array as da
import dask.dataframe as dd
import joblib
import numpy as np
import pandas as pd
import sklearn.utils.estimator_checks as sklearn_checks
@ -37,7 +34,7 @@ from scipy.sparse import csc_matrix, csr_matrix
from scipy.stats import spearmanr
from sklearn.datasets import make_blobs, make_regression
from .utils import make_ranking
from .utils import make_ranking, pickle_obj, unpickle_obj
tasks = ['binary-classification', 'multiclass-classification', 'regression', 'ranking']
distributed_training_algorithms = ['data', 'voting']
@ -234,32 +231,6 @@ def _constant_metric(y_true, y_pred):
return metric_name, value, is_higher_better
def _pickle(obj, filepath, serializer):
if serializer == 'pickle':
with open(filepath, 'wb') as f:
pickle.dump(obj, f)
elif serializer == 'joblib':
joblib.dump(obj, filepath)
elif serializer == 'cloudpickle':
with open(filepath, 'wb') as f:
cloudpickle.dump(obj, f)
else:
raise ValueError(f'Unrecognized serializer type: {serializer}')
def _unpickle(filepath, serializer):
if serializer == 'pickle':
with open(filepath, 'rb') as f:
return pickle.load(f)
elif serializer == 'joblib':
return joblib.load(filepath)
elif serializer == 'cloudpickle':
with open(filepath, 'rb') as f:
return cloudpickle.load(f)
else:
raise ValueError(f'Unrecognized serializer type: {serializer}')
def _objective_least_squares(y_true, y_pred):
grad = y_pred - y_true
hess = np.ones(len(y_true))
@ -1341,23 +1312,23 @@ def test_model_and_local_version_are_picklable_whether_or_not_client_set_explici
assert getattr(local_model, "client", None) is None
tmp_file = tmp_path / "model-1.pkl"
_pickle(
pickle_obj(
obj=dask_model,
filepath=tmp_file,
serializer=serializer
)
model_from_disk = _unpickle(
model_from_disk = unpickle_obj(
filepath=tmp_file,
serializer=serializer
)
local_tmp_file = tmp_path / "local-model-1.pkl"
_pickle(
pickle_obj(
obj=local_model,
filepath=local_tmp_file,
serializer=serializer
)
local_model_from_disk = _unpickle(
local_model_from_disk = unpickle_obj(
filepath=local_tmp_file,
serializer=serializer
)
@ -1397,23 +1368,23 @@ def test_model_and_local_version_are_picklable_whether_or_not_client_set_explici
local_model.client_
tmp_file2 = tmp_path / "model-2.pkl"
_pickle(
pickle_obj(
obj=dask_model,
filepath=tmp_file2,
serializer=serializer
)
fitted_model_from_disk = _unpickle(
fitted_model_from_disk = unpickle_obj(
filepath=tmp_file2,
serializer=serializer
)
local_tmp_file2 = tmp_path / "local-model-2.pkl"
_pickle(
pickle_obj(
obj=local_model,
filepath=local_tmp_file2,
serializer=serializer
)
local_fitted_model_from_disk = _unpickle(
local_fitted_model_from_disk = unpickle_obj(
filepath=local_tmp_file2,
serializer=serializer
)

Просмотреть файл

@ -1,6 +1,9 @@
# coding: utf-8
import pickle
from functools import lru_cache
import cloudpickle
import joblib
import numpy as np
import sklearn.datasets
from sklearn.utils import check_random_state
@ -131,3 +134,29 @@ def sklearn_multiclass_custom_objective(y_true, y_pred):
factor = num_class / (num_class - 1)
hess = factor * prob * (1 - prob)
return grad, hess
def pickle_obj(obj, filepath, serializer):
if serializer == 'pickle':
with open(filepath, 'wb') as f:
pickle.dump(obj, f)
elif serializer == 'joblib':
joblib.dump(obj, filepath)
elif serializer == 'cloudpickle':
with open(filepath, 'wb') as f:
cloudpickle.dump(obj, f)
else:
raise ValueError(f'Unrecognized serializer type: {serializer}')
def unpickle_obj(filepath, serializer):
if serializer == 'pickle':
with open(filepath, 'rb') as f:
return pickle.load(f)
elif serializer == 'joblib':
return joblib.load(filepath)
elif serializer == 'cloudpickle':
with open(filepath, 'rb') as f:
return cloudpickle.load(f)
else:
raise ValueError(f'Unrecognized serializer type: {serializer}')