Runner: Support `topo-aware` and `k-batch` pattern in 'mpi' mode (#437)
**Description** Support the following patterns in `mpi` mode: * `k-batch` * `topo-aware`
This commit is contained in:
Родитель
fc661f7db3
Коммит
65e433c0c6
|
@ -463,5 +463,10 @@ Pattern variables to run benchmarks with nodes in specified traffic pattern comb
|
|||
Only available for `mpi` mode.
|
||||
|
||||
Available variables in formatted string includes:
|
||||
+ `name`
|
||||
* accepted values: `all-nodes`, `pair-wise`
|
||||
+ `type`: the traffic pattern type, required.
|
||||
* accepted values: `all-nodes`, `pair-wise`, `k-batch`, `topo-aware`
|
||||
+ `batch`: the scale of batch, required in `k-batch` pattern.
|
||||
+ `ibstat`: the path of ibstat output, wil be auto-generated in `./output/ibstat_file.txt` if not specified, optional in `topo-aware` pattern
|
||||
+ `ibnetdiscover`: the path of ibnetdiscover output `ibnetdiscover_file.txt`, required in `topo-aware` pattern.
|
||||
+ `min_dist`: minimum distance of VM pair, required in `topo-aware` pattern.
|
||||
+ `max_dist`: maximum distance of VM pair, required in `topo-aware` pattern.
|
||||
|
|
|
@ -10,7 +10,9 @@ from superbench.common.utils.file_handler import rotate_dir, create_sb_output_di
|
|||
from superbench.common.utils.lazy_import import LazyImport
|
||||
from superbench.common.utils.process import run_command
|
||||
from superbench.common.utils.topo_aware import gen_topo_aware_config
|
||||
from superbench.common.utils.gen_traffic_pattern_config import gen_pair_wise_config, gen_traffic_pattern_host_group
|
||||
from superbench.common.utils.gen_traffic_pattern_config import (
|
||||
gen_pair_wise_config, gen_traffic_pattern_host_group, gen_ibstat
|
||||
)
|
||||
|
||||
device_manager = LazyImport('superbench.common.utils.device_manager')
|
||||
|
||||
|
@ -30,4 +32,5 @@ __all__ = [
|
|||
'gen_topo_aware_config',
|
||||
'gen_pair_wise_config',
|
||||
'gen_traffic_pattern_host_group',
|
||||
'gen_ibstat',
|
||||
]
|
||||
|
|
|
@ -2,7 +2,11 @@
|
|||
# Licensed under the MIT License.
|
||||
|
||||
"""Utilities for traffic pattern config."""
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
from superbench.common.utils import logger
|
||||
from superbench.common.utils import gen_topo_aware_config
|
||||
|
||||
|
||||
def gen_all_nodes_config(n):
|
||||
|
@ -61,6 +65,35 @@ def gen_pair_wise_config(n):
|
|||
return config
|
||||
|
||||
|
||||
def gen_k_batch_config(n, batch):
|
||||
"""Generate VM groups config with specified batch scale.
|
||||
|
||||
Args:
|
||||
n (int): the number of participants.
|
||||
batch (int): the scale of batch.
|
||||
|
||||
Returns:
|
||||
config (list): the generated config list, each item in the list is a str like "0,1;2,3".
|
||||
"""
|
||||
config = []
|
||||
if batch is None:
|
||||
logger.warning('scale is not specified')
|
||||
return config
|
||||
if batch <= 0 or n <= 0:
|
||||
logger.warning('scale or n is not positive')
|
||||
return config
|
||||
if batch > n:
|
||||
logger.warning('scale large than n')
|
||||
return config
|
||||
|
||||
group = []
|
||||
rem = n % batch
|
||||
for i in range(0, n - rem, batch):
|
||||
group.append(','.join(map(str, list(range(i, i + batch)))))
|
||||
config = [';'.join(group)]
|
||||
return config
|
||||
|
||||
|
||||
def __convert_config_to_host_group(config, host_list):
|
||||
"""Convert config format to host node.
|
||||
|
||||
|
@ -84,6 +117,43 @@ def __convert_config_to_host_group(config, host_list):
|
|||
return host_groups
|
||||
|
||||
|
||||
def gen_ibstat(ansible_config, ibstat_path): # pragma: no cover
|
||||
"""Generate the ibstat file in specified path.
|
||||
|
||||
Args:
|
||||
ansible_config (DictConfig): Ansible config object.
|
||||
ibstat_path (str): the expected path of ibstat file.
|
||||
|
||||
Returns:
|
||||
ibstat_path (str): the generated path of ibstat file.
|
||||
"""
|
||||
from superbench.runner import AnsibleClient
|
||||
ibstat_list = []
|
||||
stdout_regex = re.compile(r'\x1b(\[.*?[@-~]|\].*?(\x07|\x1b\\))')
|
||||
ansible_client = AnsibleClient(ansible_config)
|
||||
cmd = 'cat /sys/class/infiniband/*/sys_image_guid | tr -d :'
|
||||
|
||||
# callback function to collect and parse ibstat
|
||||
def _ibstat_parser(artifact_dir):
|
||||
stdout_path = Path(artifact_dir) / 'stdout'
|
||||
with stdout_path.open(mode='r') as raw_outputs:
|
||||
for raw_output in raw_outputs:
|
||||
output = stdout_regex.sub('', raw_output).strip()
|
||||
if ' | CHANGED | rc=0 >>' in output:
|
||||
output = 'VM_hostname ' + output.replace(' | CHANGED | rc=0 >>', '')
|
||||
ibstat_list.append(output)
|
||||
|
||||
config = ansible_client.get_shell_config(cmd)
|
||||
config['artifacts_handler'] = _ibstat_parser
|
||||
rc = ansible_client.run(config)
|
||||
if rc != 0:
|
||||
logger.error('Failed to gather ibstat with config: {}'.format(config))
|
||||
with Path(ibstat_path).open(mode='w') as f:
|
||||
for ibstat in ibstat_list:
|
||||
f.write(ibstat + '\n')
|
||||
return ibstat_path
|
||||
|
||||
|
||||
def gen_traffic_pattern_host_group(host_list, pattern):
|
||||
"""Generate host group from specified traffic pattern.
|
||||
|
||||
|
@ -96,11 +166,17 @@ def gen_traffic_pattern_host_group(host_list, pattern):
|
|||
"""
|
||||
config = []
|
||||
n = len(host_list)
|
||||
if pattern.name == 'all-nodes':
|
||||
if pattern.type == 'all-nodes':
|
||||
config = gen_all_nodes_config(n)
|
||||
elif pattern.name == 'pair-wise':
|
||||
elif pattern.type == 'pair-wise':
|
||||
config = gen_pair_wise_config(n)
|
||||
elif pattern.type == 'k-batch':
|
||||
config = gen_k_batch_config(n, pattern.batch)
|
||||
elif pattern.type == 'topo-aware':
|
||||
config = gen_topo_aware_config(
|
||||
host_list, pattern.ibstat, pattern.ibnetdiscover, pattern.min_dist, pattern.max_dist
|
||||
)
|
||||
else:
|
||||
logger.error('Unsupported traffic pattern: {}'.format(pattern.name))
|
||||
logger.error('Unsupported traffic pattern: {}'.format(pattern.type))
|
||||
host_group = __convert_config_to_host_group(config, host_list)
|
||||
return host_group
|
||||
|
|
|
@ -3,6 +3,10 @@
|
|||
|
||||
"""SuperBench runner module."""
|
||||
|
||||
from superbench.runner.ansible import AnsibleClient
|
||||
from superbench.runner.runner import SuperBenchRunner
|
||||
|
||||
__all__ = ['SuperBenchRunner']
|
||||
__all__ = [
|
||||
'AnsibleClient',
|
||||
'SuperBenchRunner',
|
||||
]
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
|
||||
"""SuperBench Runner."""
|
||||
|
||||
import os
|
||||
import json
|
||||
import random
|
||||
from pathlib import Path
|
||||
|
@ -14,7 +15,7 @@ from natsort import natsorted
|
|||
from joblib import Parallel, delayed
|
||||
from omegaconf import ListConfig, OmegaConf
|
||||
|
||||
from superbench.common.utils import SuperBenchLogger, logger, gen_traffic_pattern_host_group
|
||||
from superbench.common.utils import SuperBenchLogger, logger, gen_ibstat, gen_traffic_pattern_host_group
|
||||
from superbench.runner.ansible import AnsibleClient
|
||||
from superbench.benchmarks import ReduceType, Reducer
|
||||
from superbench.monitor import MonitorRecord
|
||||
|
@ -88,6 +89,11 @@ class SuperBenchRunner():
|
|||
}
|
||||
for key in ['PATH', 'LD_LIBRARY_PATH', 'SB_MICRO_PATH', 'SB_WORKSPACE']:
|
||||
self._sb_benchmarks[name].modes[idx].env.setdefault(key, None)
|
||||
if mode.pattern:
|
||||
if mode.pattern.type == 'topo-aware' and not mode.pattern.ibstat:
|
||||
self._sb_benchmarks[name].modes[idx].pattern.ibstat = gen_ibstat(
|
||||
self._ansible_config, str(self._output_path / 'ibstate_file.txt')
|
||||
)
|
||||
|
||||
def __get_enabled_benchmarks(self):
|
||||
"""Get enabled benchmarks list.
|
||||
|
@ -449,6 +455,9 @@ class SuperBenchRunner():
|
|||
if not mode.pattern:
|
||||
ansible_rc = self._run_proc(benchmark_name, mode, {'proc_rank': 0})
|
||||
else:
|
||||
if not os.path.exists(self._output_path / 'hostfile'):
|
||||
logger.warning('No hostfile under %s.', self._output_path)
|
||||
continue
|
||||
with open(self._output_path / 'hostfile', 'r') as f:
|
||||
host_list = f.read().splitlines()
|
||||
pattern_hostx = gen_traffic_pattern_host_group(host_list, mode.pattern)
|
||||
|
|
|
@ -5,18 +5,20 @@
|
|||
import argparse
|
||||
import unittest
|
||||
|
||||
from tests.helper import decorator
|
||||
from superbench.common.utils import gen_traffic_pattern_host_group
|
||||
|
||||
|
||||
class GenConfigTest(unittest.TestCase):
|
||||
"""Test the utils for generating config."""
|
||||
def test_gen_traffic_pattern_host_group(self):
|
||||
@decorator.load_data('tests/data/ib_traffic_topo_aware_hostfile') # noqa: C901
|
||||
def test_gen_traffic_pattern_host_group(self, tp_hostfile):
|
||||
"""Test the function of generating traffic pattern config from specified mode."""
|
||||
# Test for all-nodes pattern
|
||||
hostx = ['node0', 'node1', 'node2', 'node3', 'node4', 'node5', 'node6', 'node7']
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
'--name',
|
||||
'--type',
|
||||
type=str,
|
||||
default='all-nodes',
|
||||
)
|
||||
|
@ -27,7 +29,7 @@ class GenConfigTest(unittest.TestCase):
|
|||
# Test for pair-wise pattern
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
'--name',
|
||||
'--type',
|
||||
type=str,
|
||||
default='pair-wise',
|
||||
)
|
||||
|
@ -42,3 +44,78 @@ class GenConfigTest(unittest.TestCase):
|
|||
[['node0', 'node6'], ['node7', 'node5'], ['node1', 'node4'], ['node2', 'node3']]
|
||||
]
|
||||
self.assertEqual(gen_traffic_pattern_host_group(hostx, pattern), expected_host_group)
|
||||
|
||||
# Test for k-batch pattern
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
'--type',
|
||||
type=str,
|
||||
default='k-batch',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--batch',
|
||||
type=int,
|
||||
default=3,
|
||||
)
|
||||
pattern, _ = parser.parse_known_args()
|
||||
expected_host_group = [[['node0', 'node1', 'node2'], ['node3', 'node4', 'node5']]]
|
||||
self.assertEqual(gen_traffic_pattern_host_group(hostx, pattern), expected_host_group)
|
||||
|
||||
# Test for topo-aware pattern
|
||||
tp_ibstat_path = 'tests/data/ib_traffic_topo_aware_ibstat.txt'
|
||||
tp_ibnetdiscover_path = 'tests/data/ib_traffic_topo_aware_ibnetdiscover.txt'
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
'--type',
|
||||
type=str,
|
||||
default='topo-aware',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--ibstat',
|
||||
type=str,
|
||||
default=tp_ibstat_path,
|
||||
)
|
||||
parser.add_argument(
|
||||
'--ibnetdiscover',
|
||||
type=str,
|
||||
default=tp_ibnetdiscover_path,
|
||||
)
|
||||
parser.add_argument(
|
||||
'--min_dist',
|
||||
type=int,
|
||||
default=2,
|
||||
)
|
||||
parser.add_argument(
|
||||
'--max_dist',
|
||||
type=int,
|
||||
default=6,
|
||||
)
|
||||
hostx = tp_hostfile.split()
|
||||
pattern, _ = parser.parse_known_args()
|
||||
expected_host_group = [
|
||||
[
|
||||
['vma414bbc00005I', 'vma414bbc00005J'], ['vma414bbc00005K', 'vma414bbc00005L'],
|
||||
['vma414bbc00005M', 'vma414bbc00005N'], ['vma414bbc00005O', 'vma414bbc00005P'],
|
||||
['vma414bbc00005Q', 'vma414bbc00005R']
|
||||
],
|
||||
[
|
||||
['vma414bbc00005I', 'vma414bbc00005K'], ['vma414bbc00005J', 'vma414bbc00005L'],
|
||||
['vma414bbc00005O', 'vma414bbc00005Q'], ['vma414bbc00005P', 'vma414bbc00005R']
|
||||
],
|
||||
[
|
||||
['vma414bbc00005I', 'vma414bbc00005O'], ['vma414bbc00005J', 'vma414bbc00005P'],
|
||||
['vma414bbc00005K', 'vma414bbc00005Q'], ['vma414bbc00005L', 'vma414bbc00005R']
|
||||
]
|
||||
]
|
||||
self.assertEqual(gen_traffic_pattern_host_group(hostx, pattern), expected_host_group)
|
||||
|
||||
# Test for invalid pattern
|
||||
hostx = ['node0', 'node1', 'node2', 'node3', 'node4', 'node5', 'node6', 'node7']
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
'--type',
|
||||
type=str,
|
||||
default='invalid pattern',
|
||||
)
|
||||
pattern, _ = parser.parse_known_args()
|
||||
gen_traffic_pattern_host_group(hostx, pattern)
|
||||
|
|
|
@ -0,0 +1,297 @@
|
|||
# SuperBench Config
|
||||
version: v0.6
|
||||
superbench:
|
||||
enable: null
|
||||
monitor:
|
||||
enable: true
|
||||
sample_duration: 1
|
||||
sample_interval: 10
|
||||
var:
|
||||
default_local_mode: &default_local_mode
|
||||
enable: true
|
||||
modes:
|
||||
- name: local
|
||||
proc_num: 8
|
||||
prefix: CUDA_VISIBLE_DEVICES={proc_rank}
|
||||
parallel: yes
|
||||
default_pytorch_mode: &default_pytorch_mode
|
||||
enable: true
|
||||
modes:
|
||||
- name: torch.distributed
|
||||
proc_num: 8
|
||||
node_num: 1
|
||||
frameworks:
|
||||
- pytorch
|
||||
common_model_config: &common_model_config
|
||||
duration: 0
|
||||
num_warmup: 16
|
||||
num_steps: 128
|
||||
batch_size: 1
|
||||
precision:
|
||||
- float32
|
||||
- float16
|
||||
model_action:
|
||||
- train
|
||||
benchmarks:
|
||||
gpu-burn:
|
||||
enable: true
|
||||
modes:
|
||||
- name: local
|
||||
proc_num: 1
|
||||
parallel: no
|
||||
parameters:
|
||||
time: 300
|
||||
doubles: true
|
||||
tensor_core: true
|
||||
nccl-bw:default:
|
||||
enable: true
|
||||
modes:
|
||||
- name: local
|
||||
proc_num: 1
|
||||
parallel: no
|
||||
parameters:
|
||||
ngpus: 8
|
||||
nccl-bw:gdr-only:
|
||||
enable: true
|
||||
modes:
|
||||
- name: local
|
||||
proc_num: 1
|
||||
parallel: no
|
||||
env:
|
||||
NCCL_IB_PCI_RELAXED_ORDERING: '1'
|
||||
NCCL_NET_GDR_LEVEL: '5'
|
||||
NCCL_P2P_DISABLE: '1'
|
||||
NCCL_SHM_DISABLE: '1'
|
||||
NCCL_MIN_NCHANNELS: '16'
|
||||
NCCL_IB_DISABLE: '0'
|
||||
parameters:
|
||||
ngpus: 8
|
||||
nccl-bw:all-nodes:
|
||||
enable: true
|
||||
modes:
|
||||
- name: mpi
|
||||
proc_num: 8
|
||||
node_num: all
|
||||
pattern:
|
||||
type: all-nodes
|
||||
parameters:
|
||||
ngpus: 8
|
||||
nccl-bw:pair-wise:
|
||||
enable: true
|
||||
modes:
|
||||
- name: mpi
|
||||
proc_num: 8
|
||||
node_num: all
|
||||
pattern:
|
||||
type: pair-wise
|
||||
parameters:
|
||||
ngpus: 8
|
||||
nccl-bw:k-batch:
|
||||
enable: true
|
||||
modes:
|
||||
- name: mpi
|
||||
proc_num: 8
|
||||
node_num: all
|
||||
pattern:
|
||||
type: k-batch
|
||||
parameters:
|
||||
ngpus: 8
|
||||
nccl-bw:topo-aware:
|
||||
enable: true
|
||||
modes:
|
||||
- name: mpi
|
||||
proc_num: 8
|
||||
node_num: all
|
||||
pattern:
|
||||
type: topo-aware
|
||||
ibstat: ibstat_file.txt
|
||||
ibnetdiscover: ibnetdiscover_file.txt
|
||||
min_dist: 2
|
||||
max_dist: 6
|
||||
parameters:
|
||||
ngpus: 8
|
||||
ib-loopback:
|
||||
enable: true
|
||||
modes:
|
||||
- name: local
|
||||
proc_num: 4
|
||||
prefix: PROC_RANK={proc_rank} IB_DEVICES=0,2,4,6 NUMA_NODES=1,0,3,2
|
||||
parallel: yes
|
||||
- name: local
|
||||
proc_num: 4
|
||||
prefix: PROC_RANK={proc_rank} IB_DEVICES=1,3,5,7 NUMA_NODES=1,0,3,2
|
||||
parallel: yes
|
||||
disk-benchmark:
|
||||
enable: false
|
||||
modes:
|
||||
- name: local
|
||||
proc_num: 1
|
||||
parallel: no
|
||||
parameters:
|
||||
block_devices:
|
||||
- /dev/nvme0n1
|
||||
cpu-memory-bw-latency:
|
||||
enable: false
|
||||
modes:
|
||||
- name: local
|
||||
proc_num: 1
|
||||
parallel: no
|
||||
parameters:
|
||||
tests:
|
||||
- bandwidth_matrix
|
||||
- latency_matrix
|
||||
- max_bandwidth
|
||||
mem-bw:
|
||||
enable: true
|
||||
modes:
|
||||
- name: local
|
||||
proc_num: 8
|
||||
prefix: CUDA_VISIBLE_DEVICES={proc_rank} numactl -N $(({proc_rank}/2))
|
||||
parallel: no
|
||||
gpu-copy-bw:correctness:
|
||||
enable: true
|
||||
modes:
|
||||
- name: local
|
||||
parallel: no
|
||||
parameters:
|
||||
mem_type:
|
||||
- htod
|
||||
- dtoh
|
||||
- dtod
|
||||
copy_type:
|
||||
- sm
|
||||
- dma
|
||||
size: 4096
|
||||
num_warm_up: 0
|
||||
num_loops: 1
|
||||
check_data: true
|
||||
gpu-copy-bw:perf:
|
||||
enable: true
|
||||
modes:
|
||||
- name: local
|
||||
parallel: no
|
||||
parameters:
|
||||
mem_type:
|
||||
- htod
|
||||
- dtoh
|
||||
- dtod
|
||||
copy_type:
|
||||
- sm
|
||||
- dma
|
||||
kernel-launch:
|
||||
<<: *default_local_mode
|
||||
gemm-flops:
|
||||
<<: *default_local_mode
|
||||
cudnn-function:
|
||||
<<: *default_local_mode
|
||||
cublas-function:
|
||||
<<: *default_local_mode
|
||||
matmul:
|
||||
<<: *default_local_mode
|
||||
frameworks:
|
||||
- pytorch
|
||||
sharding-matmul:
|
||||
<<: *default_pytorch_mode
|
||||
computation-communication-overlap:
|
||||
<<: *default_pytorch_mode
|
||||
ib-traffic:
|
||||
enable: false
|
||||
modes:
|
||||
- name: mpi
|
||||
proc_num: 8
|
||||
parameters:
|
||||
msg_size: 8388608
|
||||
ib_dev: mlx5_$LOCAL_RANK
|
||||
gpu_dev: $LOCAL_RANK
|
||||
numa_dev: $((LOCAL_RANK/2))
|
||||
gpcnet-network-test:
|
||||
enable: false
|
||||
modes:
|
||||
- name: mpi
|
||||
proc_num: 1
|
||||
mca:
|
||||
pml: ucx
|
||||
btl: ^uct
|
||||
btl_tcp_if_include: eth0
|
||||
env:
|
||||
UCX_NET_DEVICES: mlx5_0:1
|
||||
gpcnet-network-load-test:
|
||||
enable: false
|
||||
modes:
|
||||
- name: mpi
|
||||
proc_num: 1
|
||||
mca:
|
||||
pml: ucx
|
||||
btl: ^uct
|
||||
btl_tcp_if_include: eth0
|
||||
env:
|
||||
UCX_NET_DEVICES: mlx5_0:1
|
||||
tcp-connectivity:
|
||||
enable: false
|
||||
modes:
|
||||
- name: local
|
||||
parallel: no
|
||||
parameters:
|
||||
port: 22
|
||||
ort-inference:
|
||||
<<: *default_local_mode
|
||||
parameters:
|
||||
batch_size: 1
|
||||
tensorrt-inference:
|
||||
<<: *default_local_mode
|
||||
parameters:
|
||||
pytorch_models:
|
||||
- resnet50
|
||||
- resnet101
|
||||
- resnet152
|
||||
- densenet169
|
||||
- densenet201
|
||||
- bert-base
|
||||
- bert-large
|
||||
seq_length: 224
|
||||
batch_size: 1
|
||||
precision: int8
|
||||
gpt_models:
|
||||
<<: *default_pytorch_mode
|
||||
models:
|
||||
- gpt2-small
|
||||
- gpt2-large
|
||||
parameters:
|
||||
<<: *common_model_config
|
||||
bert_models:
|
||||
<<: *default_pytorch_mode
|
||||
models:
|
||||
- bert-base
|
||||
- bert-large
|
||||
parameters:
|
||||
<<: *common_model_config
|
||||
lstm_models:
|
||||
<<: *default_pytorch_mode
|
||||
models:
|
||||
- lstm
|
||||
parameters:
|
||||
<<: *common_model_config
|
||||
resnet_models:
|
||||
<<: *default_pytorch_mode
|
||||
models:
|
||||
- resnet50
|
||||
- resnet101
|
||||
- resnet152
|
||||
parameters:
|
||||
<<: *common_model_config
|
||||
densenet_models:
|
||||
<<: *default_pytorch_mode
|
||||
models:
|
||||
- densenet169
|
||||
- densenet201
|
||||
parameters:
|
||||
<<: *common_model_config
|
||||
vgg_models:
|
||||
<<: *default_pytorch_mode
|
||||
models:
|
||||
- vgg11
|
||||
- vgg13
|
||||
- vgg16
|
||||
- vgg19
|
||||
parameters:
|
||||
<<: *common_model_config
|
|
@ -20,13 +20,13 @@ class RunnerTestCase(unittest.TestCase):
|
|||
"""A class for runner test cases."""
|
||||
def setUp(self):
|
||||
"""Hook method for setting up the test fixture before exercising it."""
|
||||
default_config_file = Path(__file__).parent / '../../superbench/config/default.yaml'
|
||||
with default_config_file.open() as fp:
|
||||
self.default_config = OmegaConf.create(yaml.load(fp, Loader=yaml.SafeLoader))
|
||||
test_config_file = Path(__file__).parent / '../../tests/data/test.yaml'
|
||||
with test_config_file.open() as fp:
|
||||
self.test_config = OmegaConf.create(yaml.load(fp, Loader=yaml.SafeLoader))
|
||||
self.sb_output_dir = tempfile.mkdtemp()
|
||||
|
||||
self.runner = SuperBenchRunner(
|
||||
self.default_config,
|
||||
self.test_config,
|
||||
OmegaConf.create({}),
|
||||
OmegaConf.create({}),
|
||||
self.sb_output_dir,
|
||||
|
@ -205,7 +205,7 @@ class RunnerTestCase(unittest.TestCase):
|
|||
'proc_rank': 1,
|
||||
'mca': {},
|
||||
'pattern': {
|
||||
'name': 'all-nodes',
|
||||
'type': 'all-nodes',
|
||||
},
|
||||
'env': {
|
||||
'PATH': None,
|
||||
|
|
Загрузка…
Ссылка в новой задаче