ml-agents/conftest.py

85 строки
2.8 KiB
Python

"""
This module provides the 'base_port' pytest fixture for mlagents tests.
This is useful because each mlagents environment requires a unique port to communicate over and will fail on collisions.
Normally this would prevent tests from being run in parallel but with the help of this fixture we can guarantee every
test gets the ports it needs.
See the base_port function for usage details.
"""
import tempfile
from pathlib import Path
import pytest
from filelock import FileLock
# TODO: Use this in all ml-agents tests so they can all run in parallel.
import mlagents.plugins.trainer_type
_BASE_PORT = 6005
# Hook for xdist
# https://github.com/ohmu/pytest-xdist/blob/master/xdist/newhooks.py
def pytest_testnodeready():
PortAllocator().setup_once_per_node()
class PortAllocator:
"""
WARNING: Should only be used within this file.
Handles handing out unique ports to tests that need ports to test.
Shares state between parallel tests on the same node via a text file and lockfile.
Should only be used through the base_port test fixture.
"""
def __init__(self):
self._port_alloc_file_path: Path = (
Path(tempfile.gettempdir()) / "next_mla_test_port.txt"
)
self._port_alloc_lock_path: Path = self._port_alloc_file_path.with_suffix(
".lock"
)
self.lock = FileLock(str(self._port_alloc_lock_path))
def reserve_n_ports(self, n: int) -> int:
with self.lock:
if self._port_alloc_file_path.is_file():
base_port = int(self._port_alloc_file_path.read_text())
else:
base_port = 6005
self._port_alloc_file_path.write_text(str(base_port + n))
return base_port
def setup_once_per_node(self) -> None:
"""
Clean up state files from previous runs, shoud only be called once per node.
Intended to only be called via xdist hooks.
"""
# TODO: Use .unlink(missing_ok=True) once python 3.8 is min version.
if self._port_alloc_lock_path.exists():
self._port_alloc_lock_path.unlink()
if self._port_alloc_file_path.exists():
self._port_alloc_file_path.unlink()
@pytest.fixture
def base_port(n_ports: int) -> int:
"""
Reserve a range of ports for testing (allows parallel testing even with envs).
Usage:
@pytest.mark.parametrize("n_ports", [2])
def test_something(base_port: int) -> None:
do_something(base_port)
do_something(base_port + 1)
:param _port_allocator: The global port allocator (custom pytest fixture).
:param n_ports: The number of ports needed.
:return: The base port number.
"""
return PortAllocator().reserve_n_ports(n_ports)
@pytest.fixture(scope="session", autouse=True)
def setup_plugin_trainers():
_, _ = mlagents.plugins.trainer_type.register_trainer_plugins()