зеркало из https://github.com/microsoft/LightGBM.git
[python] make `early_stopping` callback pickleable (#5012)
* Turn `early_stopping` into a Callable class
* Fix
* Lint
* Remove print
* Fix order
* Revert "Lint"
This reverts commit 7ca8b55757
.
* Apply suggestion from code review
* Nit
* Lint
* Move callable class outside the func for pickling
* Move _pickle and _unpickle to tests utils
* Add early stopping callback picklability test
* Nit
* Fix
* Lint
* Improve type hint
* Lint
* Lint
* Add cloudpickle to test_windows
* Update tests/python_package_test/test_engine.py
* Fix
* Apply suggestions from code review
This commit is contained in:
Родитель
eb686a7658
Коммит
f77e0adf59
|
@ -50,7 +50,7 @@ if ($env:TASK -eq "swig") {
|
|||
Exit 0
|
||||
}
|
||||
|
||||
conda install -q -y -n $env:CONDA_ENV joblib matplotlib numpy pandas psutil pytest scikit-learn scipy ; Check-Output $?
|
||||
conda install -q -y -n $env:CONDA_ENV cloudpickle joblib matplotlib numpy pandas psutil pytest scikit-learn scipy ; Check-Output $?
|
||||
# python-graphviz has to be installed separately to prevent conda from downgrading to pypy
|
||||
conda install -q -y -n $env:CONDA_ENV libxml2 python-graphviz ; Check-Output $?
|
||||
|
||||
|
|
|
@ -12,14 +12,6 @@ _EvalResultTuple = Union[
|
|||
]
|
||||
|
||||
|
||||
def _gt_delta(curr_score: float, best_score: float, delta: float) -> bool:
|
||||
return curr_score > best_score + delta
|
||||
|
||||
|
||||
def _lt_delta(curr_score: float, best_score: float, delta: float) -> bool:
|
||||
return curr_score < best_score - delta
|
||||
|
||||
|
||||
class EarlyStopException(Exception):
|
||||
"""Exception of early stopping."""
|
||||
|
||||
|
@ -199,7 +191,136 @@ def reset_parameter(**kwargs: Union[list, Callable]) -> Callable:
|
|||
return _callback
|
||||
|
||||
|
||||
def early_stopping(stopping_rounds: int, first_metric_only: bool = False, verbose: bool = True, min_delta: Union[float, List[float]] = 0.0) -> Callable:
|
||||
class _EarlyStoppingCallback:
|
||||
"""Internal early stopping callable class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
stopping_rounds: int,
|
||||
first_metric_only: bool = False,
|
||||
verbose: bool = True,
|
||||
min_delta: Union[float, List[float]] = 0.0
|
||||
) -> None:
|
||||
self.order = 30
|
||||
self.before_iteration = False
|
||||
|
||||
self.stopping_rounds = stopping_rounds
|
||||
self.first_metric_only = first_metric_only
|
||||
self.verbose = verbose
|
||||
self.min_delta = min_delta
|
||||
|
||||
self.enabled = True
|
||||
self._reset_storages()
|
||||
|
||||
def _reset_storages(self) -> None:
|
||||
self.best_score = []
|
||||
self.best_iter = []
|
||||
self.best_score_list = []
|
||||
self.cmp_op = []
|
||||
self.first_metric = ''
|
||||
|
||||
def _gt_delta(self, curr_score: float, best_score: float, delta: float) -> bool:
|
||||
return curr_score > best_score + delta
|
||||
|
||||
def _lt_delta(self, curr_score: float, best_score: float, delta: float) -> bool:
|
||||
return curr_score < best_score - delta
|
||||
|
||||
def _init(self, env: CallbackEnv) -> None:
|
||||
self.enabled = not any(env.params.get(boost_alias, "") == 'dart' for boost_alias
|
||||
in _ConfigAliases.get("boosting"))
|
||||
if not self.enabled:
|
||||
_log_warning('Early stopping is not available in dart mode')
|
||||
return
|
||||
if not env.evaluation_result_list:
|
||||
raise ValueError('For early stopping, '
|
||||
'at least one dataset and eval metric is required for evaluation')
|
||||
|
||||
if self.stopping_rounds <= 0:
|
||||
raise ValueError("stopping_rounds should be greater than zero.")
|
||||
|
||||
if self.verbose:
|
||||
_log_info(f"Training until validation scores don't improve for {self.stopping_rounds} rounds")
|
||||
|
||||
self._reset_storages()
|
||||
|
||||
n_metrics = len(set(m[1] for m in env.evaluation_result_list))
|
||||
n_datasets = len(env.evaluation_result_list) // n_metrics
|
||||
if isinstance(self.min_delta, list):
|
||||
if not all(t >= 0 for t in self.min_delta):
|
||||
raise ValueError('Values for early stopping min_delta must be non-negative.')
|
||||
if len(self.min_delta) == 0:
|
||||
if self.verbose:
|
||||
_log_info('Disabling min_delta for early stopping.')
|
||||
deltas = [0.0] * n_datasets * n_metrics
|
||||
elif len(self.min_delta) == 1:
|
||||
if self.verbose:
|
||||
_log_info(f'Using {self.min_delta[0]} as min_delta for all metrics.')
|
||||
deltas = self.min_delta * n_datasets * n_metrics
|
||||
else:
|
||||
if len(self.min_delta) != n_metrics:
|
||||
raise ValueError('Must provide a single value for min_delta or as many as metrics.')
|
||||
if self.first_metric_only and self.verbose:
|
||||
_log_info(f'Using only {self.min_delta[0]} as early stopping min_delta.')
|
||||
deltas = self.min_delta * n_datasets
|
||||
else:
|
||||
if self.min_delta < 0:
|
||||
raise ValueError('Early stopping min_delta must be non-negative.')
|
||||
if self.min_delta > 0 and n_metrics > 1 and not self.first_metric_only and self.verbose:
|
||||
_log_info(f'Using {self.min_delta} as min_delta for all metrics.')
|
||||
deltas = [self.min_delta] * n_datasets * n_metrics
|
||||
|
||||
# split is needed for "<dataset type> <metric>" case (e.g. "train l1")
|
||||
self.first_metric = env.evaluation_result_list[0][1].split(" ")[-1]
|
||||
for eval_ret, delta in zip(env.evaluation_result_list, deltas):
|
||||
self.best_iter.append(0)
|
||||
self.best_score_list.append(None)
|
||||
if eval_ret[3]: # greater is better
|
||||
self.best_score.append(float('-inf'))
|
||||
self.cmp_op.append(partial(self._gt_delta, delta=delta))
|
||||
else:
|
||||
self.best_score.append(float('inf'))
|
||||
self.cmp_op.append(partial(self._lt_delta, delta=delta))
|
||||
|
||||
def _final_iteration_check(self, env: CallbackEnv, eval_name_splitted: List[str], i: int) -> None:
|
||||
if env.iteration == env.end_iteration - 1:
|
||||
if self.verbose:
|
||||
best_score_str = '\t'.join([_format_eval_result(x) for x in self.best_score_list[i]])
|
||||
_log_info('Did not meet early stopping. '
|
||||
f'Best iteration is:\n[{self.best_iter[i] + 1}]\t{best_score_str}')
|
||||
if self.first_metric_only:
|
||||
_log_info(f"Evaluated only: {eval_name_splitted[-1]}")
|
||||
raise EarlyStopException(self.best_iter[i], self.best_score_list[i])
|
||||
|
||||
def __call__(self, env: CallbackEnv) -> None:
|
||||
if env.iteration == env.begin_iteration:
|
||||
self._init(env)
|
||||
if not self.enabled:
|
||||
return
|
||||
for i in range(len(env.evaluation_result_list)):
|
||||
score = env.evaluation_result_list[i][2]
|
||||
if self.best_score_list[i] is None or self.cmp_op[i](score, self.best_score[i]):
|
||||
self.best_score[i] = score
|
||||
self.best_iter[i] = env.iteration
|
||||
self.best_score_list[i] = env.evaluation_result_list
|
||||
# split is needed for "<dataset type> <metric>" case (e.g. "train l1")
|
||||
eval_name_splitted = env.evaluation_result_list[i][1].split(" ")
|
||||
if self.first_metric_only and self.first_metric != eval_name_splitted[-1]:
|
||||
continue # use only the first metric for early stopping
|
||||
if ((env.evaluation_result_list[i][0] == "cv_agg" and eval_name_splitted[0] == "train"
|
||||
or env.evaluation_result_list[i][0] == env.model._train_data_name)):
|
||||
self._final_iteration_check(env, eval_name_splitted, i)
|
||||
continue # train data for lgb.cv or sklearn wrapper (underlying lgb.train)
|
||||
elif env.iteration - self.best_iter[i] >= self.stopping_rounds:
|
||||
if self.verbose:
|
||||
eval_result_str = '\t'.join([_format_eval_result(x) for x in self.best_score_list[i]])
|
||||
_log_info(f"Early stopping, best iteration is:\n[{self.best_iter[i] + 1}]\t{eval_result_str}")
|
||||
if self.first_metric_only:
|
||||
_log_info(f"Evaluated only: {eval_name_splitted[-1]}")
|
||||
raise EarlyStopException(self.best_iter[i], self.best_score_list[i])
|
||||
self._final_iteration_check(env, eval_name_splitted, i)
|
||||
|
||||
|
||||
def early_stopping(stopping_rounds: int, first_metric_only: bool = False, verbose: bool = True, min_delta: Union[float, List[float]] = 0.0) -> _EarlyStoppingCallback:
|
||||
"""Create a callback that activates early stopping.
|
||||
|
||||
Activates early stopping.
|
||||
|
@ -228,127 +349,7 @@ def early_stopping(stopping_rounds: int, first_metric_only: bool = False, verbos
|
|||
|
||||
Returns
|
||||
-------
|
||||
callback : callable
|
||||
callback : _EarlyStoppingCallback
|
||||
The callback that activates early stopping.
|
||||
"""
|
||||
best_score = []
|
||||
best_iter = []
|
||||
best_score_list: list = []
|
||||
cmp_op = []
|
||||
enabled = True
|
||||
first_metric = ''
|
||||
|
||||
def _init(env: CallbackEnv) -> None:
|
||||
nonlocal best_score
|
||||
nonlocal best_iter
|
||||
nonlocal best_score_list
|
||||
nonlocal cmp_op
|
||||
nonlocal enabled
|
||||
nonlocal first_metric
|
||||
enabled = not any(env.params.get(boost_alias, "") == 'dart' for boost_alias
|
||||
in _ConfigAliases.get("boosting"))
|
||||
if not enabled:
|
||||
_log_warning('Early stopping is not available in dart mode')
|
||||
return
|
||||
if not env.evaluation_result_list:
|
||||
raise ValueError('For early stopping, '
|
||||
'at least one dataset and eval metric is required for evaluation')
|
||||
|
||||
if stopping_rounds <= 0:
|
||||
raise ValueError("stopping_rounds should be greater than zero.")
|
||||
|
||||
if verbose:
|
||||
_log_info(f"Training until validation scores don't improve for {stopping_rounds} rounds")
|
||||
|
||||
# reset storages
|
||||
best_score = []
|
||||
best_iter = []
|
||||
best_score_list = []
|
||||
cmp_op = []
|
||||
first_metric = ''
|
||||
|
||||
n_metrics = len(set(m[1] for m in env.evaluation_result_list))
|
||||
n_datasets = len(env.evaluation_result_list) // n_metrics
|
||||
if isinstance(min_delta, list):
|
||||
if not all(t >= 0 for t in min_delta):
|
||||
raise ValueError('Values for early stopping min_delta must be non-negative.')
|
||||
if len(min_delta) == 0:
|
||||
if verbose:
|
||||
_log_info('Disabling min_delta for early stopping.')
|
||||
deltas = [0.0] * n_datasets * n_metrics
|
||||
elif len(min_delta) == 1:
|
||||
if verbose:
|
||||
_log_info(f'Using {min_delta[0]} as min_delta for all metrics.')
|
||||
deltas = min_delta * n_datasets * n_metrics
|
||||
else:
|
||||
if len(min_delta) != n_metrics:
|
||||
raise ValueError('Must provide a single value for min_delta or as many as metrics.')
|
||||
if first_metric_only and verbose:
|
||||
_log_info(f'Using only {min_delta[0]} as early stopping min_delta.')
|
||||
deltas = min_delta * n_datasets
|
||||
else:
|
||||
if min_delta < 0:
|
||||
raise ValueError('Early stopping min_delta must be non-negative.')
|
||||
if min_delta > 0 and n_metrics > 1 and not first_metric_only and verbose:
|
||||
_log_info(f'Using {min_delta} as min_delta for all metrics.')
|
||||
deltas = [min_delta] * n_datasets * n_metrics
|
||||
|
||||
# split is needed for "<dataset type> <metric>" case (e.g. "train l1")
|
||||
first_metric = env.evaluation_result_list[0][1].split(" ")[-1]
|
||||
for eval_ret, delta in zip(env.evaluation_result_list, deltas):
|
||||
best_iter.append(0)
|
||||
best_score_list.append(None)
|
||||
if eval_ret[3]: # greater is better
|
||||
best_score.append(float('-inf'))
|
||||
cmp_op.append(partial(_gt_delta, delta=delta))
|
||||
else:
|
||||
best_score.append(float('inf'))
|
||||
cmp_op.append(partial(_lt_delta, delta=delta))
|
||||
|
||||
def _final_iteration_check(env: CallbackEnv, eval_name_splitted: List[str], i: int) -> None:
|
||||
nonlocal best_iter
|
||||
nonlocal best_score_list
|
||||
if env.iteration == env.end_iteration - 1:
|
||||
if verbose:
|
||||
best_score_str = '\t'.join([_format_eval_result(x) for x in best_score_list[i]])
|
||||
_log_info('Did not meet early stopping. '
|
||||
f'Best iteration is:\n[{best_iter[i] + 1}]\t{best_score_str}')
|
||||
if first_metric_only:
|
||||
_log_info(f"Evaluated only: {eval_name_splitted[-1]}")
|
||||
raise EarlyStopException(best_iter[i], best_score_list[i])
|
||||
|
||||
def _callback(env: CallbackEnv) -> None:
|
||||
nonlocal best_score
|
||||
nonlocal best_iter
|
||||
nonlocal best_score_list
|
||||
nonlocal cmp_op
|
||||
nonlocal enabled
|
||||
nonlocal first_metric
|
||||
if env.iteration == env.begin_iteration:
|
||||
_init(env)
|
||||
if not enabled:
|
||||
return
|
||||
for i in range(len(env.evaluation_result_list)):
|
||||
score = env.evaluation_result_list[i][2]
|
||||
if best_score_list[i] is None or cmp_op[i](score, best_score[i]):
|
||||
best_score[i] = score
|
||||
best_iter[i] = env.iteration
|
||||
best_score_list[i] = env.evaluation_result_list
|
||||
# split is needed for "<dataset type> <metric>" case (e.g. "train l1")
|
||||
eval_name_splitted = env.evaluation_result_list[i][1].split(" ")
|
||||
if first_metric_only and first_metric != eval_name_splitted[-1]:
|
||||
continue # use only the first metric for early stopping
|
||||
if ((env.evaluation_result_list[i][0] == "cv_agg" and eval_name_splitted[0] == "train"
|
||||
or env.evaluation_result_list[i][0] == env.model._train_data_name)):
|
||||
_final_iteration_check(env, eval_name_splitted, i)
|
||||
continue # train data for lgb.cv or sklearn wrapper (underlying lgb.train)
|
||||
elif env.iteration - best_iter[i] >= stopping_rounds:
|
||||
if verbose:
|
||||
eval_result_str = '\t'.join([_format_eval_result(x) for x in best_score_list[i]])
|
||||
_log_info(f"Early stopping, best iteration is:\n[{best_iter[i] + 1}]\t{eval_result_str}")
|
||||
if first_metric_only:
|
||||
_log_info(f"Evaluated only: {eval_name_splitted[-1]}")
|
||||
raise EarlyStopException(best_iter[i], best_score_list[i])
|
||||
_final_iteration_check(env, eval_name_splitted, i)
|
||||
_callback.order = 30 # type: ignore
|
||||
return _callback
|
||||
return _EarlyStoppingCallback(stopping_rounds=stopping_rounds, first_metric_only=first_metric_only, verbose=verbose, min_delta=min_delta)
|
||||
|
|
|
@ -0,0 +1,22 @@
|
|||
# coding: utf-8
|
||||
import pytest
|
||||
|
||||
import lightgbm as lgb
|
||||
|
||||
from .utils import pickle_obj, unpickle_obj
|
||||
|
||||
|
||||
@pytest.mark.parametrize('serializer', ["pickle", "joblib", "cloudpickle"])
|
||||
def test_early_stopping_callback_is_picklable(serializer, tmp_path):
|
||||
callback = lgb.early_stopping(stopping_rounds=5)
|
||||
tmp_file = tmp_path / "early_stopping.pkl"
|
||||
pickle_obj(
|
||||
obj=callback,
|
||||
filepath=tmp_file,
|
||||
serializer=serializer
|
||||
)
|
||||
callback_from_disk = unpickle_obj(
|
||||
filepath=tmp_file,
|
||||
serializer=serializer
|
||||
)
|
||||
assert callback.stopping_rounds == callback_from_disk.stopping_rounds
|
|
@ -2,7 +2,6 @@
|
|||
"""Tests for lightgbm.dask module"""
|
||||
|
||||
import inspect
|
||||
import pickle
|
||||
import random
|
||||
import socket
|
||||
from itertools import groupby
|
||||
|
@ -24,10 +23,8 @@ if machine() != 'x86_64':
|
|||
if not lgb.compat.DASK_INSTALLED:
|
||||
pytest.skip('Dask is not installed', allow_module_level=True)
|
||||
|
||||
import cloudpickle
|
||||
import dask.array as da
|
||||
import dask.dataframe as dd
|
||||
import joblib
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import sklearn.utils.estimator_checks as sklearn_checks
|
||||
|
@ -37,7 +34,7 @@ from scipy.sparse import csc_matrix, csr_matrix
|
|||
from scipy.stats import spearmanr
|
||||
from sklearn.datasets import make_blobs, make_regression
|
||||
|
||||
from .utils import make_ranking
|
||||
from .utils import make_ranking, pickle_obj, unpickle_obj
|
||||
|
||||
tasks = ['binary-classification', 'multiclass-classification', 'regression', 'ranking']
|
||||
distributed_training_algorithms = ['data', 'voting']
|
||||
|
@ -234,32 +231,6 @@ def _constant_metric(y_true, y_pred):
|
|||
return metric_name, value, is_higher_better
|
||||
|
||||
|
||||
def _pickle(obj, filepath, serializer):
|
||||
if serializer == 'pickle':
|
||||
with open(filepath, 'wb') as f:
|
||||
pickle.dump(obj, f)
|
||||
elif serializer == 'joblib':
|
||||
joblib.dump(obj, filepath)
|
||||
elif serializer == 'cloudpickle':
|
||||
with open(filepath, 'wb') as f:
|
||||
cloudpickle.dump(obj, f)
|
||||
else:
|
||||
raise ValueError(f'Unrecognized serializer type: {serializer}')
|
||||
|
||||
|
||||
def _unpickle(filepath, serializer):
|
||||
if serializer == 'pickle':
|
||||
with open(filepath, 'rb') as f:
|
||||
return pickle.load(f)
|
||||
elif serializer == 'joblib':
|
||||
return joblib.load(filepath)
|
||||
elif serializer == 'cloudpickle':
|
||||
with open(filepath, 'rb') as f:
|
||||
return cloudpickle.load(f)
|
||||
else:
|
||||
raise ValueError(f'Unrecognized serializer type: {serializer}')
|
||||
|
||||
|
||||
def _objective_least_squares(y_true, y_pred):
|
||||
grad = y_pred - y_true
|
||||
hess = np.ones(len(y_true))
|
||||
|
@ -1341,23 +1312,23 @@ def test_model_and_local_version_are_picklable_whether_or_not_client_set_explici
|
|||
assert getattr(local_model, "client", None) is None
|
||||
|
||||
tmp_file = tmp_path / "model-1.pkl"
|
||||
_pickle(
|
||||
pickle_obj(
|
||||
obj=dask_model,
|
||||
filepath=tmp_file,
|
||||
serializer=serializer
|
||||
)
|
||||
model_from_disk = _unpickle(
|
||||
model_from_disk = unpickle_obj(
|
||||
filepath=tmp_file,
|
||||
serializer=serializer
|
||||
)
|
||||
|
||||
local_tmp_file = tmp_path / "local-model-1.pkl"
|
||||
_pickle(
|
||||
pickle_obj(
|
||||
obj=local_model,
|
||||
filepath=local_tmp_file,
|
||||
serializer=serializer
|
||||
)
|
||||
local_model_from_disk = _unpickle(
|
||||
local_model_from_disk = unpickle_obj(
|
||||
filepath=local_tmp_file,
|
||||
serializer=serializer
|
||||
)
|
||||
|
@ -1397,23 +1368,23 @@ def test_model_and_local_version_are_picklable_whether_or_not_client_set_explici
|
|||
local_model.client_
|
||||
|
||||
tmp_file2 = tmp_path / "model-2.pkl"
|
||||
_pickle(
|
||||
pickle_obj(
|
||||
obj=dask_model,
|
||||
filepath=tmp_file2,
|
||||
serializer=serializer
|
||||
)
|
||||
fitted_model_from_disk = _unpickle(
|
||||
fitted_model_from_disk = unpickle_obj(
|
||||
filepath=tmp_file2,
|
||||
serializer=serializer
|
||||
)
|
||||
|
||||
local_tmp_file2 = tmp_path / "local-model-2.pkl"
|
||||
_pickle(
|
||||
pickle_obj(
|
||||
obj=local_model,
|
||||
filepath=local_tmp_file2,
|
||||
serializer=serializer
|
||||
)
|
||||
local_fitted_model_from_disk = _unpickle(
|
||||
local_fitted_model_from_disk = unpickle_obj(
|
||||
filepath=local_tmp_file2,
|
||||
serializer=serializer
|
||||
)
|
||||
|
|
|
@ -1,6 +1,9 @@
|
|||
# coding: utf-8
|
||||
import pickle
|
||||
from functools import lru_cache
|
||||
|
||||
import cloudpickle
|
||||
import joblib
|
||||
import numpy as np
|
||||
import sklearn.datasets
|
||||
from sklearn.utils import check_random_state
|
||||
|
@ -131,3 +134,29 @@ def sklearn_multiclass_custom_objective(y_true, y_pred):
|
|||
factor = num_class / (num_class - 1)
|
||||
hess = factor * prob * (1 - prob)
|
||||
return grad, hess
|
||||
|
||||
|
||||
def pickle_obj(obj, filepath, serializer):
|
||||
if serializer == 'pickle':
|
||||
with open(filepath, 'wb') as f:
|
||||
pickle.dump(obj, f)
|
||||
elif serializer == 'joblib':
|
||||
joblib.dump(obj, filepath)
|
||||
elif serializer == 'cloudpickle':
|
||||
with open(filepath, 'wb') as f:
|
||||
cloudpickle.dump(obj, f)
|
||||
else:
|
||||
raise ValueError(f'Unrecognized serializer type: {serializer}')
|
||||
|
||||
|
||||
def unpickle_obj(filepath, serializer):
|
||||
if serializer == 'pickle':
|
||||
with open(filepath, 'rb') as f:
|
||||
return pickle.load(f)
|
||||
elif serializer == 'joblib':
|
||||
return joblib.load(filepath)
|
||||
elif serializer == 'cloudpickle':
|
||||
with open(filepath, 'rb') as f:
|
||||
return cloudpickle.load(f)
|
||||
else:
|
||||
raise ValueError(f'Unrecognized serializer type: {serializer}')
|
||||
|
|
Загрузка…
Ссылка в новой задаче