[python-package] use dataclass for CallbackEnv (#6048)

This commit is contained in:
James Lamb 2023-08-21 12:05:37 -05:00 коммит произвёл GitHub
Родитель 5fe84f8f3b
Коммит 4ea170f30a
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
4 изменённых файлов: 22 добавлений и 16 удалений

Просмотреть файл

@ -7,6 +7,7 @@
#
echo "installing lightgbm's dependencies"
pip install \
'dataclasses' \
'numpy==1.12.0' \
'pandas==0.24.0' \
'scikit-learn==0.18.2' \

Просмотреть файл

@ -1,10 +1,14 @@
# coding: utf-8
"""Callbacks library."""
import collections
from collections import OrderedDict
from dataclasses import dataclass
from functools import partial
from typing import Any, Callable, Dict, List, Tuple, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
from .basic import _ConfigAliases, _LGBM_BoosterEvalMethodResultType, _log_info, _log_warning
from .basic import Booster, _ConfigAliases, _LGBM_BoosterEvalMethodResultType, _log_info, _log_warning
if TYPE_CHECKING:
from .engine import CVBooster
__all__ = [
'early_stopping',
@ -43,14 +47,14 @@ class EarlyStopException(Exception):
# Callback environment used by callbacks
CallbackEnv = collections.namedtuple(
"CallbackEnv",
["model",
"params",
"iteration",
"begin_iteration",
"end_iteration",
"evaluation_result_list"])
@dataclass
class CallbackEnv:
model: Union[Booster, "CVBooster"]
params: Dict[str, Any]
iteration: int
begin_iteration: int
end_iteration: int
evaluation_result_list: Optional[List[_LGBM_BoosterEvalMethodResultType]]
def _format_eval_result(value: _EvalResultTuple, show_stdv: bool) -> str:
@ -126,7 +130,7 @@ class _RecordEvaluationCallback:
data_name, eval_name = item[:2]
else: # cv
data_name, eval_name = item[1].split()
self.eval_result.setdefault(data_name, collections.OrderedDict())
self.eval_result.setdefault(data_name, OrderedDict())
if len(item) == 4:
self.eval_result[data_name].setdefault(eval_name, [])
else:

Просмотреть файл

@ -1,8 +1,8 @@
# coding: utf-8
"""Library with training routines of LightGBM."""
import collections
import copy
import json
from collections import OrderedDict, defaultdict
from operator import attrgetter
from pathlib import Path
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
@ -293,7 +293,7 @@ def train(
booster.best_iteration = earlyStopException.best_iteration + 1
evaluation_result_list = earlyStopException.best_score
break
booster.best_score = collections.defaultdict(collections.OrderedDict)
booster.best_score = defaultdict(OrderedDict)
for dataset_name, eval_name, score, _ in evaluation_result_list:
booster.best_score[dataset_name][eval_name] = score
if not keep_training_booster:
@ -526,7 +526,7 @@ def _agg_cv_result(
raw_results: List[List[Tuple[str, str, float, bool]]]
) -> List[Tuple[str, str, float, bool, float]]:
"""Aggregate cross-validation results."""
cvmap: Dict[str, List[float]] = collections.OrderedDict()
cvmap: Dict[str, List[float]] = OrderedDict()
metric_type: Dict[str, bool] = {}
for one_result in raw_results:
for one_line in one_result:
@ -717,7 +717,7 @@ def cv(
.set_feature_name(feature_name) \
.set_categorical_feature(categorical_feature)
results = collections.defaultdict(list)
results = defaultdict(list)
cvfolds = _make_n_folds(full_data=train_set, folds=folds, nfold=nfold,
params=params, seed=seed, fpreproc=fpreproc,
stratified=stratified, shuffle=shuffle,

Просмотреть файл

@ -18,6 +18,7 @@ classifiers = [
"Topic :: Scientific/Engineering :: Artificial Intelligence"
]
dependencies = [
"dataclasses ; python_version < '3.7'",
"numpy",
"scipy"
]