Benchmarks: Revise Test - Revise benchmark test util to support pytorch multi-GPU test (#54)

* Superbenchmark: Revise tests - revise benchmark test util to support multi gpu test

* modify test_sharding_matmul.py to match the tests util
This commit is contained in:
Yuting Jiang 2021-04-14 16:39:53 +08:00 коммит произвёл GitHub
Родитель cb33c99ccb
Коммит 5e6897200b
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
3 изменённых файлов: 64 добавлений и 9 удалений

4
tests/__init__.py Normal file
Просмотреть файл

@ -0,0 +1,4 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Module for tests."""

Просмотреть файл

@ -22,7 +22,7 @@ def test_pytorch_sharding_matmul():
assert (BenchmarkRegistry.is_benchmark_context_valid(context))
utils.setup_simulated_ddp_distributed_env()
utils.setup_simulated_ddp_distributed_env(1, 0, utils.get_free_port())
benchmark = BenchmarkRegistry.launch_benchmark(context)
# Check basic information.

Просмотреть файл

@ -4,15 +4,12 @@
"""Utilities for benchmark tests."""
import os
import socket
from contextlib import closing
import multiprocessing as multiprocessing
from multiprocessing import Process
def setup_simulated_ddp_distributed_env():
"""Function to setup the simulated DDP distributed envionment variables."""
os.environ['WORLD_SIZE'] = '1'
os.environ['RANK'] = '0'
os.environ['LOCAL_RANK'] = '0'
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12345'
from superbench.benchmarks import BenchmarkRegistry
def clean_simulated_ddp_distributed_env():
@ -22,3 +19,57 @@ def clean_simulated_ddp_distributed_env():
os.environ.pop('LOCAL_RANK')
os.environ.pop('MASTER_ADDR')
os.environ.pop('MASTER_PORT')
def get_free_port():
"""Get a free port in current system.
Return:
port (int): a free port in current system.
"""
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
s.bind(('', 0))
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
return s.getsockname()[1]
def setup_simulated_ddp_distributed_env(world_size, local_rank, port):
"""Function to setup the simulated DDP distributed envionment variables."""
os.environ['WORLD_SIZE'] = str(world_size)
os.environ['RANK'] = str(local_rank)
os.environ['LOCAL_RANK'] = str(local_rank)
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = str(port)
def benchmark_in_one_process(context, world_size, local_rank, port, queue):
"""Function to setup env for DDP initialization and run the benchmark in each single process."""
setup_simulated_ddp_distributed_env(world_size, local_rank, port)
benchmark = BenchmarkRegistry.launch_benchmark(context)
# parser object must be removed becaues it can not be serialized.
benchmark._parser = None
queue.put(benchmark)
clean_simulated_ddp_distributed_env()
def simulated_ddp_distributed_benchmark(context, world_size):
"""Function to run the benchmark on #world_size number of processes.
Return:
results (list): list of benchmark results from #world_size number of processes.
"""
port = get_free_port()
process_list = []
multiprocessing.set_start_method('spawn')
queue = multiprocessing.Queue()
for rank in range(world_size):
process = Process(target=benchmark_in_one_process, args=(context, world_size, rank, port, queue))
process.start()
process_list.append(process)
for process in process_list:
process.join()
results = [queue.get(1) for p in process_list]
return results