зеркало из https://github.com/microsoft/LightGBM.git
[dask] remove unused private _client attribute (#3904)
* Update test_dask.py
* Update dask.py
* Update .vsts-ci.yml
* Revert "Update .vsts-ci.yml"
This reverts commit 98422be5b5
.
This commit is contained in:
Родитель
08c68c917b
Коммит
b1e000c045
|
@ -468,11 +468,9 @@ class _DaskLGBMModel:
|
|||
def _lgb_getstate(self) -> Dict[Any, Any]:
|
||||
"""Remove un-picklable attributes before serialization."""
|
||||
client = self.__dict__.pop("client", None)
|
||||
self.__dict__.pop("_client", None)
|
||||
self._other_params.pop("client", None)
|
||||
out = deepcopy(self.__dict__)
|
||||
out.update({"_client": None, "client": None})
|
||||
self._client = client
|
||||
out.update({"client": None})
|
||||
self.client = client
|
||||
return out
|
||||
|
||||
|
@ -521,8 +519,7 @@ class _DaskLGBMModel:
|
|||
attributes = source.__dict__
|
||||
extra_param_names = set(attributes.keys()).difference(params.keys())
|
||||
for name in extra_param_names:
|
||||
if name != "_client":
|
||||
setattr(dest, name, attributes[name])
|
||||
setattr(dest, name, attributes[name])
|
||||
|
||||
|
||||
class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel):
|
||||
|
@ -554,7 +551,6 @@ class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel):
|
|||
**kwargs: Any
|
||||
):
|
||||
"""Docstring is inherited from the lightgbm.LGBMClassifier.__init__."""
|
||||
self._client = client
|
||||
self.client = client
|
||||
super().__init__(
|
||||
boosting_type=boosting_type,
|
||||
|
@ -672,7 +668,6 @@ class DaskLGBMRegressor(LGBMRegressor, _DaskLGBMModel):
|
|||
**kwargs: Any
|
||||
):
|
||||
"""Docstring is inherited from the lightgbm.LGBMRegressor.__init__."""
|
||||
self._client = client
|
||||
self.client = client
|
||||
super().__init__(
|
||||
boosting_type=boosting_type,
|
||||
|
@ -779,7 +774,6 @@ class DaskLGBMRanker(LGBMRanker, _DaskLGBMModel):
|
|||
**kwargs: Any
|
||||
):
|
||||
"""Docstring is inherited from the lightgbm.LGBMRanker.__init__."""
|
||||
self._client = client
|
||||
self.client = client
|
||||
super().__init__(
|
||||
boosting_type=boosting_type,
|
||||
|
|
|
@ -2,7 +2,6 @@
|
|||
"""Tests for lightgbm.dask module"""
|
||||
|
||||
import inspect
|
||||
import joblib
|
||||
import pickle
|
||||
import socket
|
||||
from itertools import groupby
|
||||
|
@ -19,6 +18,7 @@ if not lgb.compat.DASK_INSTALLED:
|
|||
import cloudpickle
|
||||
import dask.array as da
|
||||
import dask.dataframe as dd
|
||||
import joblib
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from scipy.stats import spearmanr
|
||||
|
@ -488,34 +488,29 @@ def test_training_works_if_client_not_provided_or_set_after_construction(task, l
|
|||
|
||||
# should be able to use the class without specifying a client
|
||||
dask_model = model_factory(**params)
|
||||
assert dask_model._client is None
|
||||
assert dask_model.client is None
|
||||
with pytest.raises(lgb.compat.LGBMNotFittedError, match='Cannot access property client_ before calling fit'):
|
||||
dask_model.client_
|
||||
|
||||
dask_model.fit(dX, dy, group=dg)
|
||||
assert dask_model.fitted_
|
||||
assert dask_model._client is None
|
||||
assert dask_model.client is None
|
||||
assert dask_model.client_ == client
|
||||
|
||||
preds = dask_model.predict(dX)
|
||||
assert isinstance(preds, da.Array)
|
||||
assert dask_model.fitted_
|
||||
assert dask_model._client is None
|
||||
assert dask_model.client is None
|
||||
assert dask_model.client_ == client
|
||||
|
||||
local_model = dask_model.to_local()
|
||||
with pytest.raises(AttributeError):
|
||||
local_model._client
|
||||
local_model.client
|
||||
local_model.client_
|
||||
|
||||
# should be able to set client after construction
|
||||
dask_model = model_factory(**params)
|
||||
dask_model.set_params(client=client)
|
||||
assert dask_model._client == client
|
||||
assert dask_model.client == client
|
||||
|
||||
with pytest.raises(lgb.compat.LGBMNotFittedError, match='Cannot access property client_ before calling fit'):
|
||||
|
@ -523,21 +518,17 @@ def test_training_works_if_client_not_provided_or_set_after_construction(task, l
|
|||
|
||||
dask_model.fit(dX, dy, group=dg)
|
||||
assert dask_model.fitted_
|
||||
assert dask_model._client == client
|
||||
assert dask_model.client == client
|
||||
assert dask_model.client_ == client
|
||||
|
||||
preds = dask_model.predict(dX)
|
||||
assert isinstance(preds, da.Array)
|
||||
assert dask_model.fitted_
|
||||
assert dask_model._client == client
|
||||
assert dask_model.client == client
|
||||
assert dask_model.client_ == client
|
||||
|
||||
local_model = dask_model.to_local()
|
||||
assert getattr(local_model, "_client", None) is None
|
||||
with pytest.raises(AttributeError):
|
||||
local_model._client
|
||||
local_model.client
|
||||
local_model.client_
|
||||
|
||||
|
@ -606,10 +597,8 @@ def test_model_and_local_version_are_picklable_whether_or_not_client_set_explici
|
|||
dask_model = model_factory(**params)
|
||||
local_model = dask_model.to_local()
|
||||
if set_client:
|
||||
assert dask_model._client == client1
|
||||
assert dask_model.client == client1
|
||||
else:
|
||||
assert dask_model._client is None
|
||||
assert dask_model.client is None
|
||||
|
||||
with pytest.raises(lgb.compat.LGBMNotFittedError, match='Cannot access property client_ before calling fit'):
|
||||
|
@ -640,14 +629,11 @@ def test_model_and_local_version_are_picklable_whether_or_not_client_set_explici
|
|||
serializer=serializer
|
||||
)
|
||||
|
||||
assert model_from_disk._client is None
|
||||
assert model_from_disk.client is None
|
||||
|
||||
if set_client:
|
||||
assert dask_model._client == client1
|
||||
assert dask_model.client == client1
|
||||
else:
|
||||
assert dask_model._client is None
|
||||
assert dask_model.client is None
|
||||
|
||||
with pytest.raises(lgb.compat.LGBMNotFittedError, match='Cannot access property client_ before calling fit'):
|
||||
|
@ -674,7 +660,6 @@ def test_model_and_local_version_are_picklable_whether_or_not_client_set_explici
|
|||
|
||||
assert "client" not in local_model.get_params()
|
||||
with pytest.raises(AttributeError):
|
||||
local_model._client
|
||||
local_model.client
|
||||
local_model.client_
|
||||
|
||||
|
@ -701,17 +686,14 @@ def test_model_and_local_version_are_picklable_whether_or_not_client_set_explici
|
|||
)
|
||||
|
||||
if set_client:
|
||||
assert dask_model._client == client1
|
||||
assert dask_model.client == client1
|
||||
assert dask_model.client_ == client1
|
||||
else:
|
||||
assert dask_model._client is None
|
||||
assert dask_model.client is None
|
||||
assert dask_model.client_ == default_client()
|
||||
assert dask_model.client_ == client2
|
||||
|
||||
assert isinstance(fitted_model_from_disk, model_factory)
|
||||
assert fitted_model_from_disk._client is None
|
||||
assert fitted_model_from_disk.client is None
|
||||
assert fitted_model_from_disk.client_ == default_client()
|
||||
assert fitted_model_from_disk.client_ == client2
|
||||
|
|
Загрузка…
Ссылка в новой задаче