[python-package] add a few type hints in LGBMModel.fit() (#6470)

This commit is contained in:
James Lamb 2024-06-05 07:54:16 -05:00 коммит произвёл GitHub
Родитель 8579d5e34f
Коммит 4401401553
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
1 изменённых файлов: 44 добавлений и 15 удалений

Просмотреть файл

@ -454,6 +454,30 @@ _lgbmmodel_doc_predict = """
"""
def _extract_evaluation_meta_data(
*,
collection: Optional[Union[Dict[Any, Any], List[Any]]],
name: str,
i: int,
) -> Optional[Any]:
"""Try to extract the ith element of one of the ``eval_*`` inputs."""
if collection is None:
return None
elif isinstance(collection, list):
# It's possible, for example, to pass 3 eval sets through `eval_set`,
# but only 1 init_score through `eval_init_score`.
#
# This if-else accounts for that possiblity.
if len(collection) > i:
return collection[i]
else:
return None
elif isinstance(collection, dict):
return collection.get(i, None)
else:
raise TypeError(f"{name} should be dict or list")
class LGBMModel(_LGBMModelBase):
"""Implementation of the scikit-learn API for LightGBM."""
@ -869,17 +893,6 @@ class LGBMModel(_LGBMModelBase):
valid_sets: List[Dataset] = []
if eval_set is not None:
def _get_meta_data(collection, name, i):
if collection is None:
return None
elif isinstance(collection, list):
return collection[i] if len(collection) > i else None
elif isinstance(collection, dict):
return collection.get(i, None)
else:
raise TypeError(f"{name} should be dict or list")
if isinstance(eval_set, tuple):
eval_set = [eval_set]
for i, valid_data in enumerate(eval_set):
@ -887,8 +900,16 @@ class LGBMModel(_LGBMModelBase):
if valid_data[0] is X and valid_data[1] is y:
valid_set = train_set
else:
valid_weight = _get_meta_data(eval_sample_weight, "eval_sample_weight", i)
valid_class_weight = _get_meta_data(eval_class_weight, "eval_class_weight", i)
valid_weight = _extract_evaluation_meta_data(
collection=eval_sample_weight,
name="eval_sample_weight",
i=i,
)
valid_class_weight = _extract_evaluation_meta_data(
collection=eval_class_weight,
name="eval_class_weight",
i=i,
)
if valid_class_weight is not None:
if isinstance(valid_class_weight, dict) and self._class_map is not None:
valid_class_weight = {self._class_map[k]: v for k, v in valid_class_weight.items()}
@ -897,8 +918,16 @@ class LGBMModel(_LGBMModelBase):
valid_weight = valid_class_sample_weight
else:
valid_weight = np.multiply(valid_weight, valid_class_sample_weight)
valid_init_score = _get_meta_data(eval_init_score, "eval_init_score", i)
valid_group = _get_meta_data(eval_group, "eval_group", i)
valid_init_score = _extract_evaluation_meta_data(
collection=eval_init_score,
name="eval_init_score",
i=i,
)
valid_group = _extract_evaluation_meta_data(
collection=eval_group,
name="eval_group",
i=i,
)
valid_set = Dataset(
data=valid_data[0],
label=valid_data[1],