зеркало из https://github.com/microsoft/LightGBM.git
* starting on Dask client * more docs stuff * fix pickling * just copy docstrings * fit docs * switch test order * linting * use client kwarg * remove inner set_params() * add type hints * fix type hints * remove commented code * reorder * fix tests, add client_ property * Apply suggestions from code review Co-authored-by: Nikita Titov <nekit94-08@mail.ru> * fix tests * linting * simplify Co-authored-by: Nikita Titov <nekit94-08@mail.ru>
This commit is contained in:
Родитель
56fc036def
Коммит
c3ac77b570
|
@ -95,7 +95,7 @@ if [[ $TASK == "swig" ]]; then
|
|||
exit 0
|
||||
fi
|
||||
|
||||
conda install -q -y -n $CONDA_ENV dask distributed joblib matplotlib numpy pandas psutil pytest scikit-learn scipy
|
||||
conda install -q -y -n $CONDA_ENV cloudpickle dask distributed joblib matplotlib numpy pandas psutil pytest scikit-learn scipy
|
||||
|
||||
# graphviz must come from conda-forge to avoid this on some linux distros:
|
||||
# https://github.com/conda-forge/graphviz-feedstock/issues/18
|
||||
|
|
|
@ -9,7 +9,7 @@ It is based on dask-lightgbm, which was based on dask-xgboost.
|
|||
import socket
|
||||
from collections import defaultdict
|
||||
from copy import deepcopy
|
||||
from typing import Any, Dict, Iterable, List, Optional, Type, Union
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Type, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import numpy as np
|
||||
|
@ -17,7 +17,7 @@ import scipy.sparse as ss
|
|||
|
||||
from .basic import _choose_param_value, _ConfigAliases, _LIB, _log_warning, _safe_call, LightGBMError
|
||||
from .compat import (PANDAS_INSTALLED, pd_DataFrame, pd_Series, concat,
|
||||
SKLEARN_INSTALLED,
|
||||
SKLEARN_INSTALLED, LGBMNotFittedError,
|
||||
DASK_INSTALLED, dask_DataFrame, dask_Array, dask_Series, delayed, Client, default_client, get_worker, wait)
|
||||
from .sklearn import LGBMClassifier, LGBMModel, LGBMRegressor, LGBMRanker
|
||||
|
||||
|
@ -27,6 +27,25 @@ _DaskPart = Union[np.ndarray, pd_DataFrame, pd_Series, ss.spmatrix]
|
|||
_PredictionDtype = Union[Type[np.float32], Type[np.float64], Type[np.int32], Type[np.int64]]
|
||||
|
||||
|
||||
def _get_dask_client(client: Optional[Client]) -> Client:
|
||||
"""Choose a Dask client to use.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
client : dask.distributed.Client or None
|
||||
Dask client.
|
||||
|
||||
Returns
|
||||
-------
|
||||
client : dask.distributed.Client
|
||||
A Dask client.
|
||||
"""
|
||||
if client is None:
|
||||
return default_client()
|
||||
else:
|
||||
return client
|
||||
|
||||
|
||||
def _find_open_port(worker_ip: str, local_listen_port: int, ports_to_skip: Iterable[int]) -> int:
|
||||
"""Find an open port.
|
||||
|
||||
|
@ -434,6 +453,29 @@ def _predict(
|
|||
|
||||
class _DaskLGBMModel:
|
||||
|
||||
@property
|
||||
def client_(self) -> Client:
|
||||
"""Dask client.
|
||||
|
||||
This property can be passed in the constructor or updated
|
||||
with ``model.set_params(client=client)``.
|
||||
"""
|
||||
if not getattr(self, "fitted_", False):
|
||||
raise LGBMNotFittedError('Cannot access property client_ before calling fit().')
|
||||
|
||||
return _get_dask_client(client=self.client)
|
||||
|
||||
def _lgb_getstate(self) -> Dict[Any, Any]:
|
||||
"""Remove un-picklable attributes before serialization."""
|
||||
client = self.__dict__.pop("client", None)
|
||||
self.__dict__.pop("_client", None)
|
||||
self._other_params.pop("client", None)
|
||||
out = deepcopy(self.__dict__)
|
||||
out.update({"_client": None, "client": None})
|
||||
self._client = client
|
||||
self.client = client
|
||||
return out
|
||||
|
||||
def _fit(
|
||||
self,
|
||||
model_factory: Type[LGBMModel],
|
||||
|
@ -441,18 +483,16 @@ class _DaskLGBMModel:
|
|||
y: _DaskCollection,
|
||||
sample_weight: Optional[_DaskCollection] = None,
|
||||
group: Optional[_DaskCollection] = None,
|
||||
client: Optional[Client] = None,
|
||||
**kwargs: Any
|
||||
) -> "_DaskLGBMModel":
|
||||
if not all((DASK_INSTALLED, PANDAS_INSTALLED, SKLEARN_INSTALLED)):
|
||||
raise LightGBMError('dask, pandas and scikit-learn are required for lightgbm.dask')
|
||||
if client is None:
|
||||
client = default_client()
|
||||
|
||||
params = self.get_params(True)
|
||||
params.pop("client", None)
|
||||
|
||||
model = _train(
|
||||
client=client,
|
||||
client=_get_dask_client(self.client),
|
||||
data=X,
|
||||
label=y,
|
||||
params=params,
|
||||
|
@ -468,8 +508,11 @@ class _DaskLGBMModel:
|
|||
return self
|
||||
|
||||
def _to_local(self, model_factory: Type[LGBMModel]) -> LGBMModel:
|
||||
model = model_factory(**self.get_params())
|
||||
params = self.get_params()
|
||||
params.pop("client", None)
|
||||
model = model_factory(**params)
|
||||
self._copy_extra_params(self, model)
|
||||
model._other_params.pop("client", None)
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
|
@ -478,18 +521,82 @@ class _DaskLGBMModel:
|
|||
attributes = source.__dict__
|
||||
extra_param_names = set(attributes.keys()).difference(params.keys())
|
||||
for name in extra_param_names:
|
||||
setattr(dest, name, attributes[name])
|
||||
if name != "_client":
|
||||
setattr(dest, name, attributes[name])
|
||||
|
||||
|
||||
class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel):
|
||||
"""Distributed version of lightgbm.LGBMClassifier."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
boosting_type: str = 'gbdt',
|
||||
num_leaves: int = 31,
|
||||
max_depth: int = -1,
|
||||
learning_rate: float = 0.1,
|
||||
n_estimators: int = 100,
|
||||
subsample_for_bin: int = 200000,
|
||||
objective: Optional[Union[Callable, str]] = None,
|
||||
class_weight: Optional[Union[dict, str]] = None,
|
||||
min_split_gain: float = 0.,
|
||||
min_child_weight: float = 1e-3,
|
||||
min_child_samples: int = 20,
|
||||
subsample: float = 1.,
|
||||
subsample_freq: int = 0,
|
||||
colsample_bytree: float = 1.,
|
||||
reg_alpha: float = 0.,
|
||||
reg_lambda: float = 0.,
|
||||
random_state: Optional[Union[int, np.random.RandomState]] = None,
|
||||
n_jobs: int = -1,
|
||||
silent: bool = True,
|
||||
importance_type: str = 'split',
|
||||
client: Optional[Client] = None,
|
||||
**kwargs: Any
|
||||
):
|
||||
"""Docstring is inherited from the lightgbm.LGBMClassifier.__init__."""
|
||||
self._client = client
|
||||
self.client = client
|
||||
super().__init__(
|
||||
boosting_type=boosting_type,
|
||||
num_leaves=num_leaves,
|
||||
max_depth=max_depth,
|
||||
learning_rate=learning_rate,
|
||||
n_estimators=n_estimators,
|
||||
subsample_for_bin=subsample_for_bin,
|
||||
objective=objective,
|
||||
class_weight=class_weight,
|
||||
min_split_gain=min_split_gain,
|
||||
min_child_weight=min_child_weight,
|
||||
min_child_samples=min_child_samples,
|
||||
subsample=subsample,
|
||||
subsample_freq=subsample_freq,
|
||||
colsample_bytree=colsample_bytree,
|
||||
reg_alpha=reg_alpha,
|
||||
reg_lambda=reg_lambda,
|
||||
random_state=random_state,
|
||||
n_jobs=n_jobs,
|
||||
silent=silent,
|
||||
importance_type=importance_type,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
_base_doc = LGBMClassifier.__init__.__doc__
|
||||
_before_kwargs, _kwargs, _after_kwargs = _base_doc.partition('**kwargs')
|
||||
__init__.__doc__ = (
|
||||
_before_kwargs
|
||||
+ 'client : dask.distributed.Client or None, optional (default=None)\n'
|
||||
+ ' ' * 12 + 'Dask client. If ``None``, ``distributed.default_client()`` will be used at runtime. The Dask client used by this class will not be saved if the model object is pickled.\n'
|
||||
+ ' ' * 8 + _kwargs + _after_kwargs
|
||||
)
|
||||
|
||||
def __getstate__(self) -> Dict[Any, Any]:
|
||||
return self._lgb_getstate()
|
||||
|
||||
def fit(
|
||||
self,
|
||||
X: _DaskMatrixLike,
|
||||
y: _DaskCollection,
|
||||
sample_weight: Optional[_DaskCollection] = None,
|
||||
client: Optional[Client] = None,
|
||||
**kwargs: Any
|
||||
) -> "DaskLGBMClassifier":
|
||||
"""Docstring is inherited from the lightgbm.LGBMClassifier.fit."""
|
||||
|
@ -498,16 +605,10 @@ class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel):
|
|||
X=X,
|
||||
y=y,
|
||||
sample_weight=sample_weight,
|
||||
client=client,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
_base_doc = LGBMClassifier.fit.__doc__
|
||||
_before_init_score, _init_score, _after_init_score = _base_doc.partition('init_score :')
|
||||
fit.__doc__ = (_before_init_score
|
||||
+ 'client : dask.distributed.Client or None, optional (default=None)\n'
|
||||
+ ' ' * 12 + 'Dask client.\n'
|
||||
+ ' ' * 8 + _init_score + _after_init_score)
|
||||
fit.__doc__ = LGBMClassifier.fit.__doc__
|
||||
|
||||
def predict(self, X: _DaskMatrixLike, **kwargs: Any) -> dask_Array:
|
||||
"""Docstring is inherited from the lightgbm.LGBMClassifier.predict."""
|
||||
|
@ -545,6 +646,70 @@ class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel):
|
|||
class DaskLGBMRegressor(LGBMRegressor, _DaskLGBMModel):
|
||||
"""Distributed version of lightgbm.LGBMRegressor."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
boosting_type: str = 'gbdt',
|
||||
num_leaves: int = 31,
|
||||
max_depth: int = -1,
|
||||
learning_rate: float = 0.1,
|
||||
n_estimators: int = 100,
|
||||
subsample_for_bin: int = 200000,
|
||||
objective: Optional[Union[Callable, str]] = None,
|
||||
class_weight: Optional[Union[dict, str]] = None,
|
||||
min_split_gain: float = 0.,
|
||||
min_child_weight: float = 1e-3,
|
||||
min_child_samples: int = 20,
|
||||
subsample: float = 1.,
|
||||
subsample_freq: int = 0,
|
||||
colsample_bytree: float = 1.,
|
||||
reg_alpha: float = 0.,
|
||||
reg_lambda: float = 0.,
|
||||
random_state: Optional[Union[int, np.random.RandomState]] = None,
|
||||
n_jobs: int = -1,
|
||||
silent: bool = True,
|
||||
importance_type: str = 'split',
|
||||
client: Optional[Client] = None,
|
||||
**kwargs: Any
|
||||
):
|
||||
"""Docstring is inherited from the lightgbm.LGBMRegressor.__init__."""
|
||||
self._client = client
|
||||
self.client = client
|
||||
super().__init__(
|
||||
boosting_type=boosting_type,
|
||||
num_leaves=num_leaves,
|
||||
max_depth=max_depth,
|
||||
learning_rate=learning_rate,
|
||||
n_estimators=n_estimators,
|
||||
subsample_for_bin=subsample_for_bin,
|
||||
objective=objective,
|
||||
class_weight=class_weight,
|
||||
min_split_gain=min_split_gain,
|
||||
min_child_weight=min_child_weight,
|
||||
min_child_samples=min_child_samples,
|
||||
subsample=subsample,
|
||||
subsample_freq=subsample_freq,
|
||||
colsample_bytree=colsample_bytree,
|
||||
reg_alpha=reg_alpha,
|
||||
reg_lambda=reg_lambda,
|
||||
random_state=random_state,
|
||||
n_jobs=n_jobs,
|
||||
silent=silent,
|
||||
importance_type=importance_type,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
_base_doc = LGBMRegressor.__init__.__doc__
|
||||
_before_kwargs, _kwargs, _after_kwargs = _base_doc.partition('**kwargs')
|
||||
__init__.__doc__ = (
|
||||
_before_kwargs
|
||||
+ 'client : dask.distributed.Client or None, optional (default=None)\n'
|
||||
+ ' ' * 12 + 'Dask client. If ``None``, ``distributed.default_client()`` will be used at runtime. The Dask client used by this class will not be saved if the model object is pickled.\n'
|
||||
+ ' ' * 8 + _kwargs + _after_kwargs
|
||||
)
|
||||
|
||||
def __getstate__(self) -> Dict[Any, Any]:
|
||||
return self._lgb_getstate()
|
||||
|
||||
def fit(
|
||||
self,
|
||||
X: _DaskMatrixLike,
|
||||
|
@ -559,16 +724,10 @@ class DaskLGBMRegressor(LGBMRegressor, _DaskLGBMModel):
|
|||
X=X,
|
||||
y=y,
|
||||
sample_weight=sample_weight,
|
||||
client=client,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
_base_doc = LGBMRegressor.fit.__doc__
|
||||
_before_init_score, _init_score, _after_init_score = _base_doc.partition('init_score :')
|
||||
fit.__doc__ = (_before_init_score
|
||||
+ 'client : dask.distributed.Client or None, optional (default=None)\n'
|
||||
+ ' ' * 12 + 'Dask client.\n'
|
||||
+ ' ' * 8 + _init_score + _after_init_score)
|
||||
fit.__doc__ = LGBMRegressor.fit.__doc__
|
||||
|
||||
def predict(self, X: _DaskMatrixLike, **kwargs) -> dask_Array:
|
||||
"""Docstring is inherited from the lightgbm.LGBMRegressor.predict."""
|
||||
|
@ -594,6 +753,70 @@ class DaskLGBMRegressor(LGBMRegressor, _DaskLGBMModel):
|
|||
class DaskLGBMRanker(LGBMRanker, _DaskLGBMModel):
|
||||
"""Distributed version of lightgbm.LGBMRanker."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
boosting_type: str = 'gbdt',
|
||||
num_leaves: int = 31,
|
||||
max_depth: int = -1,
|
||||
learning_rate: float = 0.1,
|
||||
n_estimators: int = 100,
|
||||
subsample_for_bin: int = 200000,
|
||||
objective: Optional[Union[Callable, str]] = None,
|
||||
class_weight: Optional[Union[dict, str]] = None,
|
||||
min_split_gain: float = 0.,
|
||||
min_child_weight: float = 1e-3,
|
||||
min_child_samples: int = 20,
|
||||
subsample: float = 1.,
|
||||
subsample_freq: int = 0,
|
||||
colsample_bytree: float = 1.,
|
||||
reg_alpha: float = 0.,
|
||||
reg_lambda: float = 0.,
|
||||
random_state: Optional[Union[int, np.random.RandomState]] = None,
|
||||
n_jobs: int = -1,
|
||||
silent: bool = True,
|
||||
importance_type: str = 'split',
|
||||
client: Optional[Client] = None,
|
||||
**kwargs: Any
|
||||
):
|
||||
"""Docstring is inherited from the lightgbm.LGBMRanker.__init__."""
|
||||
self._client = client
|
||||
self.client = client
|
||||
super().__init__(
|
||||
boosting_type=boosting_type,
|
||||
num_leaves=num_leaves,
|
||||
max_depth=max_depth,
|
||||
learning_rate=learning_rate,
|
||||
n_estimators=n_estimators,
|
||||
subsample_for_bin=subsample_for_bin,
|
||||
objective=objective,
|
||||
class_weight=class_weight,
|
||||
min_split_gain=min_split_gain,
|
||||
min_child_weight=min_child_weight,
|
||||
min_child_samples=min_child_samples,
|
||||
subsample=subsample,
|
||||
subsample_freq=subsample_freq,
|
||||
colsample_bytree=colsample_bytree,
|
||||
reg_alpha=reg_alpha,
|
||||
reg_lambda=reg_lambda,
|
||||
random_state=random_state,
|
||||
n_jobs=n_jobs,
|
||||
silent=silent,
|
||||
importance_type=importance_type,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
_base_doc = LGBMRanker.__init__.__doc__
|
||||
_before_kwargs, _kwargs, _after_kwargs = _base_doc.partition('**kwargs')
|
||||
__init__.__doc__ = (
|
||||
_before_kwargs
|
||||
+ 'client : dask.distributed.Client or None, optional (default=None)\n'
|
||||
+ ' ' * 12 + 'Dask client. If ``None``, ``distributed.default_client()`` will be used at runtime. The Dask client used by this class will not be saved if the model object is pickled.\n'
|
||||
+ ' ' * 8 + _kwargs + _after_kwargs
|
||||
)
|
||||
|
||||
def __getstate__(self) -> Dict[Any, Any]:
|
||||
return self._lgb_getstate()
|
||||
|
||||
def fit(
|
||||
self,
|
||||
X: _DaskMatrixLike,
|
||||
|
@ -601,7 +824,6 @@ class DaskLGBMRanker(LGBMRanker, _DaskLGBMModel):
|
|||
sample_weight: Optional[_DaskCollection] = None,
|
||||
init_score: Optional[_DaskCollection] = None,
|
||||
group: Optional[_DaskCollection] = None,
|
||||
client: Optional[Client] = None,
|
||||
**kwargs: Any
|
||||
) -> "DaskLGBMRanker":
|
||||
"""Docstring is inherited from the lightgbm.LGBMRanker.fit."""
|
||||
|
@ -614,16 +836,10 @@ class DaskLGBMRanker(LGBMRanker, _DaskLGBMModel):
|
|||
y=y,
|
||||
sample_weight=sample_weight,
|
||||
group=group,
|
||||
client=client,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
_base_doc = LGBMRanker.fit.__doc__
|
||||
_before_eval_set, _eval_set, _after_eval_set = _base_doc.partition('eval_set :')
|
||||
fit.__doc__ = (_before_eval_set
|
||||
+ 'client : dask.distributed.Client or None, optional (default=None)\n'
|
||||
+ ' ' * 12 + 'Dask client.\n'
|
||||
+ ' ' * 8 + _eval_set + _after_eval_set)
|
||||
fit.__doc__ = LGBMRanker.fit.__doc__
|
||||
|
||||
def predict(self, X: _DaskMatrixLike, **kwargs: Any) -> dask_Array:
|
||||
"""Docstring is inherited from the lightgbm.LGBMRanker.predict."""
|
||||
|
|
|
@ -1,6 +1,9 @@
|
|||
# coding: utf-8
|
||||
"""Tests for lightgbm.dask module"""
|
||||
|
||||
import inspect
|
||||
import joblib
|
||||
import pickle
|
||||
import socket
|
||||
from itertools import groupby
|
||||
from os import getenv
|
||||
|
@ -13,13 +16,14 @@ if not platform.startswith('linux'):
|
|||
if not lgb.compat.DASK_INSTALLED:
|
||||
pytest.skip('Dask is not installed', allow_module_level=True)
|
||||
|
||||
import cloudpickle
|
||||
import dask.array as da
|
||||
import dask.dataframe as dd
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from scipy.stats import spearmanr
|
||||
from dask.array.utils import assert_eq
|
||||
from dask.distributed import wait
|
||||
from dask.distributed import default_client, Client, LocalCluster, wait
|
||||
from distributed.utils_test import client, cluster_fixture, gen_cluster, loop
|
||||
from scipy.sparse import csr_matrix
|
||||
from sklearn.datasets import make_blobs, make_regression
|
||||
|
@ -137,6 +141,32 @@ def _accuracy_score(dy_true, dy_pred):
|
|||
return da.average(dy_true == dy_pred).compute()
|
||||
|
||||
|
||||
def _pickle(obj, filepath, serializer):
|
||||
if serializer == 'pickle':
|
||||
with open(filepath, 'wb') as f:
|
||||
pickle.dump(obj, f)
|
||||
elif serializer == 'joblib':
|
||||
joblib.dump(obj, filepath)
|
||||
elif serializer == 'cloudpickle':
|
||||
with open(filepath, 'wb') as f:
|
||||
cloudpickle.dump(obj, f)
|
||||
else:
|
||||
raise ValueError(f'Unrecognized serializer type: {serializer}')
|
||||
|
||||
|
||||
def _unpickle(filepath, serializer):
|
||||
if serializer == 'pickle':
|
||||
with open(filepath, 'rb') as f:
|
||||
return pickle.load(f)
|
||||
elif serializer == 'joblib':
|
||||
return joblib.load(filepath)
|
||||
elif serializer == 'cloudpickle':
|
||||
with open(filepath, 'rb') as f:
|
||||
return cloudpickle.load(f)
|
||||
else:
|
||||
raise ValueError(f'Unrecognized serializer type: {serializer}')
|
||||
|
||||
|
||||
@pytest.mark.parametrize('output', data_output)
|
||||
@pytest.mark.parametrize('centers', data_centers)
|
||||
def test_classifier(output, centers, client, listen_port):
|
||||
|
@ -151,11 +181,12 @@ def test_classifier(output, centers, client, listen_port):
|
|||
"num_leaves": 10
|
||||
}
|
||||
dask_classifier = lgb.DaskLGBMClassifier(
|
||||
client=client,
|
||||
time_out=5,
|
||||
local_listen_port=listen_port,
|
||||
**params
|
||||
)
|
||||
dask_classifier = dask_classifier.fit(dX, dy, sample_weight=dw, client=client)
|
||||
dask_classifier = dask_classifier.fit(dX, dy, sample_weight=dw)
|
||||
p1 = dask_classifier.predict(dX)
|
||||
p1_proba = dask_classifier.predict_proba(dX).compute()
|
||||
p1_local = dask_classifier.to_local().predict(X)
|
||||
|
@ -193,12 +224,13 @@ def test_classifier_pred_contrib(output, centers, client, listen_port):
|
|||
"num_leaves": 10
|
||||
}
|
||||
dask_classifier = lgb.DaskLGBMClassifier(
|
||||
client=client,
|
||||
time_out=5,
|
||||
local_listen_port=listen_port,
|
||||
tree_learner='data',
|
||||
**params
|
||||
)
|
||||
dask_classifier = dask_classifier.fit(dX, dy, sample_weight=dw, client=client)
|
||||
dask_classifier = dask_classifier.fit(dX, dy, sample_weight=dw)
|
||||
preds_with_contrib = dask_classifier.predict(dX, pred_contrib=True).compute()
|
||||
|
||||
local_classifier = lgb.LGBMClassifier(**params)
|
||||
|
@ -241,6 +273,7 @@ def test_training_does_not_fail_on_port_conflicts(client):
|
|||
s.bind(('127.0.0.1', 12400))
|
||||
|
||||
dask_classifier = lgb.DaskLGBMClassifier(
|
||||
client=client,
|
||||
time_out=5,
|
||||
local_listen_port=12400,
|
||||
n_estimators=5,
|
||||
|
@ -251,7 +284,6 @@ def test_training_does_not_fail_on_port_conflicts(client):
|
|||
X=dX,
|
||||
y=dy,
|
||||
sample_weight=dw,
|
||||
client=client
|
||||
)
|
||||
assert dask_classifier.booster_
|
||||
|
||||
|
@ -270,12 +302,13 @@ def test_regressor(output, client, listen_port):
|
|||
"num_leaves": 10
|
||||
}
|
||||
dask_regressor = lgb.DaskLGBMRegressor(
|
||||
client=client,
|
||||
time_out=5,
|
||||
local_listen_port=listen_port,
|
||||
tree='data',
|
||||
**params
|
||||
)
|
||||
dask_regressor = dask_regressor.fit(dX, dy, client=client, sample_weight=dw)
|
||||
dask_regressor = dask_regressor.fit(dX, dy, sample_weight=dw)
|
||||
p1 = dask_regressor.predict(dX)
|
||||
if output != 'dataframe':
|
||||
s1 = _r2_score(dy, p1)
|
||||
|
@ -313,12 +346,13 @@ def test_regressor_pred_contrib(output, client, listen_port):
|
|||
"num_leaves": 10
|
||||
}
|
||||
dask_regressor = lgb.DaskLGBMRegressor(
|
||||
client=client,
|
||||
time_out=5,
|
||||
local_listen_port=listen_port,
|
||||
tree_learner='data',
|
||||
**params
|
||||
)
|
||||
dask_regressor = dask_regressor.fit(dX, dy, sample_weight=dw, client=client)
|
||||
dask_regressor = dask_regressor.fit(dX, dy, sample_weight=dw)
|
||||
preds_with_contrib = dask_regressor.predict(dX, pred_contrib=True).compute()
|
||||
|
||||
local_regressor = lgb.LGBMRegressor(**params)
|
||||
|
@ -353,11 +387,12 @@ def test_regressor_quantile(output, client, listen_port, alpha):
|
|||
"num_leaves": 10
|
||||
}
|
||||
dask_regressor = lgb.DaskLGBMRegressor(
|
||||
client=client,
|
||||
local_listen_port=listen_port,
|
||||
tree_learner_type='data_parallel',
|
||||
**params
|
||||
)
|
||||
dask_regressor = dask_regressor.fit(dX, dy, client=client, sample_weight=dw)
|
||||
dask_regressor = dask_regressor.fit(dX, dy, sample_weight=dw)
|
||||
p1 = dask_regressor.predict(dX).compute()
|
||||
q1 = np.count_nonzero(y < p1) / y.shape[0]
|
||||
|
||||
|
@ -400,12 +435,13 @@ def test_ranker(output, client, listen_port, group):
|
|||
"min_child_samples": 1
|
||||
}
|
||||
dask_ranker = lgb.DaskLGBMRanker(
|
||||
client=client,
|
||||
time_out=5,
|
||||
local_listen_port=listen_port,
|
||||
tree_learner_type='data_parallel',
|
||||
**params
|
||||
)
|
||||
dask_ranker = dask_ranker.fit(dX, dy, sample_weight=dw, group=dg, client=client)
|
||||
dask_ranker = dask_ranker.fit(dX, dy, sample_weight=dw, group=dg)
|
||||
rnkvec_dask = dask_ranker.predict(dX)
|
||||
rnkvec_dask = rnkvec_dask.compute()
|
||||
rnkvec_dask_local = dask_ranker.to_local().predict(X)
|
||||
|
@ -424,6 +460,288 @@ def test_ranker(output, client, listen_port, group):
|
|||
client.close(timeout=CLIENT_CLOSE_TIMEOUT)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('task', ['classification', 'regression', 'ranking'])
|
||||
def test_training_works_if_client_not_provided_or_set_after_construction(task, listen_port, client):
|
||||
if task == 'ranking':
|
||||
_, _, _, _, dX, dy, _, dg = _create_ranking_data(
|
||||
output='array',
|
||||
group=None
|
||||
)
|
||||
model_factory = lgb.DaskLGBMRanker
|
||||
else:
|
||||
_, _, _, dX, dy, _ = _create_data(
|
||||
objective=task,
|
||||
output='array',
|
||||
)
|
||||
dg = None
|
||||
if task == 'classification':
|
||||
model_factory = lgb.DaskLGBMClassifier
|
||||
elif task == 'regression':
|
||||
model_factory = lgb.DaskLGBMRegressor
|
||||
|
||||
params = {
|
||||
"time_out": 5,
|
||||
"local_listen_port": listen_port,
|
||||
"n_estimators": 1,
|
||||
"num_leaves": 2
|
||||
}
|
||||
|
||||
# should be able to use the class without specifying a client
|
||||
dask_model = model_factory(**params)
|
||||
assert dask_model._client is None
|
||||
assert dask_model.client is None
|
||||
with pytest.raises(lgb.compat.LGBMNotFittedError, match='Cannot access property client_ before calling fit'):
|
||||
dask_model.client_
|
||||
|
||||
dask_model.fit(dX, dy, group=dg)
|
||||
assert dask_model.fitted_
|
||||
assert dask_model._client is None
|
||||
assert dask_model.client is None
|
||||
assert dask_model.client_ == client
|
||||
|
||||
preds = dask_model.predict(dX)
|
||||
assert isinstance(preds, da.Array)
|
||||
assert dask_model.fitted_
|
||||
assert dask_model._client is None
|
||||
assert dask_model.client is None
|
||||
assert dask_model.client_ == client
|
||||
|
||||
local_model = dask_model.to_local()
|
||||
with pytest.raises(AttributeError):
|
||||
local_model._client
|
||||
local_model.client
|
||||
local_model.client_
|
||||
|
||||
# should be able to set client after construction
|
||||
dask_model = model_factory(**params)
|
||||
dask_model.set_params(client=client)
|
||||
assert dask_model._client == client
|
||||
assert dask_model.client == client
|
||||
|
||||
with pytest.raises(lgb.compat.LGBMNotFittedError, match='Cannot access property client_ before calling fit'):
|
||||
dask_model.client_
|
||||
|
||||
dask_model.fit(dX, dy, group=dg)
|
||||
assert dask_model.fitted_
|
||||
assert dask_model._client == client
|
||||
assert dask_model.client == client
|
||||
assert dask_model.client_ == client
|
||||
|
||||
preds = dask_model.predict(dX)
|
||||
assert isinstance(preds, da.Array)
|
||||
assert dask_model.fitted_
|
||||
assert dask_model._client == client
|
||||
assert dask_model.client == client
|
||||
assert dask_model.client_ == client
|
||||
|
||||
local_model = dask_model.to_local()
|
||||
assert getattr(local_model, "_client", None) is None
|
||||
with pytest.raises(AttributeError):
|
||||
local_model._client
|
||||
local_model.client
|
||||
local_model.client_
|
||||
|
||||
client.close(timeout=CLIENT_CLOSE_TIMEOUT)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('serializer', ['pickle', 'joblib', 'cloudpickle'])
|
||||
@pytest.mark.parametrize('task', ['classification', 'regression', 'ranking'])
|
||||
@pytest.mark.parametrize('set_client', [True, False])
|
||||
def test_model_and_local_version_are_picklable_whether_or_not_client_set_explicitly(serializer, task, set_client, listen_port, tmp_path):
|
||||
|
||||
with LocalCluster(n_workers=2, threads_per_worker=1) as cluster1:
|
||||
with Client(cluster1) as client1:
|
||||
|
||||
# data on cluster1
|
||||
if task == 'ranking':
|
||||
X_1, _, _, _, dX_1, dy_1, _, dg_1 = _create_ranking_data(
|
||||
output='array',
|
||||
group=None
|
||||
)
|
||||
else:
|
||||
X_1, _, _, dX_1, dy_1, _ = _create_data(
|
||||
objective=task,
|
||||
output='array',
|
||||
)
|
||||
dg_1 = None
|
||||
|
||||
with LocalCluster(n_workers=2, threads_per_worker=1) as cluster2:
|
||||
with Client(cluster2) as client2:
|
||||
|
||||
# create identical data on cluster2
|
||||
if task == 'ranking':
|
||||
X_2, _, _, _, dX_2, dy_2, _, dg_2 = _create_ranking_data(
|
||||
output='array',
|
||||
group=None
|
||||
)
|
||||
else:
|
||||
X_2, _, _, dX_2, dy_2, _ = _create_data(
|
||||
objective=task,
|
||||
output='array',
|
||||
)
|
||||
dg_2 = None
|
||||
|
||||
if task == 'ranking':
|
||||
model_factory = lgb.DaskLGBMRanker
|
||||
elif task == 'classification':
|
||||
model_factory = lgb.DaskLGBMClassifier
|
||||
elif task == 'regression':
|
||||
model_factory = lgb.DaskLGBMRegressor
|
||||
|
||||
params = {
|
||||
"time_out": 5,
|
||||
"local_listen_port": listen_port,
|
||||
"n_estimators": 1,
|
||||
"num_leaves": 2
|
||||
}
|
||||
|
||||
# at this point, the result of default_client() is client2 since it was the most recently
|
||||
# created. So setting client to client1 here to test that you can select a non-default client
|
||||
assert default_client() == client2
|
||||
if set_client:
|
||||
params.update({"client": client1})
|
||||
|
||||
# unfitted model should survive pickling round trip, and pickling
|
||||
# shouldn't have side effects on the model object
|
||||
dask_model = model_factory(**params)
|
||||
local_model = dask_model.to_local()
|
||||
if set_client:
|
||||
assert dask_model._client == client1
|
||||
assert dask_model.client == client1
|
||||
else:
|
||||
assert dask_model._client is None
|
||||
assert dask_model.client is None
|
||||
|
||||
with pytest.raises(lgb.compat.LGBMNotFittedError, match='Cannot access property client_ before calling fit'):
|
||||
dask_model.client_
|
||||
|
||||
assert "client" not in local_model.get_params()
|
||||
assert getattr(local_model, "client", None) is None
|
||||
|
||||
tmp_file = str(tmp_path / "model-1.pkl")
|
||||
_pickle(
|
||||
obj=dask_model,
|
||||
filepath=tmp_file,
|
||||
serializer=serializer
|
||||
)
|
||||
model_from_disk = _unpickle(
|
||||
filepath=tmp_file,
|
||||
serializer=serializer
|
||||
)
|
||||
|
||||
local_tmp_file = str(tmp_path / "local-model-1.pkl")
|
||||
_pickle(
|
||||
obj=local_model,
|
||||
filepath=local_tmp_file,
|
||||
serializer=serializer
|
||||
)
|
||||
local_model_from_disk = _unpickle(
|
||||
filepath=local_tmp_file,
|
||||
serializer=serializer
|
||||
)
|
||||
|
||||
assert model_from_disk._client is None
|
||||
assert model_from_disk.client is None
|
||||
|
||||
if set_client:
|
||||
assert dask_model._client == client1
|
||||
assert dask_model.client == client1
|
||||
else:
|
||||
assert dask_model._client is None
|
||||
assert dask_model.client is None
|
||||
|
||||
with pytest.raises(lgb.compat.LGBMNotFittedError, match='Cannot access property client_ before calling fit'):
|
||||
dask_model.client_
|
||||
|
||||
# client will always be None after unpickling
|
||||
if set_client:
|
||||
from_disk_params = model_from_disk.get_params()
|
||||
from_disk_params.pop("client", None)
|
||||
dask_params = dask_model.get_params()
|
||||
dask_params.pop("client", None)
|
||||
assert from_disk_params == dask_params
|
||||
else:
|
||||
assert model_from_disk.get_params() == dask_model.get_params()
|
||||
assert local_model_from_disk.get_params() == local_model.get_params()
|
||||
|
||||
# fitted model should survive pickling round trip, and pickling
|
||||
# shouldn't have side effects on the model object
|
||||
if set_client:
|
||||
dask_model.fit(dX_1, dy_1, group=dg_1)
|
||||
else:
|
||||
dask_model.fit(dX_2, dy_2, group=dg_2)
|
||||
local_model = dask_model.to_local()
|
||||
|
||||
assert "client" not in local_model.get_params()
|
||||
with pytest.raises(AttributeError):
|
||||
local_model._client
|
||||
local_model.client
|
||||
local_model.client_
|
||||
|
||||
tmp_file2 = str(tmp_path / "model-2.pkl")
|
||||
_pickle(
|
||||
obj=dask_model,
|
||||
filepath=tmp_file2,
|
||||
serializer=serializer
|
||||
)
|
||||
fitted_model_from_disk = _unpickle(
|
||||
filepath=tmp_file2,
|
||||
serializer=serializer
|
||||
)
|
||||
|
||||
local_tmp_file2 = str(tmp_path / "local-model-2.pkl")
|
||||
_pickle(
|
||||
obj=local_model,
|
||||
filepath=local_tmp_file2,
|
||||
serializer=serializer
|
||||
)
|
||||
local_fitted_model_from_disk = _unpickle(
|
||||
filepath=local_tmp_file2,
|
||||
serializer=serializer
|
||||
)
|
||||
|
||||
if set_client:
|
||||
assert dask_model._client == client1
|
||||
assert dask_model.client == client1
|
||||
assert dask_model.client_ == client1
|
||||
else:
|
||||
assert dask_model._client is None
|
||||
assert dask_model.client is None
|
||||
assert dask_model.client_ == default_client()
|
||||
assert dask_model.client_ == client2
|
||||
|
||||
assert isinstance(fitted_model_from_disk, model_factory)
|
||||
assert fitted_model_from_disk._client is None
|
||||
assert fitted_model_from_disk.client is None
|
||||
assert fitted_model_from_disk.client_ == default_client()
|
||||
assert fitted_model_from_disk.client_ == client2
|
||||
|
||||
# client will always be None after unpickling
|
||||
if set_client:
|
||||
from_disk_params = fitted_model_from_disk.get_params()
|
||||
from_disk_params.pop("client", None)
|
||||
dask_params = dask_model.get_params()
|
||||
dask_params.pop("client", None)
|
||||
assert from_disk_params == dask_params
|
||||
else:
|
||||
assert fitted_model_from_disk.get_params() == dask_model.get_params()
|
||||
assert local_fitted_model_from_disk.get_params() == local_model.get_params()
|
||||
|
||||
if set_client:
|
||||
preds_orig = dask_model.predict(dX_1).compute()
|
||||
preds_loaded_model = fitted_model_from_disk.predict(dX_1).compute()
|
||||
preds_orig_local = local_model.predict(X_1)
|
||||
preds_loaded_model_local = local_fitted_model_from_disk.predict(X_1)
|
||||
else:
|
||||
preds_orig = dask_model.predict(dX_2).compute()
|
||||
preds_loaded_model = fitted_model_from_disk.predict(dX_2).compute()
|
||||
preds_orig_local = local_model.predict(X_2)
|
||||
preds_loaded_model_local = local_fitted_model_from_disk.predict(X_2)
|
||||
|
||||
assert_eq(preds_orig, preds_loaded_model)
|
||||
assert_eq(preds_orig_local, preds_loaded_model_local)
|
||||
|
||||
|
||||
def test_find_open_port_works():
|
||||
worker_ip = '127.0.0.1'
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
|
@ -451,6 +769,7 @@ def test_warns_and_continues_on_unrecognized_tree_learner(client):
|
|||
X = da.random.random((1e3, 10))
|
||||
y = da.random.random((1e3, 1))
|
||||
dask_regressor = lgb.DaskLGBMRegressor(
|
||||
client=client,
|
||||
time_out=5,
|
||||
local_listen_port=1234,
|
||||
tree_learner='some-nonsense-value',
|
||||
|
@ -458,7 +777,7 @@ def test_warns_and_continues_on_unrecognized_tree_learner(client):
|
|||
num_leaves=2
|
||||
)
|
||||
with pytest.warns(UserWarning, match='Parameter tree_learner set to some-nonsense-value'):
|
||||
dask_regressor = dask_regressor.fit(X, y, client=client)
|
||||
dask_regressor = dask_regressor.fit(X, y)
|
||||
|
||||
assert dask_regressor.fitted_
|
||||
|
||||
|
@ -470,6 +789,7 @@ def test_warns_but_makes_no_changes_for_feature_or_voting_tree_learner(client):
|
|||
y = da.random.random((1e3, 1))
|
||||
for tree_learner in ['feature_parallel', 'voting']:
|
||||
dask_regressor = lgb.DaskLGBMRegressor(
|
||||
client=client,
|
||||
time_out=5,
|
||||
local_listen_port=1234,
|
||||
tree_learner=tree_learner,
|
||||
|
@ -477,7 +797,7 @@ def test_warns_but_makes_no_changes_for_feature_or_voting_tree_learner(client):
|
|||
num_leaves=2
|
||||
)
|
||||
with pytest.warns(UserWarning, match='Support for tree_learner %s in lightgbm' % tree_learner):
|
||||
dask_regressor = dask_regressor.fit(X, y, client=client)
|
||||
dask_regressor = dask_regressor.fit(X, y)
|
||||
|
||||
assert dask_regressor.fitted_
|
||||
assert dask_regressor.get_params()['tree_learner'] == tree_learner
|
||||
|
@ -501,3 +821,26 @@ def test_errors(c, s, a, b):
|
|||
model_factory=lgb.LGBMClassifier
|
||||
)
|
||||
assert 'foo' in str(info.value)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"classes",
|
||||
[
|
||||
(lgb.DaskLGBMClassifier, lgb.LGBMClassifier),
|
||||
(lgb.DaskLGBMRegressor, lgb.LGBMRegressor),
|
||||
(lgb.DaskLGBMRanker, lgb.LGBMRanker)
|
||||
]
|
||||
)
|
||||
def test_dask_classes_and_sklearn_equivalents_have_identical_constructors_except_client_arg(classes):
|
||||
dask_spec = inspect.getfullargspec(classes[0])
|
||||
sklearn_spec = inspect.getfullargspec(classes[1])
|
||||
assert dask_spec.varargs == sklearn_spec.varargs
|
||||
assert dask_spec.varkw == sklearn_spec.varkw
|
||||
assert dask_spec.kwonlyargs == sklearn_spec.kwonlyargs
|
||||
assert dask_spec.kwonlydefaults == sklearn_spec.kwonlydefaults
|
||||
|
||||
# "client" should be the only different, and the final argument
|
||||
assert dask_spec.args[:-1] == sklearn_spec.args
|
||||
assert dask_spec.defaults[:-1] == sklearn_spec.defaults
|
||||
assert dask_spec.args[-1] == 'client'
|
||||
assert dask_spec.defaults[-1] is None
|
||||
|
|
Загрузка…
Ссылка в новой задаче