зеркало из https://github.com/microsoft/LightGBM.git
[python] add type hints for custom objective and metric functions in scikit-learn interface (#4547)
* [python] add type hints for custom objective and metric functions in scikit-learn interface * update type hints * remote unnecessary input * Update python-package/lightgbm/sklearn.py Co-authored-by: Nikita Titov <nekit94-08@mail.ru> * remove type hint on objective being callable Co-authored-by: Nikita Titov <nekit94-08@mail.ru>
This commit is contained in:
Родитель
bfb346c13e
Коммит
843d380d6b
|
@ -11,7 +11,7 @@ from collections import defaultdict, namedtuple
|
|||
from copy import deepcopy
|
||||
from enum import Enum, auto
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type, Union
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import numpy as np
|
||||
|
@ -21,8 +21,8 @@ from .basic import _LIB, LightGBMError, _choose_param_value, _ConfigAliases, _lo
|
|||
from .compat import (DASK_INSTALLED, PANDAS_INSTALLED, SKLEARN_INSTALLED, Client, LGBMNotFittedError, concat,
|
||||
dask_Array, dask_array_from_delayed, dask_bag_from_delayed, dask_DataFrame, dask_Series,
|
||||
default_client, delayed, pd_DataFrame, pd_Series, wait)
|
||||
from .sklearn import (LGBMClassifier, LGBMModel, LGBMRanker, LGBMRegressor, _lgbmmodel_doc_custom_eval_note,
|
||||
_lgbmmodel_doc_fit, _lgbmmodel_doc_predict)
|
||||
from .sklearn import (LGBMClassifier, LGBMModel, LGBMRanker, LGBMRegressor, _LGBM_ScikitCustomEvalFunction,
|
||||
_lgbmmodel_doc_custom_eval_note, _lgbmmodel_doc_fit, _lgbmmodel_doc_predict)
|
||||
|
||||
_DaskCollection = Union[dask_Array, dask_DataFrame, dask_Series]
|
||||
_DaskMatrixLike = Union[dask_Array, dask_DataFrame]
|
||||
|
@ -400,7 +400,7 @@ def _train(
|
|||
eval_class_weight: Optional[List[Union[dict, str]]] = None,
|
||||
eval_init_score: Optional[List[_DaskCollection]] = None,
|
||||
eval_group: Optional[List[_DaskVectorLike]] = None,
|
||||
eval_metric: Optional[Union[Callable, str, List[Union[Callable, str]]]] = None,
|
||||
eval_metric: Optional[Union[_LGBM_ScikitCustomEvalFunction, str, List[Union[_LGBM_ScikitCustomEvalFunction, str]]]] = None,
|
||||
eval_at: Optional[Iterable[int]] = None,
|
||||
**kwargs: Any
|
||||
) -> LGBMModel:
|
||||
|
@ -1029,7 +1029,7 @@ class _DaskLGBMModel:
|
|||
eval_class_weight: Optional[List[Union[dict, str]]] = None,
|
||||
eval_init_score: Optional[List[_DaskCollection]] = None,
|
||||
eval_group: Optional[List[_DaskVectorLike]] = None,
|
||||
eval_metric: Optional[Union[Callable, str, List[Union[Callable, str]]]] = None,
|
||||
eval_metric: Optional[Union[_LGBM_ScikitCustomEvalFunction, str, List[Union[_LGBM_ScikitCustomEvalFunction, str]]]] = None,
|
||||
eval_at: Optional[Iterable[int]] = None,
|
||||
early_stopping_rounds: Optional[int] = None,
|
||||
**kwargs: Any
|
||||
|
@ -1096,7 +1096,7 @@ class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel):
|
|||
learning_rate: float = 0.1,
|
||||
n_estimators: int = 100,
|
||||
subsample_for_bin: int = 200000,
|
||||
objective: Optional[Union[Callable, str]] = None,
|
||||
objective: Optional[str] = None,
|
||||
class_weight: Optional[Union[dict, str]] = None,
|
||||
min_split_gain: float = 0.,
|
||||
min_child_weight: float = 1e-3,
|
||||
|
@ -1165,7 +1165,7 @@ class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel):
|
|||
eval_sample_weight: Optional[List[_DaskVectorLike]] = None,
|
||||
eval_class_weight: Optional[List[Union[dict, str]]] = None,
|
||||
eval_init_score: Optional[List[_DaskCollection]] = None,
|
||||
eval_metric: Optional[Union[Callable, str, List[Union[Callable, str]]]] = None,
|
||||
eval_metric: Optional[Union[_LGBM_ScikitCustomEvalFunction, str, List[Union[_LGBM_ScikitCustomEvalFunction, str]]]] = None,
|
||||
early_stopping_rounds: Optional[int] = None,
|
||||
**kwargs: Any
|
||||
) -> "DaskLGBMClassifier":
|
||||
|
@ -1281,7 +1281,7 @@ class DaskLGBMRegressor(LGBMRegressor, _DaskLGBMModel):
|
|||
learning_rate: float = 0.1,
|
||||
n_estimators: int = 100,
|
||||
subsample_for_bin: int = 200000,
|
||||
objective: Optional[Union[Callable, str]] = None,
|
||||
objective: Optional[str] = None,
|
||||
class_weight: Optional[Union[dict, str]] = None,
|
||||
min_split_gain: float = 0.,
|
||||
min_child_weight: float = 1e-3,
|
||||
|
@ -1348,7 +1348,7 @@ class DaskLGBMRegressor(LGBMRegressor, _DaskLGBMModel):
|
|||
eval_names: Optional[List[str]] = None,
|
||||
eval_sample_weight: Optional[List[_DaskVectorLike]] = None,
|
||||
eval_init_score: Optional[List[_DaskVectorLike]] = None,
|
||||
eval_metric: Optional[Union[Callable, str, List[Union[Callable, str]]]] = None,
|
||||
eval_metric: Optional[Union[_LGBM_ScikitCustomEvalFunction, str, List[Union[_LGBM_ScikitCustomEvalFunction, str]]]] = None,
|
||||
early_stopping_rounds: Optional[int] = None,
|
||||
**kwargs: Any
|
||||
) -> "DaskLGBMRegressor":
|
||||
|
@ -1446,7 +1446,7 @@ class DaskLGBMRanker(LGBMRanker, _DaskLGBMModel):
|
|||
learning_rate: float = 0.1,
|
||||
n_estimators: int = 100,
|
||||
subsample_for_bin: int = 200000,
|
||||
objective: Optional[Union[Callable, str]] = None,
|
||||
objective: Optional[str] = None,
|
||||
class_weight: Optional[Union[dict, str]] = None,
|
||||
min_split_gain: float = 0.,
|
||||
min_child_weight: float = 1e-3,
|
||||
|
@ -1516,7 +1516,7 @@ class DaskLGBMRanker(LGBMRanker, _DaskLGBMModel):
|
|||
eval_sample_weight: Optional[List[_DaskVectorLike]] = None,
|
||||
eval_init_score: Optional[List[_DaskVectorLike]] = None,
|
||||
eval_group: Optional[List[_DaskVectorLike]] = None,
|
||||
eval_metric: Optional[Union[Callable, str, List[Union[Callable, str]]]] = None,
|
||||
eval_metric: Optional[Union[_LGBM_ScikitCustomEvalFunction, str, List[Union[_LGBM_ScikitCustomEvalFunction, str]]]] = None,
|
||||
eval_at: Iterable[int] = (1, 2, 3, 4, 5),
|
||||
early_stopping_rounds: Optional[int] = None,
|
||||
**kwargs: Any
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
"""Scikit-learn wrapper interface for LightGBM."""
|
||||
import copy
|
||||
from inspect import signature
|
||||
from typing import Callable, Dict, Optional, Union
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
@ -11,14 +11,42 @@ from .callback import log_evaluation, record_evaluation
|
|||
from .compat import (SKLEARN_INSTALLED, LGBMNotFittedError, _LGBMAssertAllFinite, _LGBMCheckArray,
|
||||
_LGBMCheckClassificationTargets, _LGBMCheckSampleWeight, _LGBMCheckXY, _LGBMClassifierBase,
|
||||
_LGBMComputeSampleWeight, _LGBMLabelEncoder, _LGBMModelBase, _LGBMRegressorBase, dt_DataTable,
|
||||
pd_DataFrame)
|
||||
pd_DataFrame, pd_Series)
|
||||
from .engine import train
|
||||
|
||||
_ArrayLike = Union[List, np.ndarray, pd_Series]
|
||||
_EvalResultType = Tuple[str, float, bool]
|
||||
|
||||
_LGBM_ScikitCustomObjectiveFunction = Union[
|
||||
Callable[
|
||||
[np.ndarray, np.ndarray],
|
||||
Tuple[_ArrayLike, _ArrayLike]
|
||||
],
|
||||
Callable[
|
||||
[np.ndarray, np.ndarray, np.ndarray],
|
||||
Tuple[_ArrayLike, _ArrayLike]
|
||||
],
|
||||
]
|
||||
_LGBM_ScikitCustomEvalFunction = Union[
|
||||
Callable[
|
||||
[np.ndarray, np.ndarray],
|
||||
Union[_EvalResultType, List[_EvalResultType]]
|
||||
],
|
||||
Callable[
|
||||
[np.ndarray, np.ndarray, np.ndarray],
|
||||
Union[_EvalResultType, List[_EvalResultType]]
|
||||
],
|
||||
Callable[
|
||||
[np.ndarray, np.ndarray, np.ndarray, np.ndarray],
|
||||
Union[_EvalResultType, List[_EvalResultType]]
|
||||
],
|
||||
]
|
||||
|
||||
|
||||
class _ObjectiveFunctionWrapper:
|
||||
"""Proxy class for objective function."""
|
||||
|
||||
def __init__(self, func):
|
||||
def __init__(self, func: _LGBM_ScikitCustomObjectiveFunction):
|
||||
"""Construct a proxy class.
|
||||
|
||||
This class transforms objective function to match objective function with signature ``new_func(preds, dataset)``
|
||||
|
@ -107,7 +135,7 @@ class _ObjectiveFunctionWrapper:
|
|||
class _EvalFunctionWrapper:
|
||||
"""Proxy class for evaluation function."""
|
||||
|
||||
def __init__(self, func):
|
||||
def __init__(self, func: _LGBM_ScikitCustomEvalFunction):
|
||||
"""Construct a proxy class.
|
||||
|
||||
This class transforms evaluation function to match evaluation function with signature ``new_func(preds, dataset)``
|
||||
|
@ -358,7 +386,7 @@ class LGBMModel(_LGBMModelBase):
|
|||
learning_rate: float = 0.1,
|
||||
n_estimators: int = 100,
|
||||
subsample_for_bin: int = 200000,
|
||||
objective: Optional[Union[str, Callable]] = None,
|
||||
objective: Optional[Union[str, _LGBM_ScikitCustomObjectiveFunction]] = None,
|
||||
class_weight: Optional[Union[Dict, str]] = None,
|
||||
min_split_gain: float = 0.,
|
||||
min_child_weight: float = 1e-3,
|
||||
|
|
Загрузка…
Ссылка в новой задаче