Benchmarks: Revise Test - Revise benchmark test util to support pytorch multi-GPU test (#54)
* Superbenchmark: Revise tests - revise benchmark test util to support multi gpu test * modify test_sharding_matmul.py to match the tests util
This commit is contained in:
Родитель
cb33c99ccb
Коммит
5e6897200b
|
@ -0,0 +1,4 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""Module for tests."""
|
|
@ -22,7 +22,7 @@ def test_pytorch_sharding_matmul():
|
|||
|
||||
assert (BenchmarkRegistry.is_benchmark_context_valid(context))
|
||||
|
||||
utils.setup_simulated_ddp_distributed_env()
|
||||
utils.setup_simulated_ddp_distributed_env(1, 0, utils.get_free_port())
|
||||
benchmark = BenchmarkRegistry.launch_benchmark(context)
|
||||
|
||||
# Check basic information.
|
||||
|
|
|
@ -4,15 +4,12 @@
|
|||
"""Utilities for benchmark tests."""
|
||||
|
||||
import os
|
||||
import socket
|
||||
from contextlib import closing
|
||||
import multiprocessing as multiprocessing
|
||||
from multiprocessing import Process
|
||||
|
||||
|
||||
def setup_simulated_ddp_distributed_env():
|
||||
"""Function to setup the simulated DDP distributed envionment variables."""
|
||||
os.environ['WORLD_SIZE'] = '1'
|
||||
os.environ['RANK'] = '0'
|
||||
os.environ['LOCAL_RANK'] = '0'
|
||||
os.environ['MASTER_ADDR'] = 'localhost'
|
||||
os.environ['MASTER_PORT'] = '12345'
|
||||
from superbench.benchmarks import BenchmarkRegistry
|
||||
|
||||
|
||||
def clean_simulated_ddp_distributed_env():
|
||||
|
@ -22,3 +19,57 @@ def clean_simulated_ddp_distributed_env():
|
|||
os.environ.pop('LOCAL_RANK')
|
||||
os.environ.pop('MASTER_ADDR')
|
||||
os.environ.pop('MASTER_PORT')
|
||||
|
||||
|
||||
def get_free_port():
|
||||
"""Get a free port in current system.
|
||||
|
||||
Return:
|
||||
port (int): a free port in current system.
|
||||
"""
|
||||
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
|
||||
s.bind(('', 0))
|
||||
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
return s.getsockname()[1]
|
||||
|
||||
|
||||
def setup_simulated_ddp_distributed_env(world_size, local_rank, port):
|
||||
"""Function to setup the simulated DDP distributed envionment variables."""
|
||||
os.environ['WORLD_SIZE'] = str(world_size)
|
||||
os.environ['RANK'] = str(local_rank)
|
||||
os.environ['LOCAL_RANK'] = str(local_rank)
|
||||
os.environ['MASTER_ADDR'] = 'localhost'
|
||||
os.environ['MASTER_PORT'] = str(port)
|
||||
|
||||
|
||||
def benchmark_in_one_process(context, world_size, local_rank, port, queue):
|
||||
"""Function to setup env for DDP initialization and run the benchmark in each single process."""
|
||||
setup_simulated_ddp_distributed_env(world_size, local_rank, port)
|
||||
benchmark = BenchmarkRegistry.launch_benchmark(context)
|
||||
# parser object must be removed becaues it can not be serialized.
|
||||
benchmark._parser = None
|
||||
queue.put(benchmark)
|
||||
clean_simulated_ddp_distributed_env()
|
||||
|
||||
|
||||
def simulated_ddp_distributed_benchmark(context, world_size):
|
||||
"""Function to run the benchmark on #world_size number of processes.
|
||||
|
||||
Return:
|
||||
results (list): list of benchmark results from #world_size number of processes.
|
||||
"""
|
||||
port = get_free_port()
|
||||
process_list = []
|
||||
multiprocessing.set_start_method('spawn')
|
||||
|
||||
queue = multiprocessing.Queue()
|
||||
|
||||
for rank in range(world_size):
|
||||
process = Process(target=benchmark_in_one_process, args=(context, world_size, rank, port, queue))
|
||||
process.start()
|
||||
process_list.append(process)
|
||||
|
||||
for process in process_list:
|
||||
process.join()
|
||||
results = [queue.get(1) for p in process_list]
|
||||
return results
|
||||
|
|
Загрузка…
Ссылка в новой задаче