зеркало из https://github.com/microsoft/LightGBM.git
[dask] Add type hints in Dask package (#3866)
* add type hints in dask module * starting on asserts * remove unused code * add hints for dtypes * replace accidentally-removed docstrings * revert unrelated change * Update python-package/lightgbm/dask.py * empty commit * fix hints on group * capitalize array * hide hints in signatures * empty commit * sphinx version * Apply suggestions from code review Co-authored-by: Nikita Titov <nekit94-08@mail.ru> * fix hint for MatrixLike * Update python-package/lightgbm/dask.py Co-authored-by: Nikita Titov <nekit94-08@mail.ru> * Apply suggestions from code review Co-authored-by: Nikita Titov <nekit94-08@mail.ru> * update docstring * empty commit Co-authored-by: Nikita Titov <nekit94-08@mail.ru>
This commit is contained in:
Родитель
217642ca71
Коммит
ea8e47ea24
|
@ -74,7 +74,7 @@ C_API = os.environ.get('C_API', '').lower().strip() != 'no'
|
|||
RTD = bool(os.environ.get('READTHEDOCS', ''))
|
||||
|
||||
# If your documentation needs a minimal Sphinx version, state it here.
|
||||
needs_sphinx = '1.3' # Due to sphinx.ext.napoleon
|
||||
needs_sphinx = '2.1.0' # Due to sphinx.ext.napoleon, autodoc_typehints
|
||||
if needs_sphinx > sphinx.__version__:
|
||||
message = 'This project needs at least Sphinx v%s' % needs_sphinx
|
||||
raise VersionRequirementError(message)
|
||||
|
@ -97,6 +97,9 @@ autodoc_default_options = {
|
|||
"show-inheritance": True,
|
||||
}
|
||||
|
||||
# hide type hints in API docs
|
||||
autodoc_typehints = "none"
|
||||
|
||||
# Generate autosummary pages. Output should be set with: `:toctree: pythonapi/`
|
||||
autosummary_generate = ['Python-API.rst']
|
||||
|
||||
|
|
|
@ -113,7 +113,8 @@ except ImportError:
|
|||
try:
|
||||
from dask import delayed
|
||||
from dask.array import Array as dask_Array
|
||||
from dask.dataframe import _Frame as dask_Frame
|
||||
from dask.dataframe import DataFrame as dask_DataFrame
|
||||
from dask.dataframe import Series as dask_Series
|
||||
from dask.distributed import Client, default_client, get_worker, wait
|
||||
DASK_INSTALLED = True
|
||||
except ImportError:
|
||||
|
@ -129,7 +130,12 @@ except ImportError:
|
|||
|
||||
pass
|
||||
|
||||
class dask_Frame:
|
||||
"""Dummy class for dask.dataframe._Frame."""
|
||||
class dask_DataFrame:
|
||||
"""Dummy class for dask.dataframe.DataFrame."""
|
||||
|
||||
pass
|
||||
|
||||
class dask_Series:
|
||||
"""Dummy class for dask.dataframe.Series."""
|
||||
|
||||
pass
|
||||
|
|
|
@ -9,7 +9,7 @@ It is based on dask-lightgbm, which was based on dask-xgboost.
|
|||
import socket
|
||||
from collections import defaultdict
|
||||
from copy import deepcopy
|
||||
from typing import Dict, Iterable
|
||||
from typing import Any, Dict, Iterable, List, Optional, Type, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import numpy as np
|
||||
|
@ -18,8 +18,13 @@ import scipy.sparse as ss
|
|||
from .basic import _choose_param_value, _ConfigAliases, _LIB, _log_warning, _safe_call, LightGBMError
|
||||
from .compat import (PANDAS_INSTALLED, pd_DataFrame, pd_Series, concat,
|
||||
SKLEARN_INSTALLED,
|
||||
DASK_INSTALLED, dask_Frame, dask_Array, delayed, Client, default_client, get_worker, wait)
|
||||
from .sklearn import LGBMClassifier, LGBMRegressor, LGBMRanker
|
||||
DASK_INSTALLED, dask_DataFrame, dask_Array, dask_Series, delayed, Client, default_client, get_worker, wait)
|
||||
from .sklearn import LGBMClassifier, LGBMModel, LGBMRegressor, LGBMRanker
|
||||
|
||||
_DaskCollection = Union[dask_Array, dask_DataFrame, dask_Series]
|
||||
_DaskMatrixLike = Union[dask_Array, dask_DataFrame]
|
||||
_DaskPart = Union[np.ndarray, pd_DataFrame, pd_Series, ss.spmatrix]
|
||||
_PredictionDtype = Union[Type[np.float32], Type[np.float64], Type[np.int32], Type[np.int64]]
|
||||
|
||||
|
||||
def _find_open_port(worker_ip: str, local_listen_port: int, ports_to_skip: Iterable[int]) -> int:
|
||||
|
@ -102,7 +107,7 @@ def _find_ports_for_workers(client: Client, worker_addresses: Iterable[str], loc
|
|||
return worker_ip_to_port
|
||||
|
||||
|
||||
def _concat(seq):
|
||||
def _concat(seq: List[_DaskPart]) -> _DaskPart:
|
||||
if isinstance(seq[0], np.ndarray):
|
||||
return np.concatenate(seq, axis=0)
|
||||
elif isinstance(seq[0], (pd_DataFrame, pd_Series)):
|
||||
|
@ -113,8 +118,15 @@ def _concat(seq):
|
|||
raise TypeError('Data must be one of: numpy arrays, pandas dataframes, sparse matrices (from scipy). Got %s.' % str(type(seq[0])))
|
||||
|
||||
|
||||
def _train_part(params, model_factory, list_of_parts, worker_address_to_port, return_model,
|
||||
time_out=120, **kwargs):
|
||||
def _train_part(
|
||||
params: Dict[str, Any],
|
||||
model_factory: Type[LGBMModel],
|
||||
list_of_parts: List[Dict[str, _DaskPart]],
|
||||
worker_address_to_port: Dict[str, int],
|
||||
return_model: bool,
|
||||
time_out: int = 120,
|
||||
**kwargs: Any
|
||||
) -> Optional[LGBMModel]:
|
||||
local_worker_address = get_worker().address
|
||||
machine_list = ','.join([
|
||||
'%s:%d' % (urlparse(worker_address).hostname, port)
|
||||
|
@ -158,7 +170,7 @@ def _train_part(params, model_factory, list_of_parts, worker_address_to_port, re
|
|||
return model if return_model else None
|
||||
|
||||
|
||||
def _split_to_parts(data, is_matrix):
|
||||
def _split_to_parts(data: _DaskCollection, is_matrix: bool) -> List[_DaskPart]:
|
||||
parts = data.to_delayed()
|
||||
if isinstance(parts, np.ndarray):
|
||||
if is_matrix:
|
||||
|
@ -169,24 +181,33 @@ def _split_to_parts(data, is_matrix):
|
|||
return parts
|
||||
|
||||
|
||||
def _train(client, data, label, params, model_factory, sample_weight=None, group=None, **kwargs):
|
||||
def _train(
|
||||
client: Client,
|
||||
data: _DaskMatrixLike,
|
||||
label: _DaskCollection,
|
||||
params: Dict[str, Any],
|
||||
model_factory: Type[LGBMModel],
|
||||
sample_weight: Optional[_DaskCollection] = None,
|
||||
group: Optional[_DaskCollection] = None,
|
||||
**kwargs: Any
|
||||
) -> LGBMModel:
|
||||
"""Inner train routine.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
client : dask.distributed.Client
|
||||
Dask client.
|
||||
data : dask array of shape = [n_samples, n_features]
|
||||
data : dask Array or dask DataFrame of shape = [n_samples, n_features]
|
||||
Input feature matrix.
|
||||
label : dask array of shape = [n_samples]
|
||||
label : dask Array, dask DataFrame or dask Series of shape = [n_samples]
|
||||
The target values (class labels in classification, real numbers in regression).
|
||||
params : dict
|
||||
Parameters passed to constructor of the local underlying model.
|
||||
model_factory : lightgbm.LGBMClassifier, lightgbm.LGBMRegressor, or lightgbm.LGBMRanker class
|
||||
Class of the local underlying model.
|
||||
sample_weight : array-like of shape = [n_samples] or None, optional (default=None)
|
||||
sample_weight : dask Array, dask DataFrame, Dask Series of shape = [n_samples] or None, optional (default=None)
|
||||
Weights of training data.
|
||||
group : array-like or None, optional (default=None)
|
||||
group : dask Array, dask DataFrame, Dask Series of shape = [n_samples] or None, optional (default=None)
|
||||
Group/query data.
|
||||
Only used in the learning-to-rank task.
|
||||
sum(group) = n_samples.
|
||||
|
@ -301,7 +322,15 @@ def _train(client, data, label, params, model_factory, sample_weight=None, group
|
|||
return results[0]
|
||||
|
||||
|
||||
def _predict_part(part, model, raw_score, pred_proba, pred_leaf, pred_contrib, **kwargs):
|
||||
def _predict_part(
|
||||
part: _DaskPart,
|
||||
model: LGBMModel,
|
||||
raw_score: bool,
|
||||
pred_proba: bool,
|
||||
pred_leaf: bool,
|
||||
pred_contrib: bool,
|
||||
**kwargs: Any
|
||||
) -> _DaskPart:
|
||||
data = part.values if isinstance(part, pd_DataFrame) else part
|
||||
|
||||
if data.shape[0] == 0:
|
||||
|
@ -332,15 +361,23 @@ def _predict_part(part, model, raw_score, pred_proba, pred_leaf, pred_contrib, *
|
|||
return result
|
||||
|
||||
|
||||
def _predict(model, data, raw_score=False, pred_proba=False, pred_leaf=False, pred_contrib=False,
|
||||
dtype=np.float32, **kwargs):
|
||||
def _predict(
|
||||
model: LGBMModel,
|
||||
data: _DaskMatrixLike,
|
||||
raw_score: bool = False,
|
||||
pred_proba: bool = False,
|
||||
pred_leaf: bool = False,
|
||||
pred_contrib: bool = False,
|
||||
dtype: _PredictionDtype = np.float32,
|
||||
**kwargs: Any
|
||||
) -> dask_Array:
|
||||
"""Inner predict routine.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model : lightgbm.LGBMClassifier, lightgbm.LGBMRegressor, or lightgbm.LGBMRanker class
|
||||
Fitted underlying model.
|
||||
data : dask array of shape = [n_samples, n_features]
|
||||
data : dask Array or dask DataFrame of shape = [n_samples, n_features]
|
||||
Input feature matrix.
|
||||
raw_score : bool, optional (default=False)
|
||||
Whether to predict raw scores.
|
||||
|
@ -357,16 +394,16 @@ def _predict(model, data, raw_score=False, pred_proba=False, pred_leaf=False, pr
|
|||
|
||||
Returns
|
||||
-------
|
||||
predicted_result : dask array of shape = [n_samples] or shape = [n_samples, n_classes]
|
||||
predicted_result : dask Array of shape = [n_samples] or shape = [n_samples, n_classes]
|
||||
The predicted values.
|
||||
X_leaves : dask array of shape = [n_samples, n_trees] or shape = [n_samples, n_trees * n_classes]
|
||||
X_leaves : dask Array of shape = [n_samples, n_trees] or shape = [n_samples, n_trees * n_classes]
|
||||
If ``pred_leaf=True``, the predicted leaf of every tree for each sample.
|
||||
X_SHAP_values : dask array of shape = [n_samples, n_features + 1] or shape = [n_samples, (n_features + 1) * n_classes] or list with n_classes length of such objects
|
||||
X_SHAP_values : dask Array of shape = [n_samples, n_features + 1] or shape = [n_samples, (n_features + 1) * n_classes]
|
||||
If ``pred_contrib=True``, the feature contributions for each sample.
|
||||
"""
|
||||
if not all((DASK_INSTALLED, PANDAS_INSTALLED, SKLEARN_INSTALLED)):
|
||||
raise LightGBMError('dask, pandas and scikit-learn are required for lightgbm.dask')
|
||||
if isinstance(data, dask_Frame):
|
||||
if isinstance(data, dask_DataFrame):
|
||||
return data.map_partitions(
|
||||
_predict_part,
|
||||
model=model,
|
||||
|
@ -392,11 +429,21 @@ def _predict(model, data, raw_score=False, pred_proba=False, pred_leaf=False, pr
|
|||
**kwargs
|
||||
)
|
||||
else:
|
||||
raise TypeError('Data must be either Dask array or dataframe. Got %s.' % str(type(data)))
|
||||
raise TypeError('Data must be either dask Array or dask DataFrame. Got %s.' % str(type(data)))
|
||||
|
||||
|
||||
class _DaskLGBMModel:
|
||||
def _fit(self, model_factory, X, y, sample_weight=None, group=None, client=None, **kwargs):
|
||||
|
||||
def _fit(
|
||||
self,
|
||||
model_factory: Type[LGBMModel],
|
||||
X: _DaskMatrixLike,
|
||||
y: _DaskCollection,
|
||||
sample_weight: Optional[_DaskCollection] = None,
|
||||
group: Optional[_DaskCollection] = None,
|
||||
client: Optional[Client] = None,
|
||||
**kwargs: Any
|
||||
) -> "_DaskLGBMModel":
|
||||
if not all((DASK_INSTALLED, PANDAS_INSTALLED, SKLEARN_INSTALLED)):
|
||||
raise LightGBMError('dask, pandas and scikit-learn are required for lightgbm.dask')
|
||||
if client is None:
|
||||
|
@ -420,13 +467,13 @@ class _DaskLGBMModel:
|
|||
|
||||
return self
|
||||
|
||||
def _to_local(self, model_factory):
|
||||
def _to_local(self, model_factory: Type[LGBMModel]) -> LGBMModel:
|
||||
model = model_factory(**self.get_params())
|
||||
self._copy_extra_params(self, model)
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def _copy_extra_params(source, dest):
|
||||
def _copy_extra_params(source: Union["_DaskLGBMModel", LGBMModel], dest: Union["_DaskLGBMModel", LGBMModel]) -> None:
|
||||
params = source.get_params()
|
||||
attributes = source.__dict__
|
||||
extra_param_names = set(attributes.keys()).difference(params.keys())
|
||||
|
@ -437,7 +484,14 @@ class _DaskLGBMModel:
|
|||
class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel):
|
||||
"""Distributed version of lightgbm.LGBMClassifier."""
|
||||
|
||||
def fit(self, X, y, sample_weight=None, client=None, **kwargs):
|
||||
def fit(
|
||||
self,
|
||||
X: _DaskMatrixLike,
|
||||
y: _DaskCollection,
|
||||
sample_weight: Optional[_DaskCollection] = None,
|
||||
client: Optional[Client] = None,
|
||||
**kwargs: Any
|
||||
) -> "DaskLGBMClassifier":
|
||||
"""Docstring is inherited from the lightgbm.LGBMClassifier.fit."""
|
||||
return self._fit(
|
||||
model_factory=LGBMClassifier,
|
||||
|
@ -455,7 +509,7 @@ class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel):
|
|||
+ ' ' * 12 + 'Dask client.\n'
|
||||
+ ' ' * 8 + _init_score + _after_init_score)
|
||||
|
||||
def predict(self, X, **kwargs):
|
||||
def predict(self, X: _DaskMatrixLike, **kwargs: Any) -> dask_Array:
|
||||
"""Docstring is inherited from the lightgbm.LGBMClassifier.predict."""
|
||||
return _predict(
|
||||
model=self.to_local(),
|
||||
|
@ -466,7 +520,7 @@ class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel):
|
|||
|
||||
predict.__doc__ = LGBMClassifier.predict.__doc__
|
||||
|
||||
def predict_proba(self, X, **kwargs):
|
||||
def predict_proba(self, X: _DaskMatrixLike, **kwargs: Any) -> dask_Array:
|
||||
"""Docstring is inherited from the lightgbm.LGBMClassifier.predict_proba."""
|
||||
return _predict(
|
||||
model=self.to_local(),
|
||||
|
@ -477,7 +531,7 @@ class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel):
|
|||
|
||||
predict_proba.__doc__ = LGBMClassifier.predict_proba.__doc__
|
||||
|
||||
def to_local(self):
|
||||
def to_local(self) -> LGBMClassifier:
|
||||
"""Create regular version of lightgbm.LGBMClassifier from the distributed version.
|
||||
|
||||
Returns
|
||||
|
@ -491,7 +545,14 @@ class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel):
|
|||
class DaskLGBMRegressor(LGBMRegressor, _DaskLGBMModel):
|
||||
"""Distributed version of lightgbm.LGBMRegressor."""
|
||||
|
||||
def fit(self, X, y, sample_weight=None, client=None, **kwargs):
|
||||
def fit(
|
||||
self,
|
||||
X: _DaskMatrixLike,
|
||||
y: _DaskCollection,
|
||||
sample_weight: Optional[_DaskCollection] = None,
|
||||
client: Optional[Client] = None,
|
||||
**kwargs: Any
|
||||
) -> "DaskLGBMRegressor":
|
||||
"""Docstring is inherited from the lightgbm.LGBMRegressor.fit."""
|
||||
return self._fit(
|
||||
model_factory=LGBMRegressor,
|
||||
|
@ -509,7 +570,7 @@ class DaskLGBMRegressor(LGBMRegressor, _DaskLGBMModel):
|
|||
+ ' ' * 12 + 'Dask client.\n'
|
||||
+ ' ' * 8 + _init_score + _after_init_score)
|
||||
|
||||
def predict(self, X, **kwargs):
|
||||
def predict(self, X: _DaskMatrixLike, **kwargs) -> dask_Array:
|
||||
"""Docstring is inherited from the lightgbm.LGBMRegressor.predict."""
|
||||
return _predict(
|
||||
model=self.to_local(),
|
||||
|
@ -519,7 +580,7 @@ class DaskLGBMRegressor(LGBMRegressor, _DaskLGBMModel):
|
|||
|
||||
predict.__doc__ = LGBMRegressor.predict.__doc__
|
||||
|
||||
def to_local(self):
|
||||
def to_local(self) -> LGBMRegressor:
|
||||
"""Create regular version of lightgbm.LGBMRegressor from the distributed version.
|
||||
|
||||
Returns
|
||||
|
@ -533,7 +594,16 @@ class DaskLGBMRegressor(LGBMRegressor, _DaskLGBMModel):
|
|||
class DaskLGBMRanker(LGBMRanker, _DaskLGBMModel):
|
||||
"""Distributed version of lightgbm.LGBMRanker."""
|
||||
|
||||
def fit(self, X, y, sample_weight=None, init_score=None, group=None, client=None, **kwargs):
|
||||
def fit(
|
||||
self,
|
||||
X: _DaskMatrixLike,
|
||||
y: _DaskCollection,
|
||||
sample_weight: Optional[_DaskCollection] = None,
|
||||
init_score: Optional[_DaskCollection] = None,
|
||||
group: Optional[_DaskCollection] = None,
|
||||
client: Optional[Client] = None,
|
||||
**kwargs: Any
|
||||
) -> "DaskLGBMRanker":
|
||||
"""Docstring is inherited from the lightgbm.LGBMRanker.fit."""
|
||||
if init_score is not None:
|
||||
raise RuntimeError('init_score is not currently supported in lightgbm.dask')
|
||||
|
@ -555,13 +625,13 @@ class DaskLGBMRanker(LGBMRanker, _DaskLGBMModel):
|
|||
+ ' ' * 12 + 'Dask client.\n'
|
||||
+ ' ' * 8 + _eval_set + _after_eval_set)
|
||||
|
||||
def predict(self, X, **kwargs):
|
||||
def predict(self, X: _DaskMatrixLike, **kwargs: Any) -> dask_Array:
|
||||
"""Docstring is inherited from the lightgbm.LGBMRanker.predict."""
|
||||
return _predict(self.to_local(), X, **kwargs)
|
||||
|
||||
predict.__doc__ = LGBMRanker.predict.__doc__
|
||||
|
||||
def to_local(self):
|
||||
def to_local(self) -> LGBMRanker:
|
||||
"""Create regular version of lightgbm.LGBMRanker from the distributed version.
|
||||
|
||||
Returns
|
||||
|
|
Загрузка…
Ссылка в новой задаче