зеркало из https://github.com/microsoft/LightGBM.git
[python-package] use dataclass for CallbackEnv (#6048)
This commit is contained in:
Родитель
5fe84f8f3b
Коммит
4ea170f30a
|
@ -7,6 +7,7 @@
|
|||
#
|
||||
echo "installing lightgbm's dependencies"
|
||||
pip install \
|
||||
'dataclasses' \
|
||||
'numpy==1.12.0' \
|
||||
'pandas==0.24.0' \
|
||||
'scikit-learn==0.18.2' \
|
||||
|
|
|
@ -1,10 +1,14 @@
|
|||
# coding: utf-8
|
||||
"""Callbacks library."""
|
||||
import collections
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Dict, List, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from .basic import _ConfigAliases, _LGBM_BoosterEvalMethodResultType, _log_info, _log_warning
|
||||
from .basic import Booster, _ConfigAliases, _LGBM_BoosterEvalMethodResultType, _log_info, _log_warning
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .engine import CVBooster
|
||||
|
||||
__all__ = [
|
||||
'early_stopping',
|
||||
|
@ -43,14 +47,14 @@ class EarlyStopException(Exception):
|
|||
|
||||
|
||||
# Callback environment used by callbacks
|
||||
CallbackEnv = collections.namedtuple(
|
||||
"CallbackEnv",
|
||||
["model",
|
||||
"params",
|
||||
"iteration",
|
||||
"begin_iteration",
|
||||
"end_iteration",
|
||||
"evaluation_result_list"])
|
||||
@dataclass
|
||||
class CallbackEnv:
|
||||
model: Union[Booster, "CVBooster"]
|
||||
params: Dict[str, Any]
|
||||
iteration: int
|
||||
begin_iteration: int
|
||||
end_iteration: int
|
||||
evaluation_result_list: Optional[List[_LGBM_BoosterEvalMethodResultType]]
|
||||
|
||||
|
||||
def _format_eval_result(value: _EvalResultTuple, show_stdv: bool) -> str:
|
||||
|
@ -126,7 +130,7 @@ class _RecordEvaluationCallback:
|
|||
data_name, eval_name = item[:2]
|
||||
else: # cv
|
||||
data_name, eval_name = item[1].split()
|
||||
self.eval_result.setdefault(data_name, collections.OrderedDict())
|
||||
self.eval_result.setdefault(data_name, OrderedDict())
|
||||
if len(item) == 4:
|
||||
self.eval_result[data_name].setdefault(eval_name, [])
|
||||
else:
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
# coding: utf-8
|
||||
"""Library with training routines of LightGBM."""
|
||||
import collections
|
||||
import copy
|
||||
import json
|
||||
from collections import OrderedDict, defaultdict
|
||||
from operator import attrgetter
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
|
||||
|
@ -293,7 +293,7 @@ def train(
|
|||
booster.best_iteration = earlyStopException.best_iteration + 1
|
||||
evaluation_result_list = earlyStopException.best_score
|
||||
break
|
||||
booster.best_score = collections.defaultdict(collections.OrderedDict)
|
||||
booster.best_score = defaultdict(OrderedDict)
|
||||
for dataset_name, eval_name, score, _ in evaluation_result_list:
|
||||
booster.best_score[dataset_name][eval_name] = score
|
||||
if not keep_training_booster:
|
||||
|
@ -526,7 +526,7 @@ def _agg_cv_result(
|
|||
raw_results: List[List[Tuple[str, str, float, bool]]]
|
||||
) -> List[Tuple[str, str, float, bool, float]]:
|
||||
"""Aggregate cross-validation results."""
|
||||
cvmap: Dict[str, List[float]] = collections.OrderedDict()
|
||||
cvmap: Dict[str, List[float]] = OrderedDict()
|
||||
metric_type: Dict[str, bool] = {}
|
||||
for one_result in raw_results:
|
||||
for one_line in one_result:
|
||||
|
@ -717,7 +717,7 @@ def cv(
|
|||
.set_feature_name(feature_name) \
|
||||
.set_categorical_feature(categorical_feature)
|
||||
|
||||
results = collections.defaultdict(list)
|
||||
results = defaultdict(list)
|
||||
cvfolds = _make_n_folds(full_data=train_set, folds=folds, nfold=nfold,
|
||||
params=params, seed=seed, fpreproc=fpreproc,
|
||||
stratified=stratified, shuffle=shuffle,
|
||||
|
|
|
@ -18,6 +18,7 @@ classifiers = [
|
|||
"Topic :: Scientific/Engineering :: Artificial Intelligence"
|
||||
]
|
||||
dependencies = [
|
||||
"dataclasses ; python_version < '3.7'",
|
||||
"numpy",
|
||||
"scipy"
|
||||
]
|
||||
|
|
Загрузка…
Ссылка в новой задаче