From 296397df7bf584770a0297d84ce6a8f4ef25c317 Mon Sep 17 00:00:00 2001 From: James Lamb Date: Wed, 10 Mar 2021 12:02:27 -0600 Subject: [PATCH] [dask] raise more informative error for duplicates in 'machines' (fixes #4057) (#4059) * [dask] raise more informative error for duplicates in 'machines' * uncomment * avoid test failure * Revert "avoid test failure" This reverts commit 9442bdf00f193a19a923dc0deb46b7822cb6f601. --- python-package/lightgbm/dask.py | 4 ++++ tests/python_package_test/test_dask.py | 12 ++++++++++++ 2 files changed, 16 insertions(+) diff --git a/python-package/lightgbm/dask.py b/python-package/lightgbm/dask.py index 0d6511a52..72ecb6977 100644 --- a/python-package/lightgbm/dask.py +++ b/python-package/lightgbm/dask.py @@ -153,6 +153,10 @@ def _machines_to_worker_map(machines: str, worker_addresses: List[str]) -> Dict[ Dictionary where keys are work addresses in the form expected by Dask and values are a port for LightGBM to use. """ machine_addresses = machines.split(",") + + if len(set(machine_addresses)) != len(machine_addresses): + raise ValueError(f"Found duplicates in 'machines' ({machines}). Each entry in 'machines' must be a unique IP-port combination.") + machine_to_port = defaultdict(set) for address in machine_addresses: host, port = address.split(":") diff --git a/tests/python_package_test/test_dask.py b/tests/python_package_test/test_dask.py index 1b5754f3c..d4b22f675 100644 --- a/tests/python_package_test/test_dask.py +++ b/tests/python_package_test/test_dask.py @@ -1116,6 +1116,7 @@ def test_machines_should_be_used_if_provided(task, output): client.rebalance() n_workers = len(client.scheduler_info()['workers']) + assert n_workers > 1 open_ports = [lgb.dask._find_random_open_port() for _ in range(n_workers)] dask_model = dask_model_factory( n_estimators=5, @@ -1134,6 +1135,17 @@ def test_machines_should_be_used_if_provided(task, output): s.bind(('127.0.0.1', open_ports[0])) dask_model.fit(dX, dy, group=dg) + # an informative error should be raised if "machines" has duplicates + one_open_port = lgb.dask._find_random_open_port() + dask_model.set_params( + machines=",".join([ + "127.0.0.1:" + str(one_open_port) + for _ in range(n_workers) + ]) + ) + with pytest.raises(ValueError, match="Found duplicates in 'machines'"): + dask_model.fit(dX, dy, group=dg) + @pytest.mark.parametrize( "classes",