[dask] raise more informative error for duplicates in 'machines' (fixes #4057) (#4059)

* [dask] raise more informative error for duplicates in 'machines'

* uncomment

* avoid test failure

* Revert "avoid test failure"

This reverts commit 9442bdf00f.
This commit is contained in:
James Lamb 2021-03-10 12:02:27 -06:00 коммит произвёл GitHub
Родитель b75a43a05b
Коммит 296397df7b
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
2 изменённых файлов: 16 добавлений и 0 удалений

Просмотреть файл

@ -153,6 +153,10 @@ def _machines_to_worker_map(machines: str, worker_addresses: List[str]) -> Dict[
Dictionary where keys are work addresses in the form expected by Dask and values are a port for LightGBM to use.
"""
machine_addresses = machines.split(",")
if len(set(machine_addresses)) != len(machine_addresses):
raise ValueError(f"Found duplicates in 'machines' ({machines}). Each entry in 'machines' must be a unique IP-port combination.")
machine_to_port = defaultdict(set)
for address in machine_addresses:
host, port = address.split(":")

Просмотреть файл

@ -1116,6 +1116,7 @@ def test_machines_should_be_used_if_provided(task, output):
client.rebalance()
n_workers = len(client.scheduler_info()['workers'])
assert n_workers > 1
open_ports = [lgb.dask._find_random_open_port() for _ in range(n_workers)]
dask_model = dask_model_factory(
n_estimators=5,
@ -1134,6 +1135,17 @@ def test_machines_should_be_used_if_provided(task, output):
s.bind(('127.0.0.1', open_ports[0]))
dask_model.fit(dX, dy, group=dg)
# an informative error should be raised if "machines" has duplicates
one_open_port = lgb.dask._find_random_open_port()
dask_model.set_params(
machines=",".join([
"127.0.0.1:" + str(one_open_port)
for _ in range(n_workers)
])
)
with pytest.raises(ValueError, match="Found duplicates in 'machines'"):
dask_model.fit(dX, dy, group=dg)
@pytest.mark.parametrize(
"classes",