[dask] remove 'client' kwarg from fit() and predict() (fixes #3808) (#3883)

* starting on Dask client

* more docs stuff

* fix pickling

* just copy docstrings

* fit docs

* switch test order

* linting

* use client kwarg

* remove inner set_params()

* add type hints

* fix type hints

* remove commented code

* reorder

* fix tests, add client_ property

* Apply suggestions from code review

Co-authored-by: Nikita Titov <nekit94-08@mail.ru>

* fix tests

* linting

* simplify

Co-authored-by: Nikita Titov <nekit94-08@mail.ru>
This commit is contained in:
James Lamb 2021-02-02 23:48:59 -06:00 коммит произвёл GitHub
Родитель 56fc036def
Коммит c3ac77b570
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
3 изменённых файлов: 601 добавлений и 42 удалений

Просмотреть файл

@ -95,7 +95,7 @@ if [[ $TASK == "swig" ]]; then
exit 0
fi
conda install -q -y -n $CONDA_ENV dask distributed joblib matplotlib numpy pandas psutil pytest scikit-learn scipy
conda install -q -y -n $CONDA_ENV cloudpickle dask distributed joblib matplotlib numpy pandas psutil pytest scikit-learn scipy
# graphviz must come from conda-forge to avoid this on some linux distros:
# https://github.com/conda-forge/graphviz-feedstock/issues/18

Просмотреть файл

@ -9,7 +9,7 @@ It is based on dask-lightgbm, which was based on dask-xgboost.
import socket
from collections import defaultdict
from copy import deepcopy
from typing import Any, Dict, Iterable, List, Optional, Type, Union
from typing import Any, Callable, Dict, Iterable, List, Optional, Type, Union
from urllib.parse import urlparse
import numpy as np
@ -17,7 +17,7 @@ import scipy.sparse as ss
from .basic import _choose_param_value, _ConfigAliases, _LIB, _log_warning, _safe_call, LightGBMError
from .compat import (PANDAS_INSTALLED, pd_DataFrame, pd_Series, concat,
SKLEARN_INSTALLED,
SKLEARN_INSTALLED, LGBMNotFittedError,
DASK_INSTALLED, dask_DataFrame, dask_Array, dask_Series, delayed, Client, default_client, get_worker, wait)
from .sklearn import LGBMClassifier, LGBMModel, LGBMRegressor, LGBMRanker
@ -27,6 +27,25 @@ _DaskPart = Union[np.ndarray, pd_DataFrame, pd_Series, ss.spmatrix]
_PredictionDtype = Union[Type[np.float32], Type[np.float64], Type[np.int32], Type[np.int64]]
def _get_dask_client(client: Optional[Client]) -> Client:
"""Choose a Dask client to use.
Parameters
----------
client : dask.distributed.Client or None
Dask client.
Returns
-------
client : dask.distributed.Client
A Dask client.
"""
if client is None:
return default_client()
else:
return client
def _find_open_port(worker_ip: str, local_listen_port: int, ports_to_skip: Iterable[int]) -> int:
"""Find an open port.
@ -434,6 +453,29 @@ def _predict(
class _DaskLGBMModel:
@property
def client_(self) -> Client:
"""Dask client.
This property can be passed in the constructor or updated
with ``model.set_params(client=client)``.
"""
if not getattr(self, "fitted_", False):
raise LGBMNotFittedError('Cannot access property client_ before calling fit().')
return _get_dask_client(client=self.client)
def _lgb_getstate(self) -> Dict[Any, Any]:
"""Remove un-picklable attributes before serialization."""
client = self.__dict__.pop("client", None)
self.__dict__.pop("_client", None)
self._other_params.pop("client", None)
out = deepcopy(self.__dict__)
out.update({"_client": None, "client": None})
self._client = client
self.client = client
return out
def _fit(
self,
model_factory: Type[LGBMModel],
@ -441,18 +483,16 @@ class _DaskLGBMModel:
y: _DaskCollection,
sample_weight: Optional[_DaskCollection] = None,
group: Optional[_DaskCollection] = None,
client: Optional[Client] = None,
**kwargs: Any
) -> "_DaskLGBMModel":
if not all((DASK_INSTALLED, PANDAS_INSTALLED, SKLEARN_INSTALLED)):
raise LightGBMError('dask, pandas and scikit-learn are required for lightgbm.dask')
if client is None:
client = default_client()
params = self.get_params(True)
params.pop("client", None)
model = _train(
client=client,
client=_get_dask_client(self.client),
data=X,
label=y,
params=params,
@ -468,8 +508,11 @@ class _DaskLGBMModel:
return self
def _to_local(self, model_factory: Type[LGBMModel]) -> LGBMModel:
model = model_factory(**self.get_params())
params = self.get_params()
params.pop("client", None)
model = model_factory(**params)
self._copy_extra_params(self, model)
model._other_params.pop("client", None)
return model
@staticmethod
@ -478,18 +521,82 @@ class _DaskLGBMModel:
attributes = source.__dict__
extra_param_names = set(attributes.keys()).difference(params.keys())
for name in extra_param_names:
setattr(dest, name, attributes[name])
if name != "_client":
setattr(dest, name, attributes[name])
class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel):
"""Distributed version of lightgbm.LGBMClassifier."""
def __init__(
self,
boosting_type: str = 'gbdt',
num_leaves: int = 31,
max_depth: int = -1,
learning_rate: float = 0.1,
n_estimators: int = 100,
subsample_for_bin: int = 200000,
objective: Optional[Union[Callable, str]] = None,
class_weight: Optional[Union[dict, str]] = None,
min_split_gain: float = 0.,
min_child_weight: float = 1e-3,
min_child_samples: int = 20,
subsample: float = 1.,
subsample_freq: int = 0,
colsample_bytree: float = 1.,
reg_alpha: float = 0.,
reg_lambda: float = 0.,
random_state: Optional[Union[int, np.random.RandomState]] = None,
n_jobs: int = -1,
silent: bool = True,
importance_type: str = 'split',
client: Optional[Client] = None,
**kwargs: Any
):
"""Docstring is inherited from the lightgbm.LGBMClassifier.__init__."""
self._client = client
self.client = client
super().__init__(
boosting_type=boosting_type,
num_leaves=num_leaves,
max_depth=max_depth,
learning_rate=learning_rate,
n_estimators=n_estimators,
subsample_for_bin=subsample_for_bin,
objective=objective,
class_weight=class_weight,
min_split_gain=min_split_gain,
min_child_weight=min_child_weight,
min_child_samples=min_child_samples,
subsample=subsample,
subsample_freq=subsample_freq,
colsample_bytree=colsample_bytree,
reg_alpha=reg_alpha,
reg_lambda=reg_lambda,
random_state=random_state,
n_jobs=n_jobs,
silent=silent,
importance_type=importance_type,
**kwargs
)
_base_doc = LGBMClassifier.__init__.__doc__
_before_kwargs, _kwargs, _after_kwargs = _base_doc.partition('**kwargs')
__init__.__doc__ = (
_before_kwargs
+ 'client : dask.distributed.Client or None, optional (default=None)\n'
+ ' ' * 12 + 'Dask client. If ``None``, ``distributed.default_client()`` will be used at runtime. The Dask client used by this class will not be saved if the model object is pickled.\n'
+ ' ' * 8 + _kwargs + _after_kwargs
)
def __getstate__(self) -> Dict[Any, Any]:
return self._lgb_getstate()
def fit(
self,
X: _DaskMatrixLike,
y: _DaskCollection,
sample_weight: Optional[_DaskCollection] = None,
client: Optional[Client] = None,
**kwargs: Any
) -> "DaskLGBMClassifier":
"""Docstring is inherited from the lightgbm.LGBMClassifier.fit."""
@ -498,16 +605,10 @@ class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel):
X=X,
y=y,
sample_weight=sample_weight,
client=client,
**kwargs
)
_base_doc = LGBMClassifier.fit.__doc__
_before_init_score, _init_score, _after_init_score = _base_doc.partition('init_score :')
fit.__doc__ = (_before_init_score
+ 'client : dask.distributed.Client or None, optional (default=None)\n'
+ ' ' * 12 + 'Dask client.\n'
+ ' ' * 8 + _init_score + _after_init_score)
fit.__doc__ = LGBMClassifier.fit.__doc__
def predict(self, X: _DaskMatrixLike, **kwargs: Any) -> dask_Array:
"""Docstring is inherited from the lightgbm.LGBMClassifier.predict."""
@ -545,6 +646,70 @@ class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel):
class DaskLGBMRegressor(LGBMRegressor, _DaskLGBMModel):
"""Distributed version of lightgbm.LGBMRegressor."""
def __init__(
self,
boosting_type: str = 'gbdt',
num_leaves: int = 31,
max_depth: int = -1,
learning_rate: float = 0.1,
n_estimators: int = 100,
subsample_for_bin: int = 200000,
objective: Optional[Union[Callable, str]] = None,
class_weight: Optional[Union[dict, str]] = None,
min_split_gain: float = 0.,
min_child_weight: float = 1e-3,
min_child_samples: int = 20,
subsample: float = 1.,
subsample_freq: int = 0,
colsample_bytree: float = 1.,
reg_alpha: float = 0.,
reg_lambda: float = 0.,
random_state: Optional[Union[int, np.random.RandomState]] = None,
n_jobs: int = -1,
silent: bool = True,
importance_type: str = 'split',
client: Optional[Client] = None,
**kwargs: Any
):
"""Docstring is inherited from the lightgbm.LGBMRegressor.__init__."""
self._client = client
self.client = client
super().__init__(
boosting_type=boosting_type,
num_leaves=num_leaves,
max_depth=max_depth,
learning_rate=learning_rate,
n_estimators=n_estimators,
subsample_for_bin=subsample_for_bin,
objective=objective,
class_weight=class_weight,
min_split_gain=min_split_gain,
min_child_weight=min_child_weight,
min_child_samples=min_child_samples,
subsample=subsample,
subsample_freq=subsample_freq,
colsample_bytree=colsample_bytree,
reg_alpha=reg_alpha,
reg_lambda=reg_lambda,
random_state=random_state,
n_jobs=n_jobs,
silent=silent,
importance_type=importance_type,
**kwargs
)
_base_doc = LGBMRegressor.__init__.__doc__
_before_kwargs, _kwargs, _after_kwargs = _base_doc.partition('**kwargs')
__init__.__doc__ = (
_before_kwargs
+ 'client : dask.distributed.Client or None, optional (default=None)\n'
+ ' ' * 12 + 'Dask client. If ``None``, ``distributed.default_client()`` will be used at runtime. The Dask client used by this class will not be saved if the model object is pickled.\n'
+ ' ' * 8 + _kwargs + _after_kwargs
)
def __getstate__(self) -> Dict[Any, Any]:
return self._lgb_getstate()
def fit(
self,
X: _DaskMatrixLike,
@ -559,16 +724,10 @@ class DaskLGBMRegressor(LGBMRegressor, _DaskLGBMModel):
X=X,
y=y,
sample_weight=sample_weight,
client=client,
**kwargs
)
_base_doc = LGBMRegressor.fit.__doc__
_before_init_score, _init_score, _after_init_score = _base_doc.partition('init_score :')
fit.__doc__ = (_before_init_score
+ 'client : dask.distributed.Client or None, optional (default=None)\n'
+ ' ' * 12 + 'Dask client.\n'
+ ' ' * 8 + _init_score + _after_init_score)
fit.__doc__ = LGBMRegressor.fit.__doc__
def predict(self, X: _DaskMatrixLike, **kwargs) -> dask_Array:
"""Docstring is inherited from the lightgbm.LGBMRegressor.predict."""
@ -594,6 +753,70 @@ class DaskLGBMRegressor(LGBMRegressor, _DaskLGBMModel):
class DaskLGBMRanker(LGBMRanker, _DaskLGBMModel):
"""Distributed version of lightgbm.LGBMRanker."""
def __init__(
self,
boosting_type: str = 'gbdt',
num_leaves: int = 31,
max_depth: int = -1,
learning_rate: float = 0.1,
n_estimators: int = 100,
subsample_for_bin: int = 200000,
objective: Optional[Union[Callable, str]] = None,
class_weight: Optional[Union[dict, str]] = None,
min_split_gain: float = 0.,
min_child_weight: float = 1e-3,
min_child_samples: int = 20,
subsample: float = 1.,
subsample_freq: int = 0,
colsample_bytree: float = 1.,
reg_alpha: float = 0.,
reg_lambda: float = 0.,
random_state: Optional[Union[int, np.random.RandomState]] = None,
n_jobs: int = -1,
silent: bool = True,
importance_type: str = 'split',
client: Optional[Client] = None,
**kwargs: Any
):
"""Docstring is inherited from the lightgbm.LGBMRanker.__init__."""
self._client = client
self.client = client
super().__init__(
boosting_type=boosting_type,
num_leaves=num_leaves,
max_depth=max_depth,
learning_rate=learning_rate,
n_estimators=n_estimators,
subsample_for_bin=subsample_for_bin,
objective=objective,
class_weight=class_weight,
min_split_gain=min_split_gain,
min_child_weight=min_child_weight,
min_child_samples=min_child_samples,
subsample=subsample,
subsample_freq=subsample_freq,
colsample_bytree=colsample_bytree,
reg_alpha=reg_alpha,
reg_lambda=reg_lambda,
random_state=random_state,
n_jobs=n_jobs,
silent=silent,
importance_type=importance_type,
**kwargs
)
_base_doc = LGBMRanker.__init__.__doc__
_before_kwargs, _kwargs, _after_kwargs = _base_doc.partition('**kwargs')
__init__.__doc__ = (
_before_kwargs
+ 'client : dask.distributed.Client or None, optional (default=None)\n'
+ ' ' * 12 + 'Dask client. If ``None``, ``distributed.default_client()`` will be used at runtime. The Dask client used by this class will not be saved if the model object is pickled.\n'
+ ' ' * 8 + _kwargs + _after_kwargs
)
def __getstate__(self) -> Dict[Any, Any]:
return self._lgb_getstate()
def fit(
self,
X: _DaskMatrixLike,
@ -601,7 +824,6 @@ class DaskLGBMRanker(LGBMRanker, _DaskLGBMModel):
sample_weight: Optional[_DaskCollection] = None,
init_score: Optional[_DaskCollection] = None,
group: Optional[_DaskCollection] = None,
client: Optional[Client] = None,
**kwargs: Any
) -> "DaskLGBMRanker":
"""Docstring is inherited from the lightgbm.LGBMRanker.fit."""
@ -614,16 +836,10 @@ class DaskLGBMRanker(LGBMRanker, _DaskLGBMModel):
y=y,
sample_weight=sample_weight,
group=group,
client=client,
**kwargs
)
_base_doc = LGBMRanker.fit.__doc__
_before_eval_set, _eval_set, _after_eval_set = _base_doc.partition('eval_set :')
fit.__doc__ = (_before_eval_set
+ 'client : dask.distributed.Client or None, optional (default=None)\n'
+ ' ' * 12 + 'Dask client.\n'
+ ' ' * 8 + _eval_set + _after_eval_set)
fit.__doc__ = LGBMRanker.fit.__doc__
def predict(self, X: _DaskMatrixLike, **kwargs: Any) -> dask_Array:
"""Docstring is inherited from the lightgbm.LGBMRanker.predict."""

Просмотреть файл

@ -1,6 +1,9 @@
# coding: utf-8
"""Tests for lightgbm.dask module"""
import inspect
import joblib
import pickle
import socket
from itertools import groupby
from os import getenv
@ -13,13 +16,14 @@ if not platform.startswith('linux'):
if not lgb.compat.DASK_INSTALLED:
pytest.skip('Dask is not installed', allow_module_level=True)
import cloudpickle
import dask.array as da
import dask.dataframe as dd
import numpy as np
import pandas as pd
from scipy.stats import spearmanr
from dask.array.utils import assert_eq
from dask.distributed import wait
from dask.distributed import default_client, Client, LocalCluster, wait
from distributed.utils_test import client, cluster_fixture, gen_cluster, loop
from scipy.sparse import csr_matrix
from sklearn.datasets import make_blobs, make_regression
@ -137,6 +141,32 @@ def _accuracy_score(dy_true, dy_pred):
return da.average(dy_true == dy_pred).compute()
def _pickle(obj, filepath, serializer):
if serializer == 'pickle':
with open(filepath, 'wb') as f:
pickle.dump(obj, f)
elif serializer == 'joblib':
joblib.dump(obj, filepath)
elif serializer == 'cloudpickle':
with open(filepath, 'wb') as f:
cloudpickle.dump(obj, f)
else:
raise ValueError(f'Unrecognized serializer type: {serializer}')
def _unpickle(filepath, serializer):
if serializer == 'pickle':
with open(filepath, 'rb') as f:
return pickle.load(f)
elif serializer == 'joblib':
return joblib.load(filepath)
elif serializer == 'cloudpickle':
with open(filepath, 'rb') as f:
return cloudpickle.load(f)
else:
raise ValueError(f'Unrecognized serializer type: {serializer}')
@pytest.mark.parametrize('output', data_output)
@pytest.mark.parametrize('centers', data_centers)
def test_classifier(output, centers, client, listen_port):
@ -151,11 +181,12 @@ def test_classifier(output, centers, client, listen_port):
"num_leaves": 10
}
dask_classifier = lgb.DaskLGBMClassifier(
client=client,
time_out=5,
local_listen_port=listen_port,
**params
)
dask_classifier = dask_classifier.fit(dX, dy, sample_weight=dw, client=client)
dask_classifier = dask_classifier.fit(dX, dy, sample_weight=dw)
p1 = dask_classifier.predict(dX)
p1_proba = dask_classifier.predict_proba(dX).compute()
p1_local = dask_classifier.to_local().predict(X)
@ -193,12 +224,13 @@ def test_classifier_pred_contrib(output, centers, client, listen_port):
"num_leaves": 10
}
dask_classifier = lgb.DaskLGBMClassifier(
client=client,
time_out=5,
local_listen_port=listen_port,
tree_learner='data',
**params
)
dask_classifier = dask_classifier.fit(dX, dy, sample_weight=dw, client=client)
dask_classifier = dask_classifier.fit(dX, dy, sample_weight=dw)
preds_with_contrib = dask_classifier.predict(dX, pred_contrib=True).compute()
local_classifier = lgb.LGBMClassifier(**params)
@ -241,6 +273,7 @@ def test_training_does_not_fail_on_port_conflicts(client):
s.bind(('127.0.0.1', 12400))
dask_classifier = lgb.DaskLGBMClassifier(
client=client,
time_out=5,
local_listen_port=12400,
n_estimators=5,
@ -251,7 +284,6 @@ def test_training_does_not_fail_on_port_conflicts(client):
X=dX,
y=dy,
sample_weight=dw,
client=client
)
assert dask_classifier.booster_
@ -270,12 +302,13 @@ def test_regressor(output, client, listen_port):
"num_leaves": 10
}
dask_regressor = lgb.DaskLGBMRegressor(
client=client,
time_out=5,
local_listen_port=listen_port,
tree='data',
**params
)
dask_regressor = dask_regressor.fit(dX, dy, client=client, sample_weight=dw)
dask_regressor = dask_regressor.fit(dX, dy, sample_weight=dw)
p1 = dask_regressor.predict(dX)
if output != 'dataframe':
s1 = _r2_score(dy, p1)
@ -313,12 +346,13 @@ def test_regressor_pred_contrib(output, client, listen_port):
"num_leaves": 10
}
dask_regressor = lgb.DaskLGBMRegressor(
client=client,
time_out=5,
local_listen_port=listen_port,
tree_learner='data',
**params
)
dask_regressor = dask_regressor.fit(dX, dy, sample_weight=dw, client=client)
dask_regressor = dask_regressor.fit(dX, dy, sample_weight=dw)
preds_with_contrib = dask_regressor.predict(dX, pred_contrib=True).compute()
local_regressor = lgb.LGBMRegressor(**params)
@ -353,11 +387,12 @@ def test_regressor_quantile(output, client, listen_port, alpha):
"num_leaves": 10
}
dask_regressor = lgb.DaskLGBMRegressor(
client=client,
local_listen_port=listen_port,
tree_learner_type='data_parallel',
**params
)
dask_regressor = dask_regressor.fit(dX, dy, client=client, sample_weight=dw)
dask_regressor = dask_regressor.fit(dX, dy, sample_weight=dw)
p1 = dask_regressor.predict(dX).compute()
q1 = np.count_nonzero(y < p1) / y.shape[0]
@ -400,12 +435,13 @@ def test_ranker(output, client, listen_port, group):
"min_child_samples": 1
}
dask_ranker = lgb.DaskLGBMRanker(
client=client,
time_out=5,
local_listen_port=listen_port,
tree_learner_type='data_parallel',
**params
)
dask_ranker = dask_ranker.fit(dX, dy, sample_weight=dw, group=dg, client=client)
dask_ranker = dask_ranker.fit(dX, dy, sample_weight=dw, group=dg)
rnkvec_dask = dask_ranker.predict(dX)
rnkvec_dask = rnkvec_dask.compute()
rnkvec_dask_local = dask_ranker.to_local().predict(X)
@ -424,6 +460,288 @@ def test_ranker(output, client, listen_port, group):
client.close(timeout=CLIENT_CLOSE_TIMEOUT)
@pytest.mark.parametrize('task', ['classification', 'regression', 'ranking'])
def test_training_works_if_client_not_provided_or_set_after_construction(task, listen_port, client):
if task == 'ranking':
_, _, _, _, dX, dy, _, dg = _create_ranking_data(
output='array',
group=None
)
model_factory = lgb.DaskLGBMRanker
else:
_, _, _, dX, dy, _ = _create_data(
objective=task,
output='array',
)
dg = None
if task == 'classification':
model_factory = lgb.DaskLGBMClassifier
elif task == 'regression':
model_factory = lgb.DaskLGBMRegressor
params = {
"time_out": 5,
"local_listen_port": listen_port,
"n_estimators": 1,
"num_leaves": 2
}
# should be able to use the class without specifying a client
dask_model = model_factory(**params)
assert dask_model._client is None
assert dask_model.client is None
with pytest.raises(lgb.compat.LGBMNotFittedError, match='Cannot access property client_ before calling fit'):
dask_model.client_
dask_model.fit(dX, dy, group=dg)
assert dask_model.fitted_
assert dask_model._client is None
assert dask_model.client is None
assert dask_model.client_ == client
preds = dask_model.predict(dX)
assert isinstance(preds, da.Array)
assert dask_model.fitted_
assert dask_model._client is None
assert dask_model.client is None
assert dask_model.client_ == client
local_model = dask_model.to_local()
with pytest.raises(AttributeError):
local_model._client
local_model.client
local_model.client_
# should be able to set client after construction
dask_model = model_factory(**params)
dask_model.set_params(client=client)
assert dask_model._client == client
assert dask_model.client == client
with pytest.raises(lgb.compat.LGBMNotFittedError, match='Cannot access property client_ before calling fit'):
dask_model.client_
dask_model.fit(dX, dy, group=dg)
assert dask_model.fitted_
assert dask_model._client == client
assert dask_model.client == client
assert dask_model.client_ == client
preds = dask_model.predict(dX)
assert isinstance(preds, da.Array)
assert dask_model.fitted_
assert dask_model._client == client
assert dask_model.client == client
assert dask_model.client_ == client
local_model = dask_model.to_local()
assert getattr(local_model, "_client", None) is None
with pytest.raises(AttributeError):
local_model._client
local_model.client
local_model.client_
client.close(timeout=CLIENT_CLOSE_TIMEOUT)
@pytest.mark.parametrize('serializer', ['pickle', 'joblib', 'cloudpickle'])
@pytest.mark.parametrize('task', ['classification', 'regression', 'ranking'])
@pytest.mark.parametrize('set_client', [True, False])
def test_model_and_local_version_are_picklable_whether_or_not_client_set_explicitly(serializer, task, set_client, listen_port, tmp_path):
with LocalCluster(n_workers=2, threads_per_worker=1) as cluster1:
with Client(cluster1) as client1:
# data on cluster1
if task == 'ranking':
X_1, _, _, _, dX_1, dy_1, _, dg_1 = _create_ranking_data(
output='array',
group=None
)
else:
X_1, _, _, dX_1, dy_1, _ = _create_data(
objective=task,
output='array',
)
dg_1 = None
with LocalCluster(n_workers=2, threads_per_worker=1) as cluster2:
with Client(cluster2) as client2:
# create identical data on cluster2
if task == 'ranking':
X_2, _, _, _, dX_2, dy_2, _, dg_2 = _create_ranking_data(
output='array',
group=None
)
else:
X_2, _, _, dX_2, dy_2, _ = _create_data(
objective=task,
output='array',
)
dg_2 = None
if task == 'ranking':
model_factory = lgb.DaskLGBMRanker
elif task == 'classification':
model_factory = lgb.DaskLGBMClassifier
elif task == 'regression':
model_factory = lgb.DaskLGBMRegressor
params = {
"time_out": 5,
"local_listen_port": listen_port,
"n_estimators": 1,
"num_leaves": 2
}
# at this point, the result of default_client() is client2 since it was the most recently
# created. So setting client to client1 here to test that you can select a non-default client
assert default_client() == client2
if set_client:
params.update({"client": client1})
# unfitted model should survive pickling round trip, and pickling
# shouldn't have side effects on the model object
dask_model = model_factory(**params)
local_model = dask_model.to_local()
if set_client:
assert dask_model._client == client1
assert dask_model.client == client1
else:
assert dask_model._client is None
assert dask_model.client is None
with pytest.raises(lgb.compat.LGBMNotFittedError, match='Cannot access property client_ before calling fit'):
dask_model.client_
assert "client" not in local_model.get_params()
assert getattr(local_model, "client", None) is None
tmp_file = str(tmp_path / "model-1.pkl")
_pickle(
obj=dask_model,
filepath=tmp_file,
serializer=serializer
)
model_from_disk = _unpickle(
filepath=tmp_file,
serializer=serializer
)
local_tmp_file = str(tmp_path / "local-model-1.pkl")
_pickle(
obj=local_model,
filepath=local_tmp_file,
serializer=serializer
)
local_model_from_disk = _unpickle(
filepath=local_tmp_file,
serializer=serializer
)
assert model_from_disk._client is None
assert model_from_disk.client is None
if set_client:
assert dask_model._client == client1
assert dask_model.client == client1
else:
assert dask_model._client is None
assert dask_model.client is None
with pytest.raises(lgb.compat.LGBMNotFittedError, match='Cannot access property client_ before calling fit'):
dask_model.client_
# client will always be None after unpickling
if set_client:
from_disk_params = model_from_disk.get_params()
from_disk_params.pop("client", None)
dask_params = dask_model.get_params()
dask_params.pop("client", None)
assert from_disk_params == dask_params
else:
assert model_from_disk.get_params() == dask_model.get_params()
assert local_model_from_disk.get_params() == local_model.get_params()
# fitted model should survive pickling round trip, and pickling
# shouldn't have side effects on the model object
if set_client:
dask_model.fit(dX_1, dy_1, group=dg_1)
else:
dask_model.fit(dX_2, dy_2, group=dg_2)
local_model = dask_model.to_local()
assert "client" not in local_model.get_params()
with pytest.raises(AttributeError):
local_model._client
local_model.client
local_model.client_
tmp_file2 = str(tmp_path / "model-2.pkl")
_pickle(
obj=dask_model,
filepath=tmp_file2,
serializer=serializer
)
fitted_model_from_disk = _unpickle(
filepath=tmp_file2,
serializer=serializer
)
local_tmp_file2 = str(tmp_path / "local-model-2.pkl")
_pickle(
obj=local_model,
filepath=local_tmp_file2,
serializer=serializer
)
local_fitted_model_from_disk = _unpickle(
filepath=local_tmp_file2,
serializer=serializer
)
if set_client:
assert dask_model._client == client1
assert dask_model.client == client1
assert dask_model.client_ == client1
else:
assert dask_model._client is None
assert dask_model.client is None
assert dask_model.client_ == default_client()
assert dask_model.client_ == client2
assert isinstance(fitted_model_from_disk, model_factory)
assert fitted_model_from_disk._client is None
assert fitted_model_from_disk.client is None
assert fitted_model_from_disk.client_ == default_client()
assert fitted_model_from_disk.client_ == client2
# client will always be None after unpickling
if set_client:
from_disk_params = fitted_model_from_disk.get_params()
from_disk_params.pop("client", None)
dask_params = dask_model.get_params()
dask_params.pop("client", None)
assert from_disk_params == dask_params
else:
assert fitted_model_from_disk.get_params() == dask_model.get_params()
assert local_fitted_model_from_disk.get_params() == local_model.get_params()
if set_client:
preds_orig = dask_model.predict(dX_1).compute()
preds_loaded_model = fitted_model_from_disk.predict(dX_1).compute()
preds_orig_local = local_model.predict(X_1)
preds_loaded_model_local = local_fitted_model_from_disk.predict(X_1)
else:
preds_orig = dask_model.predict(dX_2).compute()
preds_loaded_model = fitted_model_from_disk.predict(dX_2).compute()
preds_orig_local = local_model.predict(X_2)
preds_loaded_model_local = local_fitted_model_from_disk.predict(X_2)
assert_eq(preds_orig, preds_loaded_model)
assert_eq(preds_orig_local, preds_loaded_model_local)
def test_find_open_port_works():
worker_ip = '127.0.0.1'
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
@ -451,6 +769,7 @@ def test_warns_and_continues_on_unrecognized_tree_learner(client):
X = da.random.random((1e3, 10))
y = da.random.random((1e3, 1))
dask_regressor = lgb.DaskLGBMRegressor(
client=client,
time_out=5,
local_listen_port=1234,
tree_learner='some-nonsense-value',
@ -458,7 +777,7 @@ def test_warns_and_continues_on_unrecognized_tree_learner(client):
num_leaves=2
)
with pytest.warns(UserWarning, match='Parameter tree_learner set to some-nonsense-value'):
dask_regressor = dask_regressor.fit(X, y, client=client)
dask_regressor = dask_regressor.fit(X, y)
assert dask_regressor.fitted_
@ -470,6 +789,7 @@ def test_warns_but_makes_no_changes_for_feature_or_voting_tree_learner(client):
y = da.random.random((1e3, 1))
for tree_learner in ['feature_parallel', 'voting']:
dask_regressor = lgb.DaskLGBMRegressor(
client=client,
time_out=5,
local_listen_port=1234,
tree_learner=tree_learner,
@ -477,7 +797,7 @@ def test_warns_but_makes_no_changes_for_feature_or_voting_tree_learner(client):
num_leaves=2
)
with pytest.warns(UserWarning, match='Support for tree_learner %s in lightgbm' % tree_learner):
dask_regressor = dask_regressor.fit(X, y, client=client)
dask_regressor = dask_regressor.fit(X, y)
assert dask_regressor.fitted_
assert dask_regressor.get_params()['tree_learner'] == tree_learner
@ -501,3 +821,26 @@ def test_errors(c, s, a, b):
model_factory=lgb.LGBMClassifier
)
assert 'foo' in str(info.value)
@pytest.mark.parametrize(
"classes",
[
(lgb.DaskLGBMClassifier, lgb.LGBMClassifier),
(lgb.DaskLGBMRegressor, lgb.LGBMRegressor),
(lgb.DaskLGBMRanker, lgb.LGBMRanker)
]
)
def test_dask_classes_and_sklearn_equivalents_have_identical_constructors_except_client_arg(classes):
dask_spec = inspect.getfullargspec(classes[0])
sklearn_spec = inspect.getfullargspec(classes[1])
assert dask_spec.varargs == sklearn_spec.varargs
assert dask_spec.varkw == sklearn_spec.varkw
assert dask_spec.kwonlyargs == sklearn_spec.kwonlyargs
assert dask_spec.kwonlydefaults == sklearn_spec.kwonlydefaults
# "client" should be the only different, and the final argument
assert dask_spec.args[:-1] == sklearn_spec.args
assert dask_spec.defaults[:-1] == sklearn_spec.defaults
assert dask_spec.args[-1] == 'client'
assert dask_spec.defaults[-1] is None