зеркало из https://github.com/microsoft/LightGBM.git
[dask] use random ports in network setup (#3823)
* use socket.bind with port 0 and client.run to find random open ports * include test for found ports * find random open ports as default * parametrize local_listen_port. type hint to _find_random_open_port. fid open ports only on workers with data. * make indentation consistent and pass list of workers to client.run * remove socket import * change random port implementation * fix test
This commit is contained in:
Родитель
7777852a19
Коммит
0e57657585
|
@ -45,83 +45,18 @@ def _get_dask_client(client: Optional[Client]) -> Client:
|
|||
return client
|
||||
|
||||
|
||||
def _find_open_port(worker_ip: str, local_listen_port: int, ports_to_skip: Iterable[int]) -> int:
|
||||
"""Find an open port.
|
||||
|
||||
This function tries to find a free port on the machine it's run on. It is intended to
|
||||
be run once on each Dask worker, sequentially.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
worker_ip : str
|
||||
IP address for the Dask worker.
|
||||
local_listen_port : int
|
||||
First port to try when searching for open ports.
|
||||
ports_to_skip: Iterable[int]
|
||||
An iterable of integers referring to ports that should be skipped. Since multiple Dask
|
||||
workers can run on the same physical machine, this method may be called multiple times
|
||||
on the same machine. ``ports_to_skip`` is used to ensure that LightGBM doesn't try to use
|
||||
the same port for two worker processes running on the same machine.
|
||||
def _find_random_open_port() -> int:
|
||||
"""Find a random open port on localhost.
|
||||
|
||||
Returns
|
||||
-------
|
||||
port : int
|
||||
A free port on the machine referenced by ``worker_ip``.
|
||||
A free port on localhost
|
||||
"""
|
||||
max_tries = 1000
|
||||
found_port = False
|
||||
for i in range(max_tries):
|
||||
out_port = local_listen_port + i
|
||||
if out_port in ports_to_skip:
|
||||
continue
|
||||
try:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.bind((worker_ip, out_port))
|
||||
found_port = True
|
||||
break
|
||||
# if unavailable, you'll get OSError: Address already in use
|
||||
except OSError:
|
||||
continue
|
||||
if not found_port:
|
||||
msg = "LightGBM tried %s:%d-%d and could not create a connection. Try setting local_listen_port to a different value."
|
||||
raise RuntimeError(msg % (worker_ip, local_listen_port, out_port))
|
||||
return out_port
|
||||
|
||||
|
||||
def _find_ports_for_workers(client: Client, worker_addresses: Iterable[str], local_listen_port: int) -> Dict[str, int]:
|
||||
"""Find an open port on each worker.
|
||||
|
||||
LightGBM distributed training uses TCP sockets by default, and this method is used to
|
||||
identify open ports on each worker so LightGBM can reliable create those sockets.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
client : dask.distributed.Client
|
||||
Dask client.
|
||||
worker_addresses : Iterable[str]
|
||||
An iterable of addresses for workers in the cluster. These are strings of the form ``<protocol>://<host>:port``.
|
||||
local_listen_port : int
|
||||
First port to try when searching for open ports.
|
||||
|
||||
Returns
|
||||
-------
|
||||
result : Dict[str, int]
|
||||
Dictionary where keys are worker addresses and values are an open port for LightGBM to use.
|
||||
"""
|
||||
lightgbm_ports: Set[int] = set()
|
||||
worker_ip_to_port = {}
|
||||
for worker_address in worker_addresses:
|
||||
port = client.submit(
|
||||
func=_find_open_port,
|
||||
workers=[worker_address],
|
||||
worker_ip=urlparse(worker_address).hostname,
|
||||
local_listen_port=local_listen_port,
|
||||
ports_to_skip=lightgbm_ports
|
||||
).result()
|
||||
lightgbm_ports.add(port)
|
||||
worker_ip_to_port[worker_address] = port
|
||||
|
||||
return worker_ip_to_port
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.bind(('', 0))
|
||||
port = s.getsockname()[1]
|
||||
return port
|
||||
|
||||
|
||||
def _concat(seq: List[_DaskPart]) -> _DaskPart:
|
||||
|
@ -415,10 +350,9 @@ def _train(
|
|||
}
|
||||
else:
|
||||
_log_info("Finding random open ports for workers")
|
||||
worker_address_to_port = _find_ports_for_workers(
|
||||
client=client,
|
||||
worker_addresses=worker_addresses,
|
||||
local_listen_port=local_listen_port
|
||||
worker_address_to_port = client.run(
|
||||
_find_random_open_port,
|
||||
workers=list(worker_addresses)
|
||||
)
|
||||
machines = ','.join([
|
||||
'%s:%d' % (urlparse(worker_address).hostname, port)
|
||||
|
|
|
@ -174,14 +174,6 @@ def _accuracy_score(dy_true, dy_pred):
|
|||
return da.average(dy_true == dy_pred).compute()
|
||||
|
||||
|
||||
def _find_random_open_port() -> int:
|
||||
"""Find a random open port on localhost"""
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.bind(('', 0))
|
||||
port = s.getsockname()[1]
|
||||
return port
|
||||
|
||||
|
||||
def _pickle(obj, filepath, serializer):
|
||||
if serializer == 'pickle':
|
||||
with open(filepath, 'wb') as f:
|
||||
|
@ -343,6 +335,19 @@ def test_classifier_pred_contrib(output, centers, client):
|
|||
client.close(timeout=CLIENT_CLOSE_TIMEOUT)
|
||||
|
||||
|
||||
def test_find_random_open_port(client):
|
||||
for _ in range(5):
|
||||
worker_address_to_port = client.run(lgb.dask._find_random_open_port)
|
||||
found_ports = worker_address_to_port.values()
|
||||
# check that found ports are different for same address (LocalCluster)
|
||||
assert len(set(found_ports)) == len(found_ports)
|
||||
# check that the ports are indeed open
|
||||
for port in found_ports:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.bind(('', port))
|
||||
client.close(timeout=CLIENT_CLOSE_TIMEOUT)
|
||||
|
||||
|
||||
def test_training_does_not_fail_on_port_conflicts(client):
|
||||
_, _, _, dX, dy, dw = _create_data('classification', output='array')
|
||||
|
||||
|
@ -885,29 +890,6 @@ def test_model_and_local_version_are_picklable_whether_or_not_client_set_explici
|
|||
assert_eq(preds_orig_local, preds_loaded_model_local)
|
||||
|
||||
|
||||
def test_find_open_port_works(listen_port):
|
||||
worker_ip = '127.0.0.1'
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.bind((worker_ip, listen_port))
|
||||
new_port = lgb.dask._find_open_port(
|
||||
worker_ip=worker_ip,
|
||||
local_listen_port=listen_port,
|
||||
ports_to_skip=set()
|
||||
)
|
||||
assert listen_port < new_port < listen_port + 1000
|
||||
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s_1:
|
||||
s_1.bind((worker_ip, listen_port))
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s_2:
|
||||
s_2.bind((worker_ip, listen_port + 1))
|
||||
new_port = lgb.dask._find_open_port(
|
||||
worker_ip=worker_ip,
|
||||
local_listen_port=listen_port,
|
||||
ports_to_skip=set()
|
||||
)
|
||||
assert listen_port + 1 < new_port < listen_port + 1000
|
||||
|
||||
|
||||
def test_warns_and_continues_on_unrecognized_tree_learner(client):
|
||||
X = da.random.random((1e3, 10))
|
||||
y = da.random.random((1e3, 1))
|
||||
|
@ -1075,7 +1057,7 @@ def test_network_params_not_required_but_respected_if_given(client, task, output
|
|||
|
||||
# model 2 - machines given
|
||||
n_workers = len(client.scheduler_info()['workers'])
|
||||
open_ports = [_find_random_open_port() for _ in range(n_workers)]
|
||||
open_ports = [lgb.dask._find_random_open_port() for _ in range(n_workers)]
|
||||
dask_model2 = dask_model_factory(
|
||||
n_estimators=5,
|
||||
num_leaves=5,
|
||||
|
@ -1143,7 +1125,7 @@ def test_machines_should_be_used_if_provided(task, output):
|
|||
client.rebalance()
|
||||
|
||||
n_workers = len(client.scheduler_info()['workers'])
|
||||
open_ports = [_find_random_open_port() for _ in range(n_workers)]
|
||||
open_ports = [lgb.dask._find_random_open_port() for _ in range(n_workers)]
|
||||
dask_model = dask_model_factory(
|
||||
n_estimators=5,
|
||||
num_leaves=5,
|
||||
|
|
Загрузка…
Ссылка в новой задаче