Benchmarks: Add Benchmark - Add memory bandwidth benchmark for cuda. (#114)

Add microbenchmark, example, test, config for cuda memory performance and Add cuda-samples(tag with cuda version) as git submodule and update related makefile
This commit is contained in:
Yuting Jiang 2021-07-13 17:30:19 +08:00 коммит произвёл GitHub
Родитель 71c1617b2e
Коммит f9550bd693
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
7 изменённых файлов: 530 добавлений и 3 удалений

3
.gitmodules поставляемый
Просмотреть файл

@ -2,3 +2,6 @@
path = third_party/cutlass
url = https://github.com/NVIDIA/cutlass.git
branch = v2.4.0
[submodule "third_party/cuda-samples"]
path = third_party/cuda-samples
url = https://github.com/NVIDIA/cuda-samples.git

Просмотреть файл

@ -0,0 +1,22 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""Micro benchmark example for device memory bandwidth performance.
Commands to run:
python3 examples/benchmarks/cuda_memory_bw_performance.py
"""
from superbench.benchmarks import BenchmarkRegistry, Platform
from superbench.common.utils import logger
if __name__ == '__main__':
context = BenchmarkRegistry.create_benchmark_context('mem-bw', platform=Platform.CUDA)
benchmark = BenchmarkRegistry.launch_benchmark(context)
if benchmark:
logger.info(
'benchmark: {}, return code: {}, result: {}'.format(
benchmark.name, benchmark.return_code, benchmark.result
)
)

Просмотреть файл

@ -10,8 +10,9 @@ from superbench.benchmarks.micro_benchmarks.kernel_launch_overhead import Kernel
from superbench.benchmarks.micro_benchmarks.cublas_function import CublasBenchmark
from superbench.benchmarks.micro_benchmarks.cudnn_function import CudnnBenchmark
from superbench.benchmarks.micro_benchmarks.gemm_flops_performance import GemmFlopsCuda
from superbench.benchmarks.micro_benchmarks.cuda_memory_bw_performance import CudaMemBwBenchmark
__all__ = [
'MicroBenchmark', 'MicroBenchmarkWithInvoke', 'ShardingMatmul', 'ComputationCommunicationOverlap', 'KernelLaunch',
'CublasBenchmark', 'CudnnBenchmark', 'GemmFlopsCuda'
'CublasBenchmark', 'CudnnBenchmark', 'GemmFlopsCuda', 'CudaMemBwBenchmark'
]

Просмотреть файл

@ -0,0 +1,145 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""Module of the Cuda memory performance benchmarks."""
import os
import re
from superbench.common.utils import logger
from superbench.benchmarks import BenchmarkRegistry, Platform, ReturnCode
from superbench.benchmarks.micro_benchmarks import MicroBenchmarkWithInvoke
class CudaMemBwBenchmark(MicroBenchmarkWithInvoke):
"""The Cuda memory performance benchmark class."""
def __init__(self, name, parameters=''):
"""Constructor.
Args:
name (str): benchmark name.
parameters (str): benchmark parameters.
"""
super().__init__(name, parameters)
self._bin_name = 'bandwidthTest'
self.__mem_types = ['htod', 'dtoh', 'dtod']
self.__memory = ['pageable', 'pinned']
def add_parser_arguments(self):
"""Add the specified arguments."""
super().add_parser_arguments()
self._parser.add_argument(
'--mem_type',
type=str,
nargs='+',
default=self.__mem_types,
help='Memory types to benchmark. E.g. {}.'.format(' '.join(self.__mem_types)),
)
self._parser.add_argument(
'--shmoo_mode',
action='store_true',
default=False,
help='Enable shmoo mode for bandwidthtest.',
)
self._parser.add_argument(
'--memory',
type=str,
default=None,
help='Memory argument for bandwidthtest. E.g. {}.'.format(' '.join(self.__memory)),
)
def _preprocess(self):
"""Preprocess/preparation operations before the benchmarking.
Return:
True if _preprocess() succeed.
"""
if not super()._preprocess():
return False
# Format the arguments
if not isinstance(self._args.mem_type, list):
self._args.mem_type = [self._args.mem_type]
self._args.mem_type = [p.lower() for p in self._args.mem_type]
# Check the arguments and generate the commands
for mem_type in self._args.mem_type:
if mem_type not in self.__mem_types:
self._result.set_return_code(ReturnCode.INVALID_ARGUMENT)
logger.error(
'Unsupported mem_type of bandwidth test - benchmark: {}, mem_type: {}, expected: {}.'.format(
self._name, mem_type, ' '.join(self.__mem_types)
)
)
return False
else:
command = os.path.join(self._args.bin_dir, self._bin_name)
command += ' --' + mem_type
if self._args.shmoo_mode:
command += ' mode=shmoo'
if self._args.memory:
if self._args.memory in self.__memory:
command += ' memory=' + self._args.memory
else:
self._result.set_return_code(ReturnCode.INVALID_ARGUMENT)
logger.error(
'Unsupported memory argument of bandwidth test - benchmark: {}, memory: {}, expected: {}.'.
format(self._name, self._args.memory, ' '.join(self.__memory))
)
return False
command += ' --csv'
self._commands.append(command)
return True
def _process_raw_result(self, cmd_idx, raw_output):
"""Function to parse raw results and save the summarized results.
self._result.add_raw_data() and self._result.add_result() need to be called to save the results.
Args:
cmd_idx (int): the index of command corresponding with the raw_output.
raw_output (str): raw output string of the micro-benchmark.
Return:
True if the raw output string is valid and result can be extracted.
"""
self._result.add_raw_data('raw_output_' + self._args.mem_type[cmd_idx], raw_output)
mem_bw = -1
metric = ''
valid = True
content = raw_output.splitlines()
try:
for index, line in enumerate(content):
if 'H2D' in line:
metric = 'H2D_Mem_BW'
elif 'D2H' in line:
metric = 'D2H_Mem_BW'
elif 'D2D' in line:
metric = 'D2D_Mem_BW'
else:
continue
line = line.split(',')[1]
value = re.search(r'(\d+.\d+)', line)
if value:
mem_bw = max(mem_bw, float(value.group(0)))
except BaseException:
valid = False
finally:
if valid is False or mem_bw == -1:
logger.error(
'The result format is invalid - round: {}, benchmark: {}, raw output: {}.'.format(
self._curr_run_index, self._name, raw_output
)
)
return False
self._result.add_result(metric, mem_bw)
return True
BenchmarkRegistry.register_benchmark('mem-bw', CudaMemBwBenchmark, platform=Platform.CUDA)

Просмотреть файл

@ -28,6 +28,13 @@ superbench:
model_action:
- train
benchmarks:
mem-bw:
enable: true
modes:
- name: local
proc_num: 8
prefix: CUDA_VISIBLE_DEVICES={proc_rank} numactl -c $(({proc_rank}/2))
parallel: yes
kernel-launch:
<<: *default_local_mode
gemm-flops:

Просмотреть файл

@ -0,0 +1,335 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Tests for mem-bw benchmark."""
import numbers
from pathlib import Path
import os
import unittest
from superbench.benchmarks import BenchmarkRegistry, BenchmarkType, ReturnCode, Platform
class CudaMemBwTest(unittest.TestCase):
"""Test class for cuda mem-bw benchmark."""
def setUp(self):
"""Method called to prepare the test fixture."""
# Create fake binary file just for testing.
os.environ['SB_MICRO_PATH'] = '/tmp/superbench/'
binary_path = os.path.join(os.getenv('SB_MICRO_PATH'), 'bin')
Path(os.getenv('SB_MICRO_PATH'), 'bin').mkdir(parents=True, exist_ok=True)
self.__binary_file = Path(binary_path, 'bandwidthTest')
self.__binary_file.touch(mode=0o755, exist_ok=True)
def tearDown(self):
"""Method called after the test method has been called and the result recorded."""
self.__binary_file.unlink()
def test_cuda_memory_bw_performance(self):
"""Test cuda mem-bw benchmark."""
benchmark_name = 'mem-bw'
(benchmark_class,
predefine_params) = BenchmarkRegistry._BenchmarkRegistry__select_benchmark(benchmark_name, Platform.CUDA)
assert (benchmark_class)
benchmark = benchmark_class(benchmark_name, parameters='--shmoo_mode --memory=pinned')
ret = benchmark._preprocess()
assert (ret is True)
assert (benchmark.return_code == ReturnCode.SUCCESS)
# Check basic information.
assert (benchmark)
assert (benchmark.name == 'mem-bw')
assert (benchmark.type == BenchmarkType.MICRO)
# Check command list
expected_command = [
'bandwidthTest --htod mode=shmoo memory=pinned --csv',
'bandwidthTest --dtoh mode=shmoo memory=pinned --csv', 'bandwidthTest --dtod mode=shmoo memory=pinned --csv'
]
for i in range(len(expected_command)):
commnad = benchmark._bin_name + benchmark._commands[i].split(benchmark._bin_name)[1]
assert (commnad == expected_command[i])
# Check results and metrics.
raw_output = {}
raw_output[0] = """
[CUDA Bandwidth Test] - Starting...
Running on...
Device 0: Tesla V100-PCIE-32GB
Shmoo Mode
.................................................................................
bandwidthTest-H2D-Pinned, Bandwidth = 0.4 GB/s, Time = 0.00000 s, Size = 1000 bytes, NumDevsUsed = 1
bandwidthTest-H2D-Pinned, Bandwidth = 0.7 GB/s, Time = 0.00000 s, Size = 2000 bytes, NumDevsUsed = 1
bandwidthTest-H2D-Pinned, Bandwidth = 1.0 GB/s, Time = 0.00000 s, Size = 3000 bytes, NumDevsUsed = 1
bandwidthTest-H2D-Pinned, Bandwidth = 1.4 GB/s, Time = 0.00000 s, Size = 4000 bytes, NumDevsUsed = 1
bandwidthTest-H2D-Pinned, Bandwidth = 1.7 GB/s, Time = 0.00000 s, Size = 5000 bytes, NumDevsUsed = 1
bandwidthTest-H2D-Pinned, Bandwidth = 2.0 GB/s, Time = 0.00000 s, Size = 6000 bytes, NumDevsUsed = 1
bandwidthTest-H2D-Pinned, Bandwidth = 2.3 GB/s, Time = 0.00000 s, Size = 7000 bytes, NumDevsUsed = 1
bandwidthTest-H2D-Pinned, Bandwidth = 2.5 GB/s, Time = 0.00000 s, Size = 8000 bytes, NumDevsUsed = 1
bandwidthTest-H2D-Pinned, Bandwidth = 2.7 GB/s, Time = 0.00000 s, Size = 9000 bytes, NumDevsUsed = 1
bandwidthTest-H2D-Pinned, Bandwidth = 2.9 GB/s, Time = 0.00000 s, Size = 10000 bytes, NumDevsUsed = 1
bandwidthTest-H2D-Pinned, Bandwidth = 3.2 GB/s, Time = 0.00000 s, Size = 11000 bytes, NumDevsUsed = 1
bandwidthTest-H2D-Pinned, Bandwidth = 3.4 GB/s, Time = 0.00000 s, Size = 12000 bytes, NumDevsUsed = 1
bandwidthTest-H2D-Pinned, Bandwidth = 3.5 GB/s, Time = 0.00000 s, Size = 13000 bytes, NumDevsUsed = 1
bandwidthTest-H2D-Pinned, Bandwidth = 3.5 GB/s, Time = 0.00000 s, Size = 14000 bytes, NumDevsUsed = 1
bandwidthTest-H2D-Pinned, Bandwidth = 3.8 GB/s, Time = 0.00000 s, Size = 15000 bytes, NumDevsUsed = 1
bandwidthTest-H2D-Pinned, Bandwidth = 4.0 GB/s, Time = 0.00000 s, Size = 16000 bytes, NumDevsUsed = 1
bandwidthTest-H2D-Pinned, Bandwidth = 4.1 GB/s, Time = 0.00000 s, Size = 17000 bytes, NumDevsUsed = 1
bandwidthTest-H2D-Pinned, Bandwidth = 4.3 GB/s, Time = 0.00000 s, Size = 18000 bytes, NumDevsUsed = 1
bandwidthTest-H2D-Pinned, Bandwidth = 4.4 GB/s, Time = 0.00000 s, Size = 19000 bytes, NumDevsUsed = 1
bandwidthTest-H2D-Pinned, Bandwidth = 4.6 GB/s, Time = 0.00000 s, Size = 20000 bytes, NumDevsUsed = 1
bandwidthTest-H2D-Pinned, Bandwidth = 4.8 GB/s, Time = 0.00000 s, Size = 22000 bytes, NumDevsUsed = 1
bandwidthTest-H2D-Pinned, Bandwidth = 5.0 GB/s, Time = 0.00000 s, Size = 24000 bytes, NumDevsUsed = 1
bandwidthTest-H2D-Pinned, Bandwidth = 5.2 GB/s, Time = 0.00000 s, Size = 26000 bytes, NumDevsUsed = 1
bandwidthTest-H2D-Pinned, Bandwidth = 5.4 GB/s, Time = 0.00001 s, Size = 28000 bytes, NumDevsUsed = 1
bandwidthTest-H2D-Pinned, Bandwidth = 5.7 GB/s, Time = 0.00001 s, Size = 30000 bytes, NumDevsUsed = 1
bandwidthTest-H2D-Pinned, Bandwidth = 5.9 GB/s, Time = 0.00001 s, Size = 32000 bytes, NumDevsUsed = 1
bandwidthTest-H2D-Pinned, Bandwidth = 6.1 GB/s, Time = 0.00001 s, Size = 34000 bytes, NumDevsUsed = 1
bandwidthTest-H2D-Pinned, Bandwidth = 6.3 GB/s, Time = 0.00001 s, Size = 36000 bytes, NumDevsUsed = 1
bandwidthTest-H2D-Pinned, Bandwidth = 6.4 GB/s, Time = 0.00001 s, Size = 38000 bytes, NumDevsUsed = 1
bandwidthTest-H2D-Pinned, Bandwidth = 6.6 GB/s, Time = 0.00001 s, Size = 40000 bytes, NumDevsUsed = 1
bandwidthTest-H2D-Pinned, Bandwidth = 6.7 GB/s, Time = 0.00001 s, Size = 42000 bytes, NumDevsUsed = 1
bandwidthTest-H2D-Pinned, Bandwidth = 6.9 GB/s, Time = 0.00001 s, Size = 44000 bytes, NumDevsUsed = 1
bandwidthTest-H2D-Pinned, Bandwidth = 7.0 GB/s, Time = 0.00001 s, Size = 46000 bytes, NumDevsUsed = 1
bandwidthTest-H2D-Pinned, Bandwidth = 7.1 GB/s, Time = 0.00001 s, Size = 48000 bytes, NumDevsUsed = 1
bandwidthTest-H2D-Pinned, Bandwidth = 7.3 GB/s, Time = 0.00001 s, Size = 50000 bytes, NumDevsUsed = 1
bandwidthTest-H2D-Pinned, Bandwidth = 7.8 GB/s, Time = 0.00001 s, Size = 60000 bytes, NumDevsUsed = 1
bandwidthTest-H2D-Pinned, Bandwidth = 8.2 GB/s, Time = 0.00001 s, Size = 70000 bytes, NumDevsUsed = 1
bandwidthTest-H2D-Pinned, Bandwidth = 8.6 GB/s, Time = 0.00001 s, Size = 80000 bytes, NumDevsUsed = 1
bandwidthTest-H2D-Pinned, Bandwidth = 8.9 GB/s, Time = 0.00001 s, Size = 90000 bytes, NumDevsUsed = 1
bandwidthTest-H2D-Pinned, Bandwidth = 9.2 GB/s, Time = 0.00001 s, Size = 100000 bytes, NumDevsUsed = 1
bandwidthTest-H2D-Pinned, Bandwidth = 10.5 GB/s, Time = 0.00002 s, Size = 200000 bytes, NumDevsUsed = 1
bandwidthTest-H2D-Pinned, Bandwidth = 11.1 GB/s, Time = 0.00003 s, Size = 300000 bytes, NumDevsUsed = 1
bandwidthTest-H2D-Pinned, Bandwidth = 11.4 GB/s, Time = 0.00004 s, Size = 400000 bytes, NumDevsUsed = 1
bandwidthTest-H2D-Pinned, Bandwidth = 11.6 GB/s, Time = 0.00004 s, Size = 500000 bytes, NumDevsUsed = 1
bandwidthTest-H2D-Pinned, Bandwidth = 11.7 GB/s, Time = 0.00005 s, Size = 600000 bytes, NumDevsUsed = 1
bandwidthTest-H2D-Pinned, Bandwidth = 11.8 GB/s, Time = 0.00006 s, Size = 700000 bytes, NumDevsUsed = 1
bandwidthTest-H2D-Pinned, Bandwidth = 11.9 GB/s, Time = 0.00007 s, Size = 800000 bytes, NumDevsUsed = 1
bandwidthTest-H2D-Pinned, Bandwidth = 11.9 GB/s, Time = 0.00008 s, Size = 900000 bytes, NumDevsUsed = 1
bandwidthTest-H2D-Pinned, Bandwidth = 11.7 GB/s, Time = 0.00009 s, Size = 1000000 bytes, NumDevsUsed = 1
bandwidthTest-H2D-Pinned, Bandwidth = 12.1 GB/s, Time = 0.00016 s, Size = 2000000 bytes, NumDevsUsed = 1
bandwidthTest-H2D-Pinned, Bandwidth = 12.3 GB/s, Time = 0.00024 s, Size = 3000000 bytes, NumDevsUsed = 1
bandwidthTest-H2D-Pinned, Bandwidth = 12.3 GB/s, Time = 0.00033 s, Size = 4000000 bytes, NumDevsUsed = 1
bandwidthTest-H2D-Pinned, Bandwidth = 11.5 GB/s, Time = 0.00043 s, Size = 5000000 bytes, NumDevsUsed = 1
bandwidthTest-H2D-Pinned, Bandwidth = 12.3 GB/s, Time = 0.00049 s, Size = 6000000 bytes, NumDevsUsed = 1
bandwidthTest-H2D-Pinned, Bandwidth = 12.3 GB/s, Time = 0.00057 s, Size = 7000000 bytes, NumDevsUsed = 1
bandwidthTest-H2D-Pinned, Bandwidth = 12.3 GB/s, Time = 0.00065 s, Size = 8000000 bytes, NumDevsUsed = 1
bandwidthTest-H2D-Pinned, Bandwidth = 12.3 GB/s, Time = 0.00073 s, Size = 9000000 bytes, NumDevsUsed = 1
bandwidthTest-H2D-Pinned, Bandwidth = 12.4 GB/s, Time = 0.00081 s, Size = 10000000 bytes, NumDevsUsed = 1
bandwidthTest-H2D-Pinned, Bandwidth = 12.4 GB/s, Time = 0.00089 s, Size = 11000000 bytes, NumDevsUsed = 1
bandwidthTest-H2D-Pinned, Bandwidth = 12.4 GB/s, Time = 0.00097 s, Size = 12000000 bytes, NumDevsUsed = 1
bandwidthTest-H2D-Pinned, Bandwidth = 12.4 GB/s, Time = 0.00105 s, Size = 13000000 bytes, NumDevsUsed = 1
bandwidthTest-H2D-Pinned, Bandwidth = 12.4 GB/s, Time = 0.00113 s, Size = 14000000 bytes, NumDevsUsed = 1
bandwidthTest-H2D-Pinned, Bandwidth = 12.4 GB/s, Time = 0.00121 s, Size = 15000000 bytes, NumDevsUsed = 1
bandwidthTest-H2D-Pinned, Bandwidth = 12.4 GB/s, Time = 0.00129 s, Size = 16000000 bytes, NumDevsUsed = 1
bandwidthTest-H2D-Pinned, Bandwidth = 12.4 GB/s, Time = 0.00145 s, Size = 18000000 bytes, NumDevsUsed = 1
bandwidthTest-H2D-Pinned, Bandwidth = 12.4 GB/s, Time = 0.00162 s, Size = 20000000 bytes, NumDevsUsed = 1
bandwidthTest-H2D-Pinned, Bandwidth = 12.4 GB/s, Time = 0.00178 s, Size = 22000000 bytes, NumDevsUsed = 1
bandwidthTest-H2D-Pinned, Bandwidth = 12.4 GB/s, Time = 0.00194 s, Size = 24000000 bytes, NumDevsUsed = 1
bandwidthTest-H2D-Pinned, Bandwidth = 12.4 GB/s, Time = 0.00210 s, Size = 26000000 bytes, NumDevsUsed = 1
bandwidthTest-H2D-Pinned, Bandwidth = 12.4 GB/s, Time = 0.00226 s, Size = 28000000 bytes, NumDevsUsed = 1
bandwidthTest-H2D-Pinned, Bandwidth = 12.4 GB/s, Time = 0.00242 s, Size = 30000000 bytes, NumDevsUsed = 1
bandwidthTest-H2D-Pinned, Bandwidth = 10.5 GB/s, Time = 0.00304 s, Size = 32000000 bytes, NumDevsUsed = 1
bandwidthTest-H2D-Pinned, Bandwidth = 12.2 GB/s, Time = 0.00295 s, Size = 36000000 bytes, NumDevsUsed = 1
bandwidthTest-H2D-Pinned, Bandwidth = 10.8 GB/s, Time = 0.00369 s, Size = 40000000 bytes, NumDevsUsed = 1
bandwidthTest-H2D-Pinned, Bandwidth = 12.4 GB/s, Time = 0.00355 s, Size = 44000000 bytes, NumDevsUsed = 1
bandwidthTest-H2D-Pinned, Bandwidth = 12.4 GB/s, Time = 0.00387 s, Size = 48000000 bytes, NumDevsUsed = 1
bandwidthTest-H2D-Pinned, Bandwidth = 12.1 GB/s, Time = 0.00431 s, Size = 52000000 bytes, NumDevsUsed = 1
bandwidthTest-H2D-Pinned, Bandwidth = 11.7 GB/s, Time = 0.00480 s, Size = 56000000 bytes, NumDevsUsed = 1
bandwidthTest-H2D-Pinned, Bandwidth = 12.4 GB/s, Time = 0.00484 s, Size = 60000000 bytes, NumDevsUsed = 1
bandwidthTest-H2D-Pinned, Bandwidth = 12.1 GB/s, Time = 0.00528 s, Size = 64000000 bytes, NumDevsUsed = 1
bandwidthTest-H2D-Pinned, Bandwidth = 12.4 GB/s, Time = 0.00549 s, Size = 68000000 bytes, NumDevsUsed = 1
Result = PASS
"""
raw_output[1] = """
[CUDA Bandwidth Test] - Starting...
Running on...
Device 0: Tesla V100-PCIE-32GB
Shmoo Mode
.................................................................................
bandwidthTest-D2H-Pinned, Bandwidth = 0.4 GB/s, Time = 0.00000 s, Size = 1000 bytes, NumDevsUsed = 1
bandwidthTest-D2H-Pinned, Bandwidth = 0.5 GB/s, Time = 0.00000 s, Size = 2000 bytes, NumDevsUsed = 1
bandwidthTest-D2H-Pinned, Bandwidth = 0.9 GB/s, Time = 0.00000 s, Size = 3000 bytes, NumDevsUsed = 1
bandwidthTest-D2H-Pinned, Bandwidth = 1.1 GB/s, Time = 0.00000 s, Size = 4000 bytes, NumDevsUsed = 1
bandwidthTest-D2H-Pinned, Bandwidth = 1.4 GB/s, Time = 0.00000 s, Size = 5000 bytes, NumDevsUsed = 1
bandwidthTest-D2H-Pinned, Bandwidth = 1.9 GB/s, Time = 0.00000 s, Size = 6000 bytes, NumDevsUsed = 1
bandwidthTest-D2H-Pinned, Bandwidth = 2.6 GB/s, Time = 0.00000 s, Size = 7000 bytes, NumDevsUsed = 1
bandwidthTest-D2H-Pinned, Bandwidth = 2.9 GB/s, Time = 0.00000 s, Size = 8000 bytes, NumDevsUsed = 1
bandwidthTest-D2H-Pinned, Bandwidth = 3.3 GB/s, Time = 0.00000 s, Size = 9000 bytes, NumDevsUsed = 1
bandwidthTest-D2H-Pinned, Bandwidth = 3.7 GB/s, Time = 0.00000 s, Size = 10000 bytes, NumDevsUsed = 1
bandwidthTest-D2H-Pinned, Bandwidth = 4.0 GB/s, Time = 0.00000 s, Size = 11000 bytes, NumDevsUsed = 1
bandwidthTest-D2H-Pinned, Bandwidth = 4.5 GB/s, Time = 0.00000 s, Size = 12000 bytes, NumDevsUsed = 1
bandwidthTest-D2H-Pinned, Bandwidth = 4.9 GB/s, Time = 0.00000 s, Size = 13000 bytes, NumDevsUsed = 1
bandwidthTest-D2H-Pinned, Bandwidth = 5.3 GB/s, Time = 0.00000 s, Size = 14000 bytes, NumDevsUsed = 1
bandwidthTest-D2H-Pinned, Bandwidth = 5.3 GB/s, Time = 0.00000 s, Size = 15000 bytes, NumDevsUsed = 1
bandwidthTest-D2H-Pinned, Bandwidth = 5.6 GB/s, Time = 0.00000 s, Size = 16000 bytes, NumDevsUsed = 1
bandwidthTest-D2H-Pinned, Bandwidth = 5.7 GB/s, Time = 0.00000 s, Size = 17000 bytes, NumDevsUsed = 1
bandwidthTest-D2H-Pinned, Bandwidth = 6.0 GB/s, Time = 0.00000 s, Size = 18000 bytes, NumDevsUsed = 1
bandwidthTest-D2H-Pinned, Bandwidth = 6.2 GB/s, Time = 0.00000 s, Size = 19000 bytes, NumDevsUsed = 1
bandwidthTest-D2H-Pinned, Bandwidth = 6.3 GB/s, Time = 0.00000 s, Size = 20000 bytes, NumDevsUsed = 1
bandwidthTest-D2H-Pinned, Bandwidth = 6.5 GB/s, Time = 0.00000 s, Size = 22000 bytes, NumDevsUsed = 1
bandwidthTest-D2H-Pinned, Bandwidth = 6.9 GB/s, Time = 0.00000 s, Size = 24000 bytes, NumDevsUsed = 1
bandwidthTest-D2H-Pinned, Bandwidth = 7.1 GB/s, Time = 0.00000 s, Size = 26000 bytes, NumDevsUsed = 1
bandwidthTest-D2H-Pinned, Bandwidth = 7.4 GB/s, Time = 0.00000 s, Size = 28000 bytes, NumDevsUsed = 1
bandwidthTest-D2H-Pinned, Bandwidth = 7.6 GB/s, Time = 0.00000 s, Size = 30000 bytes, NumDevsUsed = 1
bandwidthTest-D2H-Pinned, Bandwidth = 7.9 GB/s, Time = 0.00000 s, Size = 32000 bytes, NumDevsUsed = 1
bandwidthTest-D2H-Pinned, Bandwidth = 8.0 GB/s, Time = 0.00000 s, Size = 34000 bytes, NumDevsUsed = 1
bandwidthTest-D2H-Pinned, Bandwidth = 8.3 GB/s, Time = 0.00000 s, Size = 36000 bytes, NumDevsUsed = 1
bandwidthTest-D2H-Pinned, Bandwidth = 8.5 GB/s, Time = 0.00000 s, Size = 38000 bytes, NumDevsUsed = 1
bandwidthTest-D2H-Pinned, Bandwidth = 8.6 GB/s, Time = 0.00000 s, Size = 40000 bytes, NumDevsUsed = 1
bandwidthTest-D2H-Pinned, Bandwidth = 8.7 GB/s, Time = 0.00000 s, Size = 42000 bytes, NumDevsUsed = 1
bandwidthTest-D2H-Pinned, Bandwidth = 9.3 GB/s, Time = 0.00000 s, Size = 44000 bytes, NumDevsUsed = 1
bandwidthTest-D2H-Pinned, Bandwidth = 9.4 GB/s, Time = 0.00000 s, Size = 46000 bytes, NumDevsUsed = 1
bandwidthTest-D2H-Pinned, Bandwidth = 9.5 GB/s, Time = 0.00001 s, Size = 48000 bytes, NumDevsUsed = 1
bandwidthTest-D2H-Pinned, Bandwidth = 9.5 GB/s, Time = 0.00001 s, Size = 50000 bytes, NumDevsUsed = 1
bandwidthTest-D2H-Pinned, Bandwidth = 10.1 GB/s, Time = 0.00001 s, Size = 60000 bytes, NumDevsUsed = 1
bandwidthTest-D2H-Pinned, Bandwidth = 10.4 GB/s, Time = 0.00001 s, Size = 70000 bytes, NumDevsUsed = 1
bandwidthTest-D2H-Pinned, Bandwidth = 10.6 GB/s, Time = 0.00001 s, Size = 80000 bytes, NumDevsUsed = 1
bandwidthTest-D2H-Pinned, Bandwidth = 10.9 GB/s, Time = 0.00001 s, Size = 90000 bytes, NumDevsUsed = 1
bandwidthTest-D2H-Pinned, Bandwidth = 11.1 GB/s, Time = 0.00001 s, Size = 100000 bytes, NumDevsUsed = 1
bandwidthTest-D2H-Pinned, Bandwidth = 12.0 GB/s, Time = 0.00002 s, Size = 200000 bytes, NumDevsUsed = 1
bandwidthTest-D2H-Pinned, Bandwidth = 12.4 GB/s, Time = 0.00002 s, Size = 300000 bytes, NumDevsUsed = 1
bandwidthTest-D2H-Pinned, Bandwidth = 12.6 GB/s, Time = 0.00003 s, Size = 400000 bytes, NumDevsUsed = 1
bandwidthTest-D2H-Pinned, Bandwidth = 12.6 GB/s, Time = 0.00004 s, Size = 500000 bytes, NumDevsUsed = 1
bandwidthTest-D2H-Pinned, Bandwidth = 12.7 GB/s, Time = 0.00005 s, Size = 600000 bytes, NumDevsUsed = 1
bandwidthTest-D2H-Pinned, Bandwidth = 12.7 GB/s, Time = 0.00006 s, Size = 700000 bytes, NumDevsUsed = 1
bandwidthTest-D2H-Pinned, Bandwidth = 12.8 GB/s, Time = 0.00006 s, Size = 800000 bytes, NumDevsUsed = 1
bandwidthTest-D2H-Pinned, Bandwidth = 12.9 GB/s, Time = 0.00007 s, Size = 900000 bytes, NumDevsUsed = 1
bandwidthTest-D2H-Pinned, Bandwidth = 12.8 GB/s, Time = 0.00008 s, Size = 1000000 bytes, NumDevsUsed = 1
bandwidthTest-D2H-Pinned, Bandwidth = 13.0 GB/s, Time = 0.00015 s, Size = 2000000 bytes, NumDevsUsed = 1
bandwidthTest-D2H-Pinned, Bandwidth = 13.0 GB/s, Time = 0.00023 s, Size = 3000000 bytes, NumDevsUsed = 1
bandwidthTest-D2H-Pinned, Bandwidth = 13.1 GB/s, Time = 0.00031 s, Size = 4000000 bytes, NumDevsUsed = 1
bandwidthTest-D2H-Pinned, Bandwidth = 13.1 GB/s, Time = 0.00038 s, Size = 5000000 bytes, NumDevsUsed = 1
bandwidthTest-D2H-Pinned, Bandwidth = 13.1 GB/s, Time = 0.00046 s, Size = 6000000 bytes, NumDevsUsed = 1
bandwidthTest-D2H-Pinned, Bandwidth = 13.1 GB/s, Time = 0.00053 s, Size = 7000000 bytes, NumDevsUsed = 1
bandwidthTest-D2H-Pinned, Bandwidth = 13.1 GB/s, Time = 0.00061 s, Size = 8000000 bytes, NumDevsUsed = 1
bandwidthTest-D2H-Pinned, Bandwidth = 12.5 GB/s, Time = 0.00072 s, Size = 9000000 bytes, NumDevsUsed = 1
bandwidthTest-D2H-Pinned, Bandwidth = 13.1 GB/s, Time = 0.00076 s, Size = 10000000 bytes, NumDevsUsed = 1
bandwidthTest-D2H-Pinned, Bandwidth = 13.1 GB/s, Time = 0.00084 s, Size = 11000000 bytes, NumDevsUsed = 1
bandwidthTest-D2H-Pinned, Bandwidth = 13.1 GB/s, Time = 0.00091 s, Size = 12000000 bytes, NumDevsUsed = 1
bandwidthTest-D2H-Pinned, Bandwidth = 13.2 GB/s, Time = 0.00099 s, Size = 13000000 bytes, NumDevsUsed = 1
bandwidthTest-D2H-Pinned, Bandwidth = 13.2 GB/s, Time = 0.00106 s, Size = 14000000 bytes, NumDevsUsed = 1
bandwidthTest-D2H-Pinned, Bandwidth = 13.2 GB/s, Time = 0.00114 s, Size = 15000000 bytes, NumDevsUsed = 1
bandwidthTest-D2H-Pinned, Bandwidth = 13.2 GB/s, Time = 0.00122 s, Size = 16000000 bytes, NumDevsUsed = 1
bandwidthTest-D2H-Pinned, Bandwidth = 13.2 GB/s, Time = 0.00137 s, Size = 18000000 bytes, NumDevsUsed = 1
bandwidthTest-D2H-Pinned, Bandwidth = 13.2 GB/s, Time = 0.00152 s, Size = 20000000 bytes, NumDevsUsed = 1
bandwidthTest-D2H-Pinned, Bandwidth = 13.2 GB/s, Time = 0.00167 s, Size = 22000000 bytes, NumDevsUsed = 1
bandwidthTest-D2H-Pinned, Bandwidth = 13.1 GB/s, Time = 0.00183 s, Size = 24000000 bytes, NumDevsUsed = 1
bandwidthTest-D2H-Pinned, Bandwidth = 12.9 GB/s, Time = 0.00202 s, Size = 26000000 bytes, NumDevsUsed = 1
bandwidthTest-D2H-Pinned, Bandwidth = 13.1 GB/s, Time = 0.00213 s, Size = 28000000 bytes, NumDevsUsed = 1
bandwidthTest-D2H-Pinned, Bandwidth = 13.2 GB/s, Time = 0.00228 s, Size = 30000000 bytes, NumDevsUsed = 1
bandwidthTest-D2H-Pinned, Bandwidth = 13.2 GB/s, Time = 0.00243 s, Size = 32000000 bytes, NumDevsUsed = 1
bandwidthTest-D2H-Pinned, Bandwidth = 13.2 GB/s, Time = 0.00273 s, Size = 36000000 bytes, NumDevsUsed = 1
bandwidthTest-D2H-Pinned, Bandwidth = 13.2 GB/s, Time = 0.00304 s, Size = 40000000 bytes, NumDevsUsed = 1
bandwidthTest-D2H-Pinned, Bandwidth = 13.2 GB/s, Time = 0.00334 s, Size = 44000000 bytes, NumDevsUsed = 1
bandwidthTest-D2H-Pinned, Bandwidth = 13.2 GB/s, Time = 0.00364 s, Size = 48000000 bytes, NumDevsUsed = 1
bandwidthTest-D2H-Pinned, Bandwidth = 13.2 GB/s, Time = 0.00395 s, Size = 52000000 bytes, NumDevsUsed = 1
bandwidthTest-D2H-Pinned, Bandwidth = 13.2 GB/s, Time = 0.00425 s, Size = 56000000 bytes, NumDevsUsed = 1
bandwidthTest-D2H-Pinned, Bandwidth = 13.2 GB/s, Time = 0.00455 s, Size = 60000000 bytes, NumDevsUsed = 1
bandwidthTest-D2H-Pinned, Bandwidth = 13.1 GB/s, Time = 0.00487 s, Size = 64000000 bytes, NumDevsUsed = 1
bandwidthTest-D2H-Pinned, Bandwidth = 13.1 GB/s, Time = 0.00520 s, Size = 68000000 bytes, NumDevsUsed = 1
Result = PASS
"""
raw_output[2] = """
[CUDA Bandwidth Test] - Starting...
Running on...
Device 0: Tesla V100-PCIE-32GB
Shmoo Mode
.................................................................................
bandwidthTest-D2D, Bandwidth = 0.4 GB/s, Time = 0.00000 s, Size = 1000 bytes, NumDevsUsed = 1
bandwidthTest-D2D, Bandwidth = 0.1 GB/s, Time = 0.00004 s, Size = 2000 bytes, NumDevsUsed = 1
bandwidthTest-D2D, Bandwidth = 0.8 GB/s, Time = 0.00000 s, Size = 3000 bytes, NumDevsUsed = 1
bandwidthTest-D2D, Bandwidth = 1.2 GB/s, Time = 0.00000 s, Size = 4000 bytes, NumDevsUsed = 1
bandwidthTest-D2D, Bandwidth = 0.4 GB/s, Time = 0.00001 s, Size = 5000 bytes, NumDevsUsed = 1
bandwidthTest-D2D, Bandwidth = 1.7 GB/s, Time = 0.00000 s, Size = 6000 bytes, NumDevsUsed = 1
bandwidthTest-D2D, Bandwidth = 7.0 GB/s, Time = 0.00000 s, Size = 7000 bytes, NumDevsUsed = 1
bandwidthTest-D2D, Bandwidth = 8.0 GB/s, Time = 0.00000 s, Size = 8000 bytes, NumDevsUsed = 1
bandwidthTest-D2D, Bandwidth = 9.0 GB/s, Time = 0.00000 s, Size = 9000 bytes, NumDevsUsed = 1
bandwidthTest-D2D, Bandwidth = 10.0 GB/s, Time = 0.00000 s, Size = 10000 bytes, NumDevsUsed = 1
bandwidthTest-D2D, Bandwidth = 6.1 GB/s, Time = 0.00000 s, Size = 11000 bytes, NumDevsUsed = 1
bandwidthTest-D2D, Bandwidth = 12.0 GB/s, Time = 0.00000 s, Size = 12000 bytes, NumDevsUsed = 1
bandwidthTest-D2D, Bandwidth = 13.1 GB/s, Time = 0.00000 s, Size = 13000 bytes, NumDevsUsed = 1
bandwidthTest-D2D, Bandwidth = 5.3 GB/s, Time = 0.00000 s, Size = 14000 bytes, NumDevsUsed = 1
bandwidthTest-D2D, Bandwidth = 8.0 GB/s, Time = 0.00000 s, Size = 15000 bytes, NumDevsUsed = 1
bandwidthTest-D2D, Bandwidth = 8.9 GB/s, Time = 0.00000 s, Size = 16000 bytes, NumDevsUsed = 1
bandwidthTest-D2D, Bandwidth = 9.5 GB/s, Time = 0.00000 s, Size = 17000 bytes, NumDevsUsed = 1
bandwidthTest-D2D, Bandwidth = 9.8 GB/s, Time = 0.00000 s, Size = 18000 bytes, NumDevsUsed = 1
bandwidthTest-D2D, Bandwidth = 19.0 GB/s, Time = 0.00000 s, Size = 19000 bytes, NumDevsUsed = 1
bandwidthTest-D2D, Bandwidth = 5.3 GB/s, Time = 0.00000 s, Size = 20000 bytes, NumDevsUsed = 1
bandwidthTest-D2D, Bandwidth = 22.0 GB/s, Time = 0.00000 s, Size = 22000 bytes, NumDevsUsed = 1
bandwidthTest-D2D, Bandwidth = 6.3 GB/s, Time = 0.00000 s, Size = 24000 bytes, NumDevsUsed = 1
bandwidthTest-D2D, Bandwidth = 0.7 GB/s, Time = 0.00004 s, Size = 26000 bytes, NumDevsUsed = 1
bandwidthTest-D2D, Bandwidth = 28.1 GB/s, Time = 0.00000 s, Size = 28000 bytes, NumDevsUsed = 1
bandwidthTest-D2D, Bandwidth = 30.1 GB/s, Time = 0.00000 s, Size = 30000 bytes, NumDevsUsed = 1
bandwidthTest-D2D, Bandwidth = 32.0 GB/s, Time = 0.00000 s, Size = 32000 bytes, NumDevsUsed = 1
bandwidthTest-D2D, Bandwidth = 14.6 GB/s, Time = 0.00000 s, Size = 34000 bytes, NumDevsUsed = 1
bandwidthTest-D2D, Bandwidth = 20.9 GB/s, Time = 0.00000 s, Size = 36000 bytes, NumDevsUsed = 1
bandwidthTest-D2D, Bandwidth = 22.7 GB/s, Time = 0.00000 s, Size = 38000 bytes, NumDevsUsed = 1
bandwidthTest-D2D, Bandwidth = 23.5 GB/s, Time = 0.00000 s, Size = 40000 bytes, NumDevsUsed = 1
bandwidthTest-D2D, Bandwidth = 24.8 GB/s, Time = 0.00000 s, Size = 42000 bytes, NumDevsUsed = 1
bandwidthTest-D2D, Bandwidth = 44.1 GB/s, Time = 0.00000 s, Size = 44000 bytes, NumDevsUsed = 1
bandwidthTest-D2D, Bandwidth = 27.2 GB/s, Time = 0.00000 s, Size = 46000 bytes, NumDevsUsed = 1
bandwidthTest-D2D, Bandwidth = 48.0 GB/s, Time = 0.00000 s, Size = 48000 bytes, NumDevsUsed = 1
bandwidthTest-D2D, Bandwidth = 28.5 GB/s, Time = 0.00000 s, Size = 50000 bytes, NumDevsUsed = 1
bandwidthTest-D2D, Bandwidth = 60.2 GB/s, Time = 0.00000 s, Size = 60000 bytes, NumDevsUsed = 1
bandwidthTest-D2D, Bandwidth = 42.7 GB/s, Time = 0.00000 s, Size = 70000 bytes, NumDevsUsed = 1
bandwidthTest-D2D, Bandwidth = 8.4 GB/s, Time = 0.00001 s, Size = 80000 bytes, NumDevsUsed = 1
bandwidthTest-D2D, Bandwidth = 55.6 GB/s, Time = 0.00000 s, Size = 90000 bytes, NumDevsUsed = 1
bandwidthTest-D2D, Bandwidth = 59.6 GB/s, Time = 0.00000 s, Size = 100000 bytes, NumDevsUsed = 1
bandwidthTest-D2D, Bandwidth = 127.9 GB/s, Time = 0.00000 s, Size = 200000 bytes, NumDevsUsed = 1
bandwidthTest-D2D, Bandwidth = 183.1 GB/s, Time = 0.00000 s, Size = 300000 bytes, NumDevsUsed = 1
bandwidthTest-D2D, Bandwidth = 270.2 GB/s, Time = 0.00000 s, Size = 400000 bytes, NumDevsUsed = 1
bandwidthTest-D2D, Bandwidth = 15.5 GB/s, Time = 0.00003 s, Size = 500000 bytes, NumDevsUsed = 1
bandwidthTest-D2D, Bandwidth = 399.2 GB/s, Time = 0.00000 s, Size = 600000 bytes, NumDevsUsed = 1
bandwidthTest-D2D, Bandwidth = 172.1 GB/s, Time = 0.00000 s, Size = 700000 bytes, NumDevsUsed = 1
bandwidthTest-D2D, Bandwidth = 27.5 GB/s, Time = 0.00003 s, Size = 800000 bytes, NumDevsUsed = 1
bandwidthTest-D2D, Bandwidth = 71.3 GB/s, Time = 0.00001 s, Size = 900000 bytes, NumDevsUsed = 1
bandwidthTest-D2D, Bandwidth = 502.2 GB/s, Time = 0.00000 s, Size = 1000000 bytes, NumDevsUsed = 1
bandwidthTest-D2D, Bandwidth = 59.4 GB/s, Time = 0.00003 s, Size = 2000000 bytes, NumDevsUsed = 1
bandwidthTest-D2D, Bandwidth = 348.7 GB/s, Time = 0.00001 s, Size = 3000000 bytes, NumDevsUsed = 1
bandwidthTest-D2D, Bandwidth = 519.4 GB/s, Time = 0.00001 s, Size = 4000000 bytes, NumDevsUsed = 1
bandwidthTest-D2D, Bandwidth = 422.3 GB/s, Time = 0.00001 s, Size = 5000000 bytes, NumDevsUsed = 1
bandwidthTest-D2D, Bandwidth = 447.9 GB/s, Time = 0.00001 s, Size = 6000000 bytes, NumDevsUsed = 1
bandwidthTest-D2D, Bandwidth = 225.3 GB/s, Time = 0.00003 s, Size = 7000000 bytes, NumDevsUsed = 1
bandwidthTest-D2D, Bandwidth = 146.0 GB/s, Time = 0.00005 s, Size = 8000000 bytes, NumDevsUsed = 1
bandwidthTest-D2D, Bandwidth = 190.9 GB/s, Time = 0.00005 s, Size = 9000000 bytes, NumDevsUsed = 1
bandwidthTest-D2D, Bandwidth = 301.1 GB/s, Time = 0.00003 s, Size = 10000000 bytes, NumDevsUsed = 1
bandwidthTest-D2D, Bandwidth = 192.8 GB/s, Time = 0.00006 s, Size = 11000000 bytes, NumDevsUsed = 1
bandwidthTest-D2D, Bandwidth = 243.9 GB/s, Time = 0.00005 s, Size = 12000000 bytes, NumDevsUsed = 1
bandwidthTest-D2D, Bandwidth = 328.7 GB/s, Time = 0.00004 s, Size = 13000000 bytes, NumDevsUsed = 1
bandwidthTest-D2D, Bandwidth = 621.2 GB/s, Time = 0.00002 s, Size = 14000000 bytes, NumDevsUsed = 1
bandwidthTest-D2D, Bandwidth = 682.5 GB/s, Time = 0.00002 s, Size = 15000000 bytes, NumDevsUsed = 1
bandwidthTest-D2D, Bandwidth = 686.3 GB/s, Time = 0.00002 s, Size = 16000000 bytes, NumDevsUsed = 1
bandwidthTest-D2D, Bandwidth = 693.1 GB/s, Time = 0.00003 s, Size = 18000000 bytes, NumDevsUsed = 1
bandwidthTest-D2D, Bandwidth = 707.0 GB/s, Time = 0.00003 s, Size = 20000000 bytes, NumDevsUsed = 1
bandwidthTest-D2D, Bandwidth = 714.4 GB/s, Time = 0.00003 s, Size = 22000000 bytes, NumDevsUsed = 1
bandwidthTest-D2D, Bandwidth = 719.4 GB/s, Time = 0.00003 s, Size = 24000000 bytes, NumDevsUsed = 1
bandwidthTest-D2D, Bandwidth = 723.2 GB/s, Time = 0.00004 s, Size = 26000000 bytes, NumDevsUsed = 1
bandwidthTest-D2D, Bandwidth = 726.7 GB/s, Time = 0.00004 s, Size = 28000000 bytes, NumDevsUsed = 1
bandwidthTest-D2D, Bandwidth = 728.8 GB/s, Time = 0.00004 s, Size = 30000000 bytes, NumDevsUsed = 1
bandwidthTest-D2D, Bandwidth = 724.2 GB/s, Time = 0.00004 s, Size = 32000000 bytes, NumDevsUsed = 1
bandwidthTest-D2D, Bandwidth = 735.3 GB/s, Time = 0.00005 s, Size = 36000000 bytes, NumDevsUsed = 1
bandwidthTest-D2D, Bandwidth = 741.1 GB/s, Time = 0.00005 s, Size = 40000000 bytes, NumDevsUsed = 1
bandwidthTest-D2D, Bandwidth = 748.9 GB/s, Time = 0.00006 s, Size = 44000000 bytes, NumDevsUsed = 1
bandwidthTest-D2D, Bandwidth = 748.9 GB/s, Time = 0.00006 s, Size = 48000000 bytes, NumDevsUsed = 1
bandwidthTest-D2D, Bandwidth = 754.1 GB/s, Time = 0.00007 s, Size = 52000000 bytes, NumDevsUsed = 1
bandwidthTest-D2D, Bandwidth = 757.4 GB/s, Time = 0.00007 s, Size = 56000000 bytes, NumDevsUsed = 1
bandwidthTest-D2D, Bandwidth = 758.5 GB/s, Time = 0.00008 s, Size = 60000000 bytes, NumDevsUsed = 1
bandwidthTest-D2D, Bandwidth = 772.0 GB/s, Time = 0.00008 s, Size = 64000000 bytes, NumDevsUsed = 1
bandwidthTest-D2D, Bandwidth = 762.8 GB/s, Time = 0.00009 s, Size = 68000000 bytes, NumDevsUsed = 1
Result = PASS
"""
for i, metric in enumerate(['H2D_Mem_BW', 'D2H_Mem_BW', 'D2D_Mem_BW']):
assert (benchmark._process_raw_result(i, raw_output[i]))
assert (metric in benchmark.result)
assert (len(benchmark.result[metric]) == 1)
assert (isinstance(benchmark.result[metric][0], numbers.Number))

18
third_party/Makefile поставляемый
Просмотреть файл

@ -4,11 +4,15 @@
SB_MICRO_PATH ?= "/usr/local"
.PHONY: all cutlass
.PHONY: all cutlass bandwidthTest
# Build all targets.
all: cutlass
all: cutlass bandwidthTest
# Create $(SB_MICRO_PATH)/bin and $(SB_MICRO_PATH)/lib, no error if existing, make parent directories as needed.
sb_micro_path:
mkdir -p $(SB_MICRO_PATH)/bin
mkdir -p $(SB_MICRO_PATH)/lib
# Build cutlass.
cutlass:
ifneq (,$(wildcard cutlass/CMakeLists.txt))
@ -16,3 +20,13 @@ ifneq (,$(wildcard cutlass/CMakeLists.txt))
-DCUTLASS_NVCC_ARCHS='70;80' -DCUTLASS_ENABLE_EXAMPLES=OFF -DCUTLASS_ENABLE_TESTS=OFF -S ./cutlass -B ./cutlass/build
cmake --build ./cutlass/build -j 8 --target install
endif
# Build cuda-samples/Samples/bandwidthTest.
# cuda-samples is released together with CUDA, they have the exact same version. Like v10.0, v11.1 and so on.
# The version we use is the released tag of cuda-samples which is consistent with the cuda version in the environment or docker.
# The Makefile of bandwidthTest does not have 'install' target, so need to copy bin to $(SB_MICRO_PATH)/bin/ and create $(SB_MICRO_PATH)/bin/ if not existing.
bandwidthTest: sb_micro_path
ifneq (,$(wildcard cuda-samples/Samples/bandwidthTest/Makefile))
cd cuda-samples && git checkout v$(shell nvcc --version | grep 'release' | awk '{print $$6}' | cut -c2- | cut -d '.' -f1-2)
cd ./cuda-samples/Samples/bandwidthTest && make clean && make TARGET_ARCH=x86_64 SMS="70 75 80 86"
cp -v ./cuda-samples/Samples/bandwidthTest/bandwidthTest $(SB_MICRO_PATH)/bin/
endif