[python] add type hints for custom objective and metric functions in scikit-learn interface (#4547)

* [python] add type hints for custom objective and metric functions in scikit-learn interface

* update type hints

* remote unnecessary input

* Update python-package/lightgbm/sklearn.py

Co-authored-by: Nikita Titov <nekit94-08@mail.ru>

* remove type hint on objective being callable

Co-authored-by: Nikita Titov <nekit94-08@mail.ru>
This commit is contained in:
James Lamb 2021-11-15 15:05:42 -05:00 коммит произвёл GitHub
Родитель bfb346c13e
Коммит 843d380d6b
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
2 изменённых файлов: 44 добавлений и 16 удалений

Просмотреть файл

@ -11,7 +11,7 @@ from collections import defaultdict, namedtuple
from copy import deepcopy from copy import deepcopy
from enum import Enum, auto from enum import Enum, auto
from functools import partial from functools import partial
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type, Union from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, Union
from urllib.parse import urlparse from urllib.parse import urlparse
import numpy as np import numpy as np
@ -21,8 +21,8 @@ from .basic import _LIB, LightGBMError, _choose_param_value, _ConfigAliases, _lo
from .compat import (DASK_INSTALLED, PANDAS_INSTALLED, SKLEARN_INSTALLED, Client, LGBMNotFittedError, concat, from .compat import (DASK_INSTALLED, PANDAS_INSTALLED, SKLEARN_INSTALLED, Client, LGBMNotFittedError, concat,
dask_Array, dask_array_from_delayed, dask_bag_from_delayed, dask_DataFrame, dask_Series, dask_Array, dask_array_from_delayed, dask_bag_from_delayed, dask_DataFrame, dask_Series,
default_client, delayed, pd_DataFrame, pd_Series, wait) default_client, delayed, pd_DataFrame, pd_Series, wait)
from .sklearn import (LGBMClassifier, LGBMModel, LGBMRanker, LGBMRegressor, _lgbmmodel_doc_custom_eval_note, from .sklearn import (LGBMClassifier, LGBMModel, LGBMRanker, LGBMRegressor, _LGBM_ScikitCustomEvalFunction,
_lgbmmodel_doc_fit, _lgbmmodel_doc_predict) _lgbmmodel_doc_custom_eval_note, _lgbmmodel_doc_fit, _lgbmmodel_doc_predict)
_DaskCollection = Union[dask_Array, dask_DataFrame, dask_Series] _DaskCollection = Union[dask_Array, dask_DataFrame, dask_Series]
_DaskMatrixLike = Union[dask_Array, dask_DataFrame] _DaskMatrixLike = Union[dask_Array, dask_DataFrame]
@ -400,7 +400,7 @@ def _train(
eval_class_weight: Optional[List[Union[dict, str]]] = None, eval_class_weight: Optional[List[Union[dict, str]]] = None,
eval_init_score: Optional[List[_DaskCollection]] = None, eval_init_score: Optional[List[_DaskCollection]] = None,
eval_group: Optional[List[_DaskVectorLike]] = None, eval_group: Optional[List[_DaskVectorLike]] = None,
eval_metric: Optional[Union[Callable, str, List[Union[Callable, str]]]] = None, eval_metric: Optional[Union[_LGBM_ScikitCustomEvalFunction, str, List[Union[_LGBM_ScikitCustomEvalFunction, str]]]] = None,
eval_at: Optional[Iterable[int]] = None, eval_at: Optional[Iterable[int]] = None,
**kwargs: Any **kwargs: Any
) -> LGBMModel: ) -> LGBMModel:
@ -1029,7 +1029,7 @@ class _DaskLGBMModel:
eval_class_weight: Optional[List[Union[dict, str]]] = None, eval_class_weight: Optional[List[Union[dict, str]]] = None,
eval_init_score: Optional[List[_DaskCollection]] = None, eval_init_score: Optional[List[_DaskCollection]] = None,
eval_group: Optional[List[_DaskVectorLike]] = None, eval_group: Optional[List[_DaskVectorLike]] = None,
eval_metric: Optional[Union[Callable, str, List[Union[Callable, str]]]] = None, eval_metric: Optional[Union[_LGBM_ScikitCustomEvalFunction, str, List[Union[_LGBM_ScikitCustomEvalFunction, str]]]] = None,
eval_at: Optional[Iterable[int]] = None, eval_at: Optional[Iterable[int]] = None,
early_stopping_rounds: Optional[int] = None, early_stopping_rounds: Optional[int] = None,
**kwargs: Any **kwargs: Any
@ -1096,7 +1096,7 @@ class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel):
learning_rate: float = 0.1, learning_rate: float = 0.1,
n_estimators: int = 100, n_estimators: int = 100,
subsample_for_bin: int = 200000, subsample_for_bin: int = 200000,
objective: Optional[Union[Callable, str]] = None, objective: Optional[str] = None,
class_weight: Optional[Union[dict, str]] = None, class_weight: Optional[Union[dict, str]] = None,
min_split_gain: float = 0., min_split_gain: float = 0.,
min_child_weight: float = 1e-3, min_child_weight: float = 1e-3,
@ -1165,7 +1165,7 @@ class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel):
eval_sample_weight: Optional[List[_DaskVectorLike]] = None, eval_sample_weight: Optional[List[_DaskVectorLike]] = None,
eval_class_weight: Optional[List[Union[dict, str]]] = None, eval_class_weight: Optional[List[Union[dict, str]]] = None,
eval_init_score: Optional[List[_DaskCollection]] = None, eval_init_score: Optional[List[_DaskCollection]] = None,
eval_metric: Optional[Union[Callable, str, List[Union[Callable, str]]]] = None, eval_metric: Optional[Union[_LGBM_ScikitCustomEvalFunction, str, List[Union[_LGBM_ScikitCustomEvalFunction, str]]]] = None,
early_stopping_rounds: Optional[int] = None, early_stopping_rounds: Optional[int] = None,
**kwargs: Any **kwargs: Any
) -> "DaskLGBMClassifier": ) -> "DaskLGBMClassifier":
@ -1281,7 +1281,7 @@ class DaskLGBMRegressor(LGBMRegressor, _DaskLGBMModel):
learning_rate: float = 0.1, learning_rate: float = 0.1,
n_estimators: int = 100, n_estimators: int = 100,
subsample_for_bin: int = 200000, subsample_for_bin: int = 200000,
objective: Optional[Union[Callable, str]] = None, objective: Optional[str] = None,
class_weight: Optional[Union[dict, str]] = None, class_weight: Optional[Union[dict, str]] = None,
min_split_gain: float = 0., min_split_gain: float = 0.,
min_child_weight: float = 1e-3, min_child_weight: float = 1e-3,
@ -1348,7 +1348,7 @@ class DaskLGBMRegressor(LGBMRegressor, _DaskLGBMModel):
eval_names: Optional[List[str]] = None, eval_names: Optional[List[str]] = None,
eval_sample_weight: Optional[List[_DaskVectorLike]] = None, eval_sample_weight: Optional[List[_DaskVectorLike]] = None,
eval_init_score: Optional[List[_DaskVectorLike]] = None, eval_init_score: Optional[List[_DaskVectorLike]] = None,
eval_metric: Optional[Union[Callable, str, List[Union[Callable, str]]]] = None, eval_metric: Optional[Union[_LGBM_ScikitCustomEvalFunction, str, List[Union[_LGBM_ScikitCustomEvalFunction, str]]]] = None,
early_stopping_rounds: Optional[int] = None, early_stopping_rounds: Optional[int] = None,
**kwargs: Any **kwargs: Any
) -> "DaskLGBMRegressor": ) -> "DaskLGBMRegressor":
@ -1446,7 +1446,7 @@ class DaskLGBMRanker(LGBMRanker, _DaskLGBMModel):
learning_rate: float = 0.1, learning_rate: float = 0.1,
n_estimators: int = 100, n_estimators: int = 100,
subsample_for_bin: int = 200000, subsample_for_bin: int = 200000,
objective: Optional[Union[Callable, str]] = None, objective: Optional[str] = None,
class_weight: Optional[Union[dict, str]] = None, class_weight: Optional[Union[dict, str]] = None,
min_split_gain: float = 0., min_split_gain: float = 0.,
min_child_weight: float = 1e-3, min_child_weight: float = 1e-3,
@ -1516,7 +1516,7 @@ class DaskLGBMRanker(LGBMRanker, _DaskLGBMModel):
eval_sample_weight: Optional[List[_DaskVectorLike]] = None, eval_sample_weight: Optional[List[_DaskVectorLike]] = None,
eval_init_score: Optional[List[_DaskVectorLike]] = None, eval_init_score: Optional[List[_DaskVectorLike]] = None,
eval_group: Optional[List[_DaskVectorLike]] = None, eval_group: Optional[List[_DaskVectorLike]] = None,
eval_metric: Optional[Union[Callable, str, List[Union[Callable, str]]]] = None, eval_metric: Optional[Union[_LGBM_ScikitCustomEvalFunction, str, List[Union[_LGBM_ScikitCustomEvalFunction, str]]]] = None,
eval_at: Iterable[int] = (1, 2, 3, 4, 5), eval_at: Iterable[int] = (1, 2, 3, 4, 5),
early_stopping_rounds: Optional[int] = None, early_stopping_rounds: Optional[int] = None,
**kwargs: Any **kwargs: Any

Просмотреть файл

@ -2,7 +2,7 @@
"""Scikit-learn wrapper interface for LightGBM.""" """Scikit-learn wrapper interface for LightGBM."""
import copy import copy
from inspect import signature from inspect import signature
from typing import Callable, Dict, Optional, Union from typing import Callable, Dict, List, Optional, Tuple, Union
import numpy as np import numpy as np
@ -11,14 +11,42 @@ from .callback import log_evaluation, record_evaluation
from .compat import (SKLEARN_INSTALLED, LGBMNotFittedError, _LGBMAssertAllFinite, _LGBMCheckArray, from .compat import (SKLEARN_INSTALLED, LGBMNotFittedError, _LGBMAssertAllFinite, _LGBMCheckArray,
_LGBMCheckClassificationTargets, _LGBMCheckSampleWeight, _LGBMCheckXY, _LGBMClassifierBase, _LGBMCheckClassificationTargets, _LGBMCheckSampleWeight, _LGBMCheckXY, _LGBMClassifierBase,
_LGBMComputeSampleWeight, _LGBMLabelEncoder, _LGBMModelBase, _LGBMRegressorBase, dt_DataTable, _LGBMComputeSampleWeight, _LGBMLabelEncoder, _LGBMModelBase, _LGBMRegressorBase, dt_DataTable,
pd_DataFrame) pd_DataFrame, pd_Series)
from .engine import train from .engine import train
_ArrayLike = Union[List, np.ndarray, pd_Series]
_EvalResultType = Tuple[str, float, bool]
_LGBM_ScikitCustomObjectiveFunction = Union[
Callable[
[np.ndarray, np.ndarray],
Tuple[_ArrayLike, _ArrayLike]
],
Callable[
[np.ndarray, np.ndarray, np.ndarray],
Tuple[_ArrayLike, _ArrayLike]
],
]
_LGBM_ScikitCustomEvalFunction = Union[
Callable[
[np.ndarray, np.ndarray],
Union[_EvalResultType, List[_EvalResultType]]
],
Callable[
[np.ndarray, np.ndarray, np.ndarray],
Union[_EvalResultType, List[_EvalResultType]]
],
Callable[
[np.ndarray, np.ndarray, np.ndarray, np.ndarray],
Union[_EvalResultType, List[_EvalResultType]]
],
]
class _ObjectiveFunctionWrapper: class _ObjectiveFunctionWrapper:
"""Proxy class for objective function.""" """Proxy class for objective function."""
def __init__(self, func): def __init__(self, func: _LGBM_ScikitCustomObjectiveFunction):
"""Construct a proxy class. """Construct a proxy class.
This class transforms objective function to match objective function with signature ``new_func(preds, dataset)`` This class transforms objective function to match objective function with signature ``new_func(preds, dataset)``
@ -107,7 +135,7 @@ class _ObjectiveFunctionWrapper:
class _EvalFunctionWrapper: class _EvalFunctionWrapper:
"""Proxy class for evaluation function.""" """Proxy class for evaluation function."""
def __init__(self, func): def __init__(self, func: _LGBM_ScikitCustomEvalFunction):
"""Construct a proxy class. """Construct a proxy class.
This class transforms evaluation function to match evaluation function with signature ``new_func(preds, dataset)`` This class transforms evaluation function to match evaluation function with signature ``new_func(preds, dataset)``
@ -358,7 +386,7 @@ class LGBMModel(_LGBMModelBase):
learning_rate: float = 0.1, learning_rate: float = 0.1,
n_estimators: int = 100, n_estimators: int = 100,
subsample_for_bin: int = 200000, subsample_for_bin: int = 200000,
objective: Optional[Union[str, Callable]] = None, objective: Optional[Union[str, _LGBM_ScikitCustomObjectiveFunction]] = None,
class_weight: Optional[Union[Dict, str]] = None, class_weight: Optional[Union[Dict, str]] = None,
min_split_gain: float = 0., min_split_gain: float = 0.,
min_child_weight: float = 1e-3, min_child_weight: float = 1e-3,