[dask] remove unused private _client attribute (#3904)

* Update test_dask.py

* Update dask.py

* Update .vsts-ci.yml

* Revert "Update .vsts-ci.yml"

This reverts commit 98422be5b5.
This commit is contained in:
Nikita Titov 2021-02-03 19:44:08 +03:00 коммит произвёл GitHub
Родитель 08c68c917b
Коммит b1e000c045
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
2 изменённых файлов: 3 добавлений и 27 удалений

Просмотреть файл

@ -468,11 +468,9 @@ class _DaskLGBMModel:
def _lgb_getstate(self) -> Dict[Any, Any]:
"""Remove un-picklable attributes before serialization."""
client = self.__dict__.pop("client", None)
self.__dict__.pop("_client", None)
self._other_params.pop("client", None)
out = deepcopy(self.__dict__)
out.update({"_client": None, "client": None})
self._client = client
out.update({"client": None})
self.client = client
return out
@ -521,8 +519,7 @@ class _DaskLGBMModel:
attributes = source.__dict__
extra_param_names = set(attributes.keys()).difference(params.keys())
for name in extra_param_names:
if name != "_client":
setattr(dest, name, attributes[name])
setattr(dest, name, attributes[name])
class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel):
@ -554,7 +551,6 @@ class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel):
**kwargs: Any
):
"""Docstring is inherited from the lightgbm.LGBMClassifier.__init__."""
self._client = client
self.client = client
super().__init__(
boosting_type=boosting_type,
@ -672,7 +668,6 @@ class DaskLGBMRegressor(LGBMRegressor, _DaskLGBMModel):
**kwargs: Any
):
"""Docstring is inherited from the lightgbm.LGBMRegressor.__init__."""
self._client = client
self.client = client
super().__init__(
boosting_type=boosting_type,
@ -779,7 +774,6 @@ class DaskLGBMRanker(LGBMRanker, _DaskLGBMModel):
**kwargs: Any
):
"""Docstring is inherited from the lightgbm.LGBMRanker.__init__."""
self._client = client
self.client = client
super().__init__(
boosting_type=boosting_type,

Просмотреть файл

@ -2,7 +2,6 @@
"""Tests for lightgbm.dask module"""
import inspect
import joblib
import pickle
import socket
from itertools import groupby
@ -19,6 +18,7 @@ if not lgb.compat.DASK_INSTALLED:
import cloudpickle
import dask.array as da
import dask.dataframe as dd
import joblib
import numpy as np
import pandas as pd
from scipy.stats import spearmanr
@ -488,34 +488,29 @@ def test_training_works_if_client_not_provided_or_set_after_construction(task, l
# should be able to use the class without specifying a client
dask_model = model_factory(**params)
assert dask_model._client is None
assert dask_model.client is None
with pytest.raises(lgb.compat.LGBMNotFittedError, match='Cannot access property client_ before calling fit'):
dask_model.client_
dask_model.fit(dX, dy, group=dg)
assert dask_model.fitted_
assert dask_model._client is None
assert dask_model.client is None
assert dask_model.client_ == client
preds = dask_model.predict(dX)
assert isinstance(preds, da.Array)
assert dask_model.fitted_
assert dask_model._client is None
assert dask_model.client is None
assert dask_model.client_ == client
local_model = dask_model.to_local()
with pytest.raises(AttributeError):
local_model._client
local_model.client
local_model.client_
# should be able to set client after construction
dask_model = model_factory(**params)
dask_model.set_params(client=client)
assert dask_model._client == client
assert dask_model.client == client
with pytest.raises(lgb.compat.LGBMNotFittedError, match='Cannot access property client_ before calling fit'):
@ -523,21 +518,17 @@ def test_training_works_if_client_not_provided_or_set_after_construction(task, l
dask_model.fit(dX, dy, group=dg)
assert dask_model.fitted_
assert dask_model._client == client
assert dask_model.client == client
assert dask_model.client_ == client
preds = dask_model.predict(dX)
assert isinstance(preds, da.Array)
assert dask_model.fitted_
assert dask_model._client == client
assert dask_model.client == client
assert dask_model.client_ == client
local_model = dask_model.to_local()
assert getattr(local_model, "_client", None) is None
with pytest.raises(AttributeError):
local_model._client
local_model.client
local_model.client_
@ -606,10 +597,8 @@ def test_model_and_local_version_are_picklable_whether_or_not_client_set_explici
dask_model = model_factory(**params)
local_model = dask_model.to_local()
if set_client:
assert dask_model._client == client1
assert dask_model.client == client1
else:
assert dask_model._client is None
assert dask_model.client is None
with pytest.raises(lgb.compat.LGBMNotFittedError, match='Cannot access property client_ before calling fit'):
@ -640,14 +629,11 @@ def test_model_and_local_version_are_picklable_whether_or_not_client_set_explici
serializer=serializer
)
assert model_from_disk._client is None
assert model_from_disk.client is None
if set_client:
assert dask_model._client == client1
assert dask_model.client == client1
else:
assert dask_model._client is None
assert dask_model.client is None
with pytest.raises(lgb.compat.LGBMNotFittedError, match='Cannot access property client_ before calling fit'):
@ -674,7 +660,6 @@ def test_model_and_local_version_are_picklable_whether_or_not_client_set_explici
assert "client" not in local_model.get_params()
with pytest.raises(AttributeError):
local_model._client
local_model.client
local_model.client_
@ -701,17 +686,14 @@ def test_model_and_local_version_are_picklable_whether_or_not_client_set_explici
)
if set_client:
assert dask_model._client == client1
assert dask_model.client == client1
assert dask_model.client_ == client1
else:
assert dask_model._client is None
assert dask_model.client is None
assert dask_model.client_ == default_client()
assert dask_model.client_ == client2
assert isinstance(fitted_model_from_disk, model_factory)
assert fitted_model_from_disk._client is None
assert fitted_model_from_disk.client is None
assert fitted_model_from_disk.client_ == default_client()
assert fitted_model_from_disk.client_ == client2