зеркало из https://github.com/microsoft/LightGBM.git
[python-package] add a few type hints in LGBMModel.fit() (#6470)
This commit is contained in:
Родитель
8579d5e34f
Коммит
4401401553
|
@ -454,6 +454,30 @@ _lgbmmodel_doc_predict = """
|
|||
"""
|
||||
|
||||
|
||||
def _extract_evaluation_meta_data(
|
||||
*,
|
||||
collection: Optional[Union[Dict[Any, Any], List[Any]]],
|
||||
name: str,
|
||||
i: int,
|
||||
) -> Optional[Any]:
|
||||
"""Try to extract the ith element of one of the ``eval_*`` inputs."""
|
||||
if collection is None:
|
||||
return None
|
||||
elif isinstance(collection, list):
|
||||
# It's possible, for example, to pass 3 eval sets through `eval_set`,
|
||||
# but only 1 init_score through `eval_init_score`.
|
||||
#
|
||||
# This if-else accounts for that possiblity.
|
||||
if len(collection) > i:
|
||||
return collection[i]
|
||||
else:
|
||||
return None
|
||||
elif isinstance(collection, dict):
|
||||
return collection.get(i, None)
|
||||
else:
|
||||
raise TypeError(f"{name} should be dict or list")
|
||||
|
||||
|
||||
class LGBMModel(_LGBMModelBase):
|
||||
"""Implementation of the scikit-learn API for LightGBM."""
|
||||
|
||||
|
@ -869,17 +893,6 @@ class LGBMModel(_LGBMModelBase):
|
|||
|
||||
valid_sets: List[Dataset] = []
|
||||
if eval_set is not None:
|
||||
|
||||
def _get_meta_data(collection, name, i):
|
||||
if collection is None:
|
||||
return None
|
||||
elif isinstance(collection, list):
|
||||
return collection[i] if len(collection) > i else None
|
||||
elif isinstance(collection, dict):
|
||||
return collection.get(i, None)
|
||||
else:
|
||||
raise TypeError(f"{name} should be dict or list")
|
||||
|
||||
if isinstance(eval_set, tuple):
|
||||
eval_set = [eval_set]
|
||||
for i, valid_data in enumerate(eval_set):
|
||||
|
@ -887,8 +900,16 @@ class LGBMModel(_LGBMModelBase):
|
|||
if valid_data[0] is X and valid_data[1] is y:
|
||||
valid_set = train_set
|
||||
else:
|
||||
valid_weight = _get_meta_data(eval_sample_weight, "eval_sample_weight", i)
|
||||
valid_class_weight = _get_meta_data(eval_class_weight, "eval_class_weight", i)
|
||||
valid_weight = _extract_evaluation_meta_data(
|
||||
collection=eval_sample_weight,
|
||||
name="eval_sample_weight",
|
||||
i=i,
|
||||
)
|
||||
valid_class_weight = _extract_evaluation_meta_data(
|
||||
collection=eval_class_weight,
|
||||
name="eval_class_weight",
|
||||
i=i,
|
||||
)
|
||||
if valid_class_weight is not None:
|
||||
if isinstance(valid_class_weight, dict) and self._class_map is not None:
|
||||
valid_class_weight = {self._class_map[k]: v for k, v in valid_class_weight.items()}
|
||||
|
@ -897,8 +918,16 @@ class LGBMModel(_LGBMModelBase):
|
|||
valid_weight = valid_class_sample_weight
|
||||
else:
|
||||
valid_weight = np.multiply(valid_weight, valid_class_sample_weight)
|
||||
valid_init_score = _get_meta_data(eval_init_score, "eval_init_score", i)
|
||||
valid_group = _get_meta_data(eval_group, "eval_group", i)
|
||||
valid_init_score = _extract_evaluation_meta_data(
|
||||
collection=eval_init_score,
|
||||
name="eval_init_score",
|
||||
i=i,
|
||||
)
|
||||
valid_group = _extract_evaluation_meta_data(
|
||||
collection=eval_group,
|
||||
name="eval_group",
|
||||
i=i,
|
||||
)
|
||||
valid_set = Dataset(
|
||||
data=valid_data[0],
|
||||
label=valid_data[1],
|
||||
|
|
Загрузка…
Ссылка в новой задаче