From 5e6897200bc5b37ffeb1c32d7496642a8e7f27fe Mon Sep 17 00:00:00 2001 From: Yuting Jiang <37182275+yukirora@users.noreply.github.com> Date: Wed, 14 Apr 2021 16:39:53 +0800 Subject: [PATCH] Benchmarks: Revise Test - Revise benchmark test util to support pytorch multi-GPU test (#54) * Superbenchmark: Revise tests - revise benchmark test util to support multi gpu test * modify test_sharding_matmul.py to match the tests util --- tests/__init__.py | 4 ++ .../micro_benchmarks/test_sharding_matmul.py | 2 +- tests/benchmarks/utils.py | 67 ++++++++++++++++--- 3 files changed, 64 insertions(+), 9 deletions(-) create mode 100644 tests/__init__.py diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..ddd4a039 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Module for tests.""" diff --git a/tests/benchmarks/micro_benchmarks/test_sharding_matmul.py b/tests/benchmarks/micro_benchmarks/test_sharding_matmul.py index f862a546..c828d7e1 100644 --- a/tests/benchmarks/micro_benchmarks/test_sharding_matmul.py +++ b/tests/benchmarks/micro_benchmarks/test_sharding_matmul.py @@ -22,7 +22,7 @@ def test_pytorch_sharding_matmul(): assert (BenchmarkRegistry.is_benchmark_context_valid(context)) - utils.setup_simulated_ddp_distributed_env() + utils.setup_simulated_ddp_distributed_env(1, 0, utils.get_free_port()) benchmark = BenchmarkRegistry.launch_benchmark(context) # Check basic information. diff --git a/tests/benchmarks/utils.py b/tests/benchmarks/utils.py index b2a95e35..3ef897c5 100644 --- a/tests/benchmarks/utils.py +++ b/tests/benchmarks/utils.py @@ -4,15 +4,12 @@ """Utilities for benchmark tests.""" import os +import socket +from contextlib import closing +import multiprocessing as multiprocessing +from multiprocessing import Process - -def setup_simulated_ddp_distributed_env(): - """Function to setup the simulated DDP distributed envionment variables.""" - os.environ['WORLD_SIZE'] = '1' - os.environ['RANK'] = '0' - os.environ['LOCAL_RANK'] = '0' - os.environ['MASTER_ADDR'] = 'localhost' - os.environ['MASTER_PORT'] = '12345' +from superbench.benchmarks import BenchmarkRegistry def clean_simulated_ddp_distributed_env(): @@ -22,3 +19,57 @@ def clean_simulated_ddp_distributed_env(): os.environ.pop('LOCAL_RANK') os.environ.pop('MASTER_ADDR') os.environ.pop('MASTER_PORT') + + +def get_free_port(): + """Get a free port in current system. + + Return: + port (int): a free port in current system. + """ + with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: + s.bind(('', 0)) + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + return s.getsockname()[1] + + +def setup_simulated_ddp_distributed_env(world_size, local_rank, port): + """Function to setup the simulated DDP distributed envionment variables.""" + os.environ['WORLD_SIZE'] = str(world_size) + os.environ['RANK'] = str(local_rank) + os.environ['LOCAL_RANK'] = str(local_rank) + os.environ['MASTER_ADDR'] = 'localhost' + os.environ['MASTER_PORT'] = str(port) + + +def benchmark_in_one_process(context, world_size, local_rank, port, queue): + """Function to setup env for DDP initialization and run the benchmark in each single process.""" + setup_simulated_ddp_distributed_env(world_size, local_rank, port) + benchmark = BenchmarkRegistry.launch_benchmark(context) + # parser object must be removed becaues it can not be serialized. + benchmark._parser = None + queue.put(benchmark) + clean_simulated_ddp_distributed_env() + + +def simulated_ddp_distributed_benchmark(context, world_size): + """Function to run the benchmark on #world_size number of processes. + + Return: + results (list): list of benchmark results from #world_size number of processes. + """ + port = get_free_port() + process_list = [] + multiprocessing.set_start_method('spawn') + + queue = multiprocessing.Queue() + + for rank in range(world_size): + process = Process(target=benchmark_in_one_process, args=(context, world_size, rank, port, queue)) + process.start() + process_list.append(process) + + for process in process_list: + process.join() + results = [queue.get(1) for p in process_list] + return results