[dask] use random ports in network setup (#3823)

* use socket.bind with port 0 and client.run to find random open ports

* include test for found ports

* find random open ports as default

* parametrize local_listen_port. type hint to _find_random_open_port. fid open ports only on workers with data.

* make indentation consistent and pass list of workers to client.run

* remove socket import

* change random port implementation

* fix test
This commit is contained in:
jmoralez 2021-02-23 22:14:12 -06:00 коммит произвёл GitHub
Родитель 7777852a19
Коммит 0e57657585
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
2 изменённых файлов: 25 добавлений и 109 удалений

Просмотреть файл

@ -45,83 +45,18 @@ def _get_dask_client(client: Optional[Client]) -> Client:
return client
def _find_open_port(worker_ip: str, local_listen_port: int, ports_to_skip: Iterable[int]) -> int:
"""Find an open port.
This function tries to find a free port on the machine it's run on. It is intended to
be run once on each Dask worker, sequentially.
Parameters
----------
worker_ip : str
IP address for the Dask worker.
local_listen_port : int
First port to try when searching for open ports.
ports_to_skip: Iterable[int]
An iterable of integers referring to ports that should be skipped. Since multiple Dask
workers can run on the same physical machine, this method may be called multiple times
on the same machine. ``ports_to_skip`` is used to ensure that LightGBM doesn't try to use
the same port for two worker processes running on the same machine.
def _find_random_open_port() -> int:
"""Find a random open port on localhost.
Returns
-------
port : int
A free port on the machine referenced by ``worker_ip``.
A free port on localhost
"""
max_tries = 1000
found_port = False
for i in range(max_tries):
out_port = local_listen_port + i
if out_port in ports_to_skip:
continue
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind((worker_ip, out_port))
found_port = True
break
# if unavailable, you'll get OSError: Address already in use
except OSError:
continue
if not found_port:
msg = "LightGBM tried %s:%d-%d and could not create a connection. Try setting local_listen_port to a different value."
raise RuntimeError(msg % (worker_ip, local_listen_port, out_port))
return out_port
def _find_ports_for_workers(client: Client, worker_addresses: Iterable[str], local_listen_port: int) -> Dict[str, int]:
"""Find an open port on each worker.
LightGBM distributed training uses TCP sockets by default, and this method is used to
identify open ports on each worker so LightGBM can reliable create those sockets.
Parameters
----------
client : dask.distributed.Client
Dask client.
worker_addresses : Iterable[str]
An iterable of addresses for workers in the cluster. These are strings of the form ``<protocol>://<host>:port``.
local_listen_port : int
First port to try when searching for open ports.
Returns
-------
result : Dict[str, int]
Dictionary where keys are worker addresses and values are an open port for LightGBM to use.
"""
lightgbm_ports: Set[int] = set()
worker_ip_to_port = {}
for worker_address in worker_addresses:
port = client.submit(
func=_find_open_port,
workers=[worker_address],
worker_ip=urlparse(worker_address).hostname,
local_listen_port=local_listen_port,
ports_to_skip=lightgbm_ports
).result()
lightgbm_ports.add(port)
worker_ip_to_port[worker_address] = port
return worker_ip_to_port
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(('', 0))
port = s.getsockname()[1]
return port
def _concat(seq: List[_DaskPart]) -> _DaskPart:
@ -415,10 +350,9 @@ def _train(
}
else:
_log_info("Finding random open ports for workers")
worker_address_to_port = _find_ports_for_workers(
client=client,
worker_addresses=worker_addresses,
local_listen_port=local_listen_port
worker_address_to_port = client.run(
_find_random_open_port,
workers=list(worker_addresses)
)
machines = ','.join([
'%s:%d' % (urlparse(worker_address).hostname, port)

Просмотреть файл

@ -174,14 +174,6 @@ def _accuracy_score(dy_true, dy_pred):
return da.average(dy_true == dy_pred).compute()
def _find_random_open_port() -> int:
"""Find a random open port on localhost"""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(('', 0))
port = s.getsockname()[1]
return port
def _pickle(obj, filepath, serializer):
if serializer == 'pickle':
with open(filepath, 'wb') as f:
@ -343,6 +335,19 @@ def test_classifier_pred_contrib(output, centers, client):
client.close(timeout=CLIENT_CLOSE_TIMEOUT)
def test_find_random_open_port(client):
for _ in range(5):
worker_address_to_port = client.run(lgb.dask._find_random_open_port)
found_ports = worker_address_to_port.values()
# check that found ports are different for same address (LocalCluster)
assert len(set(found_ports)) == len(found_ports)
# check that the ports are indeed open
for port in found_ports:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(('', port))
client.close(timeout=CLIENT_CLOSE_TIMEOUT)
def test_training_does_not_fail_on_port_conflicts(client):
_, _, _, dX, dy, dw = _create_data('classification', output='array')
@ -885,29 +890,6 @@ def test_model_and_local_version_are_picklable_whether_or_not_client_set_explici
assert_eq(preds_orig_local, preds_loaded_model_local)
def test_find_open_port_works(listen_port):
worker_ip = '127.0.0.1'
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind((worker_ip, listen_port))
new_port = lgb.dask._find_open_port(
worker_ip=worker_ip,
local_listen_port=listen_port,
ports_to_skip=set()
)
assert listen_port < new_port < listen_port + 1000
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s_1:
s_1.bind((worker_ip, listen_port))
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s_2:
s_2.bind((worker_ip, listen_port + 1))
new_port = lgb.dask._find_open_port(
worker_ip=worker_ip,
local_listen_port=listen_port,
ports_to_skip=set()
)
assert listen_port + 1 < new_port < listen_port + 1000
def test_warns_and_continues_on_unrecognized_tree_learner(client):
X = da.random.random((1e3, 10))
y = da.random.random((1e3, 1))
@ -1075,7 +1057,7 @@ def test_network_params_not_required_but_respected_if_given(client, task, output
# model 2 - machines given
n_workers = len(client.scheduler_info()['workers'])
open_ports = [_find_random_open_port() for _ in range(n_workers)]
open_ports = [lgb.dask._find_random_open_port() for _ in range(n_workers)]
dask_model2 = dask_model_factory(
n_estimators=5,
num_leaves=5,
@ -1143,7 +1125,7 @@ def test_machines_should_be_used_if_provided(task, output):
client.rebalance()
n_workers = len(client.scheduler_info()['workers'])
open_ports = [_find_random_open_port() for _ in range(n_workers)]
open_ports = [lgb.dask._find_random_open_port() for _ in range(n_workers)]
dask_model = dask_model_factory(
n_estimators=5,
num_leaves=5,