зеркало из https://github.com/microsoft/LightGBM.git
[python] add type hints for custom objective and metric functions in scikit-learn interface (#4547)
* [python] add type hints for custom objective and metric functions in scikit-learn interface * update type hints * remote unnecessary input * Update python-package/lightgbm/sklearn.py Co-authored-by: Nikita Titov <nekit94-08@mail.ru> * remove type hint on objective being callable Co-authored-by: Nikita Titov <nekit94-08@mail.ru>
This commit is contained in:
Родитель
bfb346c13e
Коммит
843d380d6b
|
@ -11,7 +11,7 @@ from collections import defaultdict, namedtuple
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from enum import Enum, auto
|
from enum import Enum, auto
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type, Union
|
from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, Union
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -21,8 +21,8 @@ from .basic import _LIB, LightGBMError, _choose_param_value, _ConfigAliases, _lo
|
||||||
from .compat import (DASK_INSTALLED, PANDAS_INSTALLED, SKLEARN_INSTALLED, Client, LGBMNotFittedError, concat,
|
from .compat import (DASK_INSTALLED, PANDAS_INSTALLED, SKLEARN_INSTALLED, Client, LGBMNotFittedError, concat,
|
||||||
dask_Array, dask_array_from_delayed, dask_bag_from_delayed, dask_DataFrame, dask_Series,
|
dask_Array, dask_array_from_delayed, dask_bag_from_delayed, dask_DataFrame, dask_Series,
|
||||||
default_client, delayed, pd_DataFrame, pd_Series, wait)
|
default_client, delayed, pd_DataFrame, pd_Series, wait)
|
||||||
from .sklearn import (LGBMClassifier, LGBMModel, LGBMRanker, LGBMRegressor, _lgbmmodel_doc_custom_eval_note,
|
from .sklearn import (LGBMClassifier, LGBMModel, LGBMRanker, LGBMRegressor, _LGBM_ScikitCustomEvalFunction,
|
||||||
_lgbmmodel_doc_fit, _lgbmmodel_doc_predict)
|
_lgbmmodel_doc_custom_eval_note, _lgbmmodel_doc_fit, _lgbmmodel_doc_predict)
|
||||||
|
|
||||||
_DaskCollection = Union[dask_Array, dask_DataFrame, dask_Series]
|
_DaskCollection = Union[dask_Array, dask_DataFrame, dask_Series]
|
||||||
_DaskMatrixLike = Union[dask_Array, dask_DataFrame]
|
_DaskMatrixLike = Union[dask_Array, dask_DataFrame]
|
||||||
|
@ -400,7 +400,7 @@ def _train(
|
||||||
eval_class_weight: Optional[List[Union[dict, str]]] = None,
|
eval_class_weight: Optional[List[Union[dict, str]]] = None,
|
||||||
eval_init_score: Optional[List[_DaskCollection]] = None,
|
eval_init_score: Optional[List[_DaskCollection]] = None,
|
||||||
eval_group: Optional[List[_DaskVectorLike]] = None,
|
eval_group: Optional[List[_DaskVectorLike]] = None,
|
||||||
eval_metric: Optional[Union[Callable, str, List[Union[Callable, str]]]] = None,
|
eval_metric: Optional[Union[_LGBM_ScikitCustomEvalFunction, str, List[Union[_LGBM_ScikitCustomEvalFunction, str]]]] = None,
|
||||||
eval_at: Optional[Iterable[int]] = None,
|
eval_at: Optional[Iterable[int]] = None,
|
||||||
**kwargs: Any
|
**kwargs: Any
|
||||||
) -> LGBMModel:
|
) -> LGBMModel:
|
||||||
|
@ -1029,7 +1029,7 @@ class _DaskLGBMModel:
|
||||||
eval_class_weight: Optional[List[Union[dict, str]]] = None,
|
eval_class_weight: Optional[List[Union[dict, str]]] = None,
|
||||||
eval_init_score: Optional[List[_DaskCollection]] = None,
|
eval_init_score: Optional[List[_DaskCollection]] = None,
|
||||||
eval_group: Optional[List[_DaskVectorLike]] = None,
|
eval_group: Optional[List[_DaskVectorLike]] = None,
|
||||||
eval_metric: Optional[Union[Callable, str, List[Union[Callable, str]]]] = None,
|
eval_metric: Optional[Union[_LGBM_ScikitCustomEvalFunction, str, List[Union[_LGBM_ScikitCustomEvalFunction, str]]]] = None,
|
||||||
eval_at: Optional[Iterable[int]] = None,
|
eval_at: Optional[Iterable[int]] = None,
|
||||||
early_stopping_rounds: Optional[int] = None,
|
early_stopping_rounds: Optional[int] = None,
|
||||||
**kwargs: Any
|
**kwargs: Any
|
||||||
|
@ -1096,7 +1096,7 @@ class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel):
|
||||||
learning_rate: float = 0.1,
|
learning_rate: float = 0.1,
|
||||||
n_estimators: int = 100,
|
n_estimators: int = 100,
|
||||||
subsample_for_bin: int = 200000,
|
subsample_for_bin: int = 200000,
|
||||||
objective: Optional[Union[Callable, str]] = None,
|
objective: Optional[str] = None,
|
||||||
class_weight: Optional[Union[dict, str]] = None,
|
class_weight: Optional[Union[dict, str]] = None,
|
||||||
min_split_gain: float = 0.,
|
min_split_gain: float = 0.,
|
||||||
min_child_weight: float = 1e-3,
|
min_child_weight: float = 1e-3,
|
||||||
|
@ -1165,7 +1165,7 @@ class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel):
|
||||||
eval_sample_weight: Optional[List[_DaskVectorLike]] = None,
|
eval_sample_weight: Optional[List[_DaskVectorLike]] = None,
|
||||||
eval_class_weight: Optional[List[Union[dict, str]]] = None,
|
eval_class_weight: Optional[List[Union[dict, str]]] = None,
|
||||||
eval_init_score: Optional[List[_DaskCollection]] = None,
|
eval_init_score: Optional[List[_DaskCollection]] = None,
|
||||||
eval_metric: Optional[Union[Callable, str, List[Union[Callable, str]]]] = None,
|
eval_metric: Optional[Union[_LGBM_ScikitCustomEvalFunction, str, List[Union[_LGBM_ScikitCustomEvalFunction, str]]]] = None,
|
||||||
early_stopping_rounds: Optional[int] = None,
|
early_stopping_rounds: Optional[int] = None,
|
||||||
**kwargs: Any
|
**kwargs: Any
|
||||||
) -> "DaskLGBMClassifier":
|
) -> "DaskLGBMClassifier":
|
||||||
|
@ -1281,7 +1281,7 @@ class DaskLGBMRegressor(LGBMRegressor, _DaskLGBMModel):
|
||||||
learning_rate: float = 0.1,
|
learning_rate: float = 0.1,
|
||||||
n_estimators: int = 100,
|
n_estimators: int = 100,
|
||||||
subsample_for_bin: int = 200000,
|
subsample_for_bin: int = 200000,
|
||||||
objective: Optional[Union[Callable, str]] = None,
|
objective: Optional[str] = None,
|
||||||
class_weight: Optional[Union[dict, str]] = None,
|
class_weight: Optional[Union[dict, str]] = None,
|
||||||
min_split_gain: float = 0.,
|
min_split_gain: float = 0.,
|
||||||
min_child_weight: float = 1e-3,
|
min_child_weight: float = 1e-3,
|
||||||
|
@ -1348,7 +1348,7 @@ class DaskLGBMRegressor(LGBMRegressor, _DaskLGBMModel):
|
||||||
eval_names: Optional[List[str]] = None,
|
eval_names: Optional[List[str]] = None,
|
||||||
eval_sample_weight: Optional[List[_DaskVectorLike]] = None,
|
eval_sample_weight: Optional[List[_DaskVectorLike]] = None,
|
||||||
eval_init_score: Optional[List[_DaskVectorLike]] = None,
|
eval_init_score: Optional[List[_DaskVectorLike]] = None,
|
||||||
eval_metric: Optional[Union[Callable, str, List[Union[Callable, str]]]] = None,
|
eval_metric: Optional[Union[_LGBM_ScikitCustomEvalFunction, str, List[Union[_LGBM_ScikitCustomEvalFunction, str]]]] = None,
|
||||||
early_stopping_rounds: Optional[int] = None,
|
early_stopping_rounds: Optional[int] = None,
|
||||||
**kwargs: Any
|
**kwargs: Any
|
||||||
) -> "DaskLGBMRegressor":
|
) -> "DaskLGBMRegressor":
|
||||||
|
@ -1446,7 +1446,7 @@ class DaskLGBMRanker(LGBMRanker, _DaskLGBMModel):
|
||||||
learning_rate: float = 0.1,
|
learning_rate: float = 0.1,
|
||||||
n_estimators: int = 100,
|
n_estimators: int = 100,
|
||||||
subsample_for_bin: int = 200000,
|
subsample_for_bin: int = 200000,
|
||||||
objective: Optional[Union[Callable, str]] = None,
|
objective: Optional[str] = None,
|
||||||
class_weight: Optional[Union[dict, str]] = None,
|
class_weight: Optional[Union[dict, str]] = None,
|
||||||
min_split_gain: float = 0.,
|
min_split_gain: float = 0.,
|
||||||
min_child_weight: float = 1e-3,
|
min_child_weight: float = 1e-3,
|
||||||
|
@ -1516,7 +1516,7 @@ class DaskLGBMRanker(LGBMRanker, _DaskLGBMModel):
|
||||||
eval_sample_weight: Optional[List[_DaskVectorLike]] = None,
|
eval_sample_weight: Optional[List[_DaskVectorLike]] = None,
|
||||||
eval_init_score: Optional[List[_DaskVectorLike]] = None,
|
eval_init_score: Optional[List[_DaskVectorLike]] = None,
|
||||||
eval_group: Optional[List[_DaskVectorLike]] = None,
|
eval_group: Optional[List[_DaskVectorLike]] = None,
|
||||||
eval_metric: Optional[Union[Callable, str, List[Union[Callable, str]]]] = None,
|
eval_metric: Optional[Union[_LGBM_ScikitCustomEvalFunction, str, List[Union[_LGBM_ScikitCustomEvalFunction, str]]]] = None,
|
||||||
eval_at: Iterable[int] = (1, 2, 3, 4, 5),
|
eval_at: Iterable[int] = (1, 2, 3, 4, 5),
|
||||||
early_stopping_rounds: Optional[int] = None,
|
early_stopping_rounds: Optional[int] = None,
|
||||||
**kwargs: Any
|
**kwargs: Any
|
||||||
|
|
|
@ -2,7 +2,7 @@
|
||||||
"""Scikit-learn wrapper interface for LightGBM."""
|
"""Scikit-learn wrapper interface for LightGBM."""
|
||||||
import copy
|
import copy
|
||||||
from inspect import signature
|
from inspect import signature
|
||||||
from typing import Callable, Dict, Optional, Union
|
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
@ -11,14 +11,42 @@ from .callback import log_evaluation, record_evaluation
|
||||||
from .compat import (SKLEARN_INSTALLED, LGBMNotFittedError, _LGBMAssertAllFinite, _LGBMCheckArray,
|
from .compat import (SKLEARN_INSTALLED, LGBMNotFittedError, _LGBMAssertAllFinite, _LGBMCheckArray,
|
||||||
_LGBMCheckClassificationTargets, _LGBMCheckSampleWeight, _LGBMCheckXY, _LGBMClassifierBase,
|
_LGBMCheckClassificationTargets, _LGBMCheckSampleWeight, _LGBMCheckXY, _LGBMClassifierBase,
|
||||||
_LGBMComputeSampleWeight, _LGBMLabelEncoder, _LGBMModelBase, _LGBMRegressorBase, dt_DataTable,
|
_LGBMComputeSampleWeight, _LGBMLabelEncoder, _LGBMModelBase, _LGBMRegressorBase, dt_DataTable,
|
||||||
pd_DataFrame)
|
pd_DataFrame, pd_Series)
|
||||||
from .engine import train
|
from .engine import train
|
||||||
|
|
||||||
|
_ArrayLike = Union[List, np.ndarray, pd_Series]
|
||||||
|
_EvalResultType = Tuple[str, float, bool]
|
||||||
|
|
||||||
|
_LGBM_ScikitCustomObjectiveFunction = Union[
|
||||||
|
Callable[
|
||||||
|
[np.ndarray, np.ndarray],
|
||||||
|
Tuple[_ArrayLike, _ArrayLike]
|
||||||
|
],
|
||||||
|
Callable[
|
||||||
|
[np.ndarray, np.ndarray, np.ndarray],
|
||||||
|
Tuple[_ArrayLike, _ArrayLike]
|
||||||
|
],
|
||||||
|
]
|
||||||
|
_LGBM_ScikitCustomEvalFunction = Union[
|
||||||
|
Callable[
|
||||||
|
[np.ndarray, np.ndarray],
|
||||||
|
Union[_EvalResultType, List[_EvalResultType]]
|
||||||
|
],
|
||||||
|
Callable[
|
||||||
|
[np.ndarray, np.ndarray, np.ndarray],
|
||||||
|
Union[_EvalResultType, List[_EvalResultType]]
|
||||||
|
],
|
||||||
|
Callable[
|
||||||
|
[np.ndarray, np.ndarray, np.ndarray, np.ndarray],
|
||||||
|
Union[_EvalResultType, List[_EvalResultType]]
|
||||||
|
],
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class _ObjectiveFunctionWrapper:
|
class _ObjectiveFunctionWrapper:
|
||||||
"""Proxy class for objective function."""
|
"""Proxy class for objective function."""
|
||||||
|
|
||||||
def __init__(self, func):
|
def __init__(self, func: _LGBM_ScikitCustomObjectiveFunction):
|
||||||
"""Construct a proxy class.
|
"""Construct a proxy class.
|
||||||
|
|
||||||
This class transforms objective function to match objective function with signature ``new_func(preds, dataset)``
|
This class transforms objective function to match objective function with signature ``new_func(preds, dataset)``
|
||||||
|
@ -107,7 +135,7 @@ class _ObjectiveFunctionWrapper:
|
||||||
class _EvalFunctionWrapper:
|
class _EvalFunctionWrapper:
|
||||||
"""Proxy class for evaluation function."""
|
"""Proxy class for evaluation function."""
|
||||||
|
|
||||||
def __init__(self, func):
|
def __init__(self, func: _LGBM_ScikitCustomEvalFunction):
|
||||||
"""Construct a proxy class.
|
"""Construct a proxy class.
|
||||||
|
|
||||||
This class transforms evaluation function to match evaluation function with signature ``new_func(preds, dataset)``
|
This class transforms evaluation function to match evaluation function with signature ``new_func(preds, dataset)``
|
||||||
|
@ -358,7 +386,7 @@ class LGBMModel(_LGBMModelBase):
|
||||||
learning_rate: float = 0.1,
|
learning_rate: float = 0.1,
|
||||||
n_estimators: int = 100,
|
n_estimators: int = 100,
|
||||||
subsample_for_bin: int = 200000,
|
subsample_for_bin: int = 200000,
|
||||||
objective: Optional[Union[str, Callable]] = None,
|
objective: Optional[Union[str, _LGBM_ScikitCustomObjectiveFunction]] = None,
|
||||||
class_weight: Optional[Union[Dict, str]] = None,
|
class_weight: Optional[Union[Dict, str]] = None,
|
||||||
min_split_gain: float = 0.,
|
min_split_gain: float = 0.,
|
||||||
min_child_weight: float = 1e-3,
|
min_child_weight: float = 1e-3,
|
||||||
|
|
Загрузка…
Ссылка в новой задаче