Benchmarks: Add Benchmark - Add FLOPs performance benchmark for cuda. (#87)

* add cuda flops performance benchmark.
This commit is contained in:
guoshzhao 2021-06-02 09:15:58 +08:00 коммит произвёл GitHub
Родитель 331c740a15
Коммит 6c6f526937
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
5 изменённых файлов: 307 добавлений и 1 удалений

Просмотреть файл

@ -0,0 +1,23 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""Model benchmark example for Cutlass GEMM FLOPs performance.
Commands to run:
python3 examples/benchmarks/gemm_flops_cuda_performance.py
"""
from superbench.benchmarks import BenchmarkRegistry, Platform
from superbench.common.utils import logger
if __name__ == '__main__':
parameters = '--n 16384 --k 16384 --m 16384'
context = BenchmarkRegistry.create_benchmark_context('gemm-flops', platform=Platform.CUDA, parameters=parameters)
benchmark = BenchmarkRegistry.launch_benchmark(context)
if benchmark:
logger.info(
'benchmark: {}, return code: {}, result: {}'.format(
benchmark.name, benchmark.return_code, benchmark.result
)
)

Просмотреть файл

@ -9,8 +9,9 @@ from superbench.benchmarks.micro_benchmarks.computation_communication_overlap im
from superbench.benchmarks.micro_benchmarks.kernel_launch_overhead import KernelLaunch from superbench.benchmarks.micro_benchmarks.kernel_launch_overhead import KernelLaunch
from superbench.benchmarks.micro_benchmarks.cublas_function import CublasBenchmark from superbench.benchmarks.micro_benchmarks.cublas_function import CublasBenchmark
from superbench.benchmarks.micro_benchmarks.cudnn_function import CudnnBenchmark from superbench.benchmarks.micro_benchmarks.cudnn_function import CudnnBenchmark
from superbench.benchmarks.micro_benchmarks.gemm_flops_performance import GemmFlopsCuda
__all__ = [ __all__ = [
'MicroBenchmark', 'MicroBenchmarkWithInvoke', 'ShardingMatmul', 'ComputationCommunicationOverlap', 'KernelLaunch', 'MicroBenchmark', 'MicroBenchmarkWithInvoke', 'ShardingMatmul', 'ComputationCommunicationOverlap', 'KernelLaunch',
'CublasBenchmark', 'CudnnBenchmark' 'CublasBenchmark', 'CudnnBenchmark', 'GemmFlopsCuda'
] ]

Просмотреть файл

@ -0,0 +1,171 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""Module of the FLOPs performance benchmarks."""
import os
from superbench.common.utils import logger
from superbench.common.utils import nv_helper
from superbench.benchmarks import BenchmarkRegistry, Platform, ReturnCode
from superbench.benchmarks.micro_benchmarks import MicroBenchmarkWithInvoke
class GemmFlopsCuda(MicroBenchmarkWithInvoke):
"""The GEMM FLOPs performance benchmark class."""
def __init__(self, name, parameters=''):
"""Constructor.
Args:
name (str): benchmark name.
parameters (str): benchmark parameters.
"""
super().__init__(name, parameters)
self._bin_name = 'cutlass_profiler'
self.__kernel_map = {
'FP64': 'cutlass_simt_dgemm_128x128_8x2_*',
'FP32': 'cutlass_simt_sgemm_128x128_8x2_*',
'FP16': 'cutlass_simt_hgemm_256x128_8x2_*',
'FP64_TC': 'cutlass_tensorop_d884gemm_128x128_16x3_*',
'TF32_TC': 'cutlass_tensorop_tf32_s1688gemm_tf32_256x128_16x3_*',
'BF16_TC': 'cutlass_tensorop_bf16_s16816gemm_bf16_256x128_32x3_*',
'FP16_TC': 'cutlass_tensorop_h16816gemm_256x128_32x3_*',
'INT8_TC': 'cutlass_tensorop_s8_i16832gemm_s8_256x128_64x3_*',
'INT4_TC': 'cutlass_tensorop_s4_i16864gemm_s4_256x128_128x3_*',
}
def add_parser_arguments(self):
"""Add the specified arguments."""
super().add_parser_arguments()
self._parser.add_argument(
'--num_warmup',
type=int,
default=5,
required=False,
help='The number of warmup step.',
)
self._parser.add_argument(
'--n',
type=int,
default=16384,
required=False,
help='The N dim of matmul (N, K) * (K, M).',
)
self._parser.add_argument(
'--k',
type=int,
default=16384,
required=False,
help='The K dim of matmul (N, K) * (K, M).',
)
self._parser.add_argument(
'--m',
type=int,
default=16384,
required=False,
help='The M dim of matmul (N, K) * (K, M).',
)
self._parser.add_argument(
'--precision',
type=str,
nargs='+',
default=list(self.__kernel_map.keys()),
help='Precision for benchmarking. E.g. {}.'.format(' '.join(list(self.__kernel_map.keys()))),
)
def _preprocess(self):
"""Preprocess/preparation operations before the benchmarking.
Return:
True if _preprocess() succeed.
"""
if not super()._preprocess():
return False
self._args.precision = [p.upper() for p in self._args.precision]
for p in self._args.precision:
if p not in list(self.__kernel_map.keys()):
self._result.set_return_code(ReturnCode.INVALID_ARGUMENT)
logger.error(
'Unsupported precision - benchmark: {}, precision: {}, expected: {}.'.format(
self._name, p, list(self.__kernel_map.keys())
)
)
return False
else:
command = os.path.join(self._args.bin_dir, self._bin_name)
command += (' --warmup-iterations=' + str(self._args.num_warmup))
command += (' --operation=gemm')
command += (' --n=' + str(self._args.n))
command += (' --k=' + str(self._args.k))
command += (' --m=' + str(self._args.m))
command += (' --kernels=' + self.__kernel_map[p])
self._commands.append(command)
# TODO - To support more architecutres, currently only support compute capability = 7.0 or 8.0
capability = nv_helper.get_device_compute_capability()
if capability == 7.0:
self.__kernel_map['FP16_TC'] = 'cutlass_tensorop_h884gemm_256x128_32x2_*'
if capability not in [7.0, 8.0]:
self._result.set_return_code(ReturnCode.MICROBENCHMARK_UNSUPPORTED_ARCHITECTURE)
logger.error(
'Unsupported architecture - benchmark: {}, compute capability: {}, expected: 7.0 or 8.0'.format(
self._name, capability
)
)
return False
return True
def _process_raw_result(self, cmd_idx, raw_output):
"""Function to parse raw results and save the summarized results.
self._result.add_raw_data() and self._result.add_result() need to be called to save the results.
Args:
cmd_idx (int): the index of command corresponding with the raw_output.
raw_output (str): raw output string of the micro-benchmark.
Return:
True if the raw output string is valid and result can be extracted.
"""
precision = self._args.precision[cmd_idx]
self._result.add_raw_data('raw_output_' + precision, raw_output)
valid = True
flops = list()
content = raw_output.splitlines()
try:
for line in content:
if 'gemm,cutlass_simt_dgemm_128x128_8x2' in line or \
'gemm,cutlass_simt_sgemm_128x128_8x2' in line or \
'gemm,cutlass_simt_hgemm_256x128_8x2' in line or \
'gemm,cutlass_tensorop_d884gemm_128x128_16x3' in line or \
'gemm,cutlass_tensorop_tf32_s1688gemm_tf32_256x128_16x3' in line or \
'gemm,cutlass_tensorop_bf16_s16816gemm_bf16_256x128_32x3' in line or \
'gemm,cutlass_tensorop_h16816gemm_256x128_32x3' in line or \
'gemm,cutlass_tensorop_h884gemm_256x128_32x2' in line or \
'gemm,cutlass_tensorop_s8_i16832gemm_s8_256x128_64x3' in line or \
'gemm,cutlass_tensorop_s4_i16864gemm_s4_256x128_128x3' in line:
flops.append(float(line.split(',')[-1]))
except BaseException:
valid = False
finally:
if valid is False or len(flops) == 0:
logger.error(
'The result format is invalid - round: {}, benchmark: {}, raw output: {}.'.format(
self._curr_run_index, self._name, raw_output
)
)
return False
self._result.add_result(precision, max(flops))
return True
BenchmarkRegistry.register_benchmark('gemm-flops', GemmFlopsCuda, platform=Platform.CUDA)

Просмотреть файл

@ -28,3 +28,4 @@ class ReturnCode(Enum):
MICROBENCHMARK_BINARY_NOT_EXIST = 31 MICROBENCHMARK_BINARY_NOT_EXIST = 31
MICROBENCHMARK_EXECUTION_FAILURE = 32 MICROBENCHMARK_EXECUTION_FAILURE = 32
MICROBENCHMARK_RESULT_PARSING_FAILURE = 33 MICROBENCHMARK_RESULT_PARSING_FAILURE = 33
MICROBENCHMARK_UNSUPPORTED_ARCHITECTURE = 34

Просмотреть файл

@ -0,0 +1,110 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Tests for gemm-flops benchmark."""
import os
import unittest
from pathlib import Path
from tests.helper import decorator
from superbench.common.utils import nv_helper
from superbench.benchmarks import BenchmarkRegistry, ReturnCode, Platform, BenchmarkType
class GemmFlopsCudaTest(unittest.TestCase):
"""Tests for GemmFlopsCuda benchmark."""
def setUp(self):
"""Method called to prepare the test fixture."""
# Create fake binary file just for testing.
os.environ['SB_MICRO_PATH'] = '/tmp/superbench/'
binary_path = os.path.join(os.getenv('SB_MICRO_PATH'), 'bin')
Path(binary_path).mkdir(parents=True, exist_ok=True)
self.__binary_file = Path(os.path.join(binary_path, 'cutlass_profiler'))
self.__binary_file.touch(mode=0o755, exist_ok=True)
def tearDown(self):
"""Method called after the test method has been called and the result recorded."""
self.__binary_file.unlink()
@decorator.cuda_test
def test_flops_performance_cuda(self):
"""Test gemm-flops benchmark."""
benchmark_name = 'gemm-flops'
(benchmark_class,
predefine_params) = BenchmarkRegistry._BenchmarkRegistry__select_benchmark(benchmark_name, Platform.CUDA)
assert (benchmark_class)
# Negative case - MICROBENCHMARK_UNSUPPORTED_ARCHITECTURE.
benchmark = benchmark_class(
benchmark_name,
parameters='--num_warmup 200 --n 1024 --k 512 --m 2048 --precision FP32 TF32_TC FP16_TC INT8_TC'
)
ret = benchmark._preprocess()
if nv_helper.get_device_compute_capability() not in [7.0, 8.0]:
assert (ret is False)
assert (benchmark.return_code == ReturnCode.MICROBENCHMARK_UNSUPPORTED_ARCHITECTURE)
else:
assert (ret is True)
assert (benchmark.return_code == ReturnCode.SUCCESS)
# Check basic information.
assert (benchmark.name == 'gemm-flops')
assert (benchmark.type == BenchmarkType.MICRO)
assert (benchmark._bin_name == 'cutlass_profiler')
# Check parameters specified in BenchmarkContext.
assert (benchmark._args.num_warmup == 200)
assert (benchmark._args.n == 1024)
assert (benchmark._args.k == 512)
assert (benchmark._args.m == 2048)
assert (benchmark._args.precision == ['FP32', 'TF32_TC', 'FP16_TC', 'INT8_TC'])
# Check the command list.
for i in range(len(benchmark._args.precision)):
command = '{} --warmup-iterations={} --operation=gemm --n={} --k={} --m={} --kernels={}'.format(
benchmark._bin_name, benchmark._args.num_warmup, benchmark._args.n, benchmark._args.k,
benchmark._args.m, benchmark._GemmFlopsCuda__kernel_map[benchmark._args.precision[i]]
)
expected_cmd = benchmark._bin_name + benchmark._commands[i].split(benchmark._bin_name)[1]
assert (command == expected_cmd)
# Check results and metrics.
raw_output_FP32 = """
CSV Results:
Problem,Provider,OperationKind,Operation,Disposition,Status,gemm_kind,m,n,k,A,B,C,alpha,beta,split_k_slices,batch_count,op_class,accum,cta_m,cta_n,cta_k,stages,warps_m,warps_n,warps_k,inst_m,inst_n,inst_k,min_cc,max_cc,Bytes,Flops,Runtime,GB/s,GFLOPs
1,CUTLASS,gemm,cutlass_simt_sgemm_128x128_8x2_nn_align1,passed,success,universal,16384,16384,16384,f32:column,f32:column,f32:column,1,0,1,1,simt,f32,128,128,8,2,4,2,1,1,1,1,50,1024,3221225472,8796629893120,481.022,6.23672,18287.4
1,CUTLASS,gemm,cutlass_simt_sgemm_128x128_8x2_nt_align1,passed,success,universal,16384,16384,16384,f32:column,f32:row,f32:column,1,0,1,1,simt,f32,128,128,8,2,4,2,1,1,1,1,50,1024,3221225472,8796629893120,478.866,6.2648,18369.7
1,CUTLASS,gemm,cutlass_simt_sgemm_128x128_8x2_tn_align1,passed,success,universal,16384,16384,16384,f32:row,f32:column,f32:column,1,0,1,1,simt,f32,128,128,8,2,4,2,1,1,1,1,50,1024,3221225472,8796629893120,482.034,6.22363,18249
1,CUTLASS,gemm,cutlass_simt_sgemm_128x128_8x2_tt_align1,passed,success,universal,16384,16384,16384,f32:row,f32:row,f32:column,1,0,1,1,simt,f32,128,128,8,2,4,2,1,1,1,1,50,1024,3221225472,8796629893120,481.838,6.22616,18256.4
"""
raw_output_TF32_TC = """
CSV Results:
Problem,Provider,OperationKind,Operation,Disposition,Status,gemm_kind,m,n,k,A,B,C,alpha,beta,split_k_slices,batch_count,op_class,accum,cta_m,cta_n,cta_k,stages,warps_m,warps_n,warps_k,inst_m,inst_n,inst_k,min_cc,max_cc,Bytes,Flops,Runtime,GB/s,GFLOPs
1,CUTLASS,gemm,cutlass_tensorop_tf32_s1688gemm_tf32_256x128_16x3_nn_align4,passed,success,universal,16384,16384,16384,tf32:column,tf32:column,tf32:column,1,0,1,1,tensorop,f32,256,128,16,3,4,2,1,16,8,8,80,1024,3221225472,8796629893120,88.5764,33.8691,99311.2
1,CUTLASS,gemm,cutlass_tensorop_tf32_s1688gemm_tf32_256x128_16x3_nt_align4,passed,success,universal,16384,16384,16384,tf32:column,tf32:row,tf32:column,1,0,1,1,tensorop,f32,256,128,16,3,4,2,1,16,8,8,80,1024,3221225472,8796629893120,70.3503,42.6438,125040
1,CUTLASS,gemm,cutlass_tensorop_tf32_s1688gemm_tf32_256x128_16x3_tn_align4,passed,success,universal,16384,16384,16384,tf32:row,tf32:column,tf32:column,1,0,1,1,tensorop,f32,256,128,16,3,4,2,1,16,8,8,80,1024,3221225472,8796629893120,86.5167,34.6754,101676
1,CUTLASS,gemm,cutlass_tensorop_tf32_s1688gemm_tf32_256x128_16x3_tt_align4,passed,success,universal,16384,16384,16384,tf32:row,tf32:row,tf32:column,1,0,1,1,tensorop,f32,256,128,16,3,4,2,1,16,8,8,80,1024,3221225472,8796629893120,68.3621,43.884,128677
"""
raw_output_FP16_TC = """
CSV Results:
Problem,Provider,OperationKind,Operation,Disposition,Status,gemm_kind,m,n,k,A,B,C,alpha,beta,split_k_slices,batch_count,op_class,accum,cta_m,cta_n,cta_k,stages,warps_m,warps_n,warps_k,inst_m,inst_n,inst_k,min_cc,max_cc,Bytes,Flops,Runtime,GB/s,GFLOPs
1,CUTLASS,gemm,cutlass_tensorop_h16816gemm_256x128_32x3_nn_align8,incorrect,success,universal,16384,16384,16384,f16:column,f16:column,f16:column,1,0,1,1,tensorop,f16,256,128,32,3,4,2,1,16,8,16,80,1024,1610612736,8796629893120,34.1575,43.9142,257531
1,CUTLASS,gemm,cutlass_tensorop_h16816gemm_256x128_32x3_nt_align8,incorrect,success,universal,16384,16384,16384,f16:column,f16:row,f16:column,1,0,1,1,tensorop,f16,256,128,32,3,4,2,1,16,8,16,80,1024,1610612736,8796629893120,34.6153,43.3334,254126
1,CUTLASS,gemm,cutlass_tensorop_h16816gemm_256x128_32x3_tn_align8,incorrect,success,universal,16384,16384,16384,f16:row,f16:column,f16:column,1,0,1,1,tensorop,f16,256,128,32,3,4,2,1,16,8,16,80,1024,1610612736,8796629893120,39.0413,38.4209,225316
1,CUTLASS,gemm,cutlass_tensorop_h16816gemm_256x128_32x3_tt_align8,incorrect,success,universal,16384,16384,16384,f16:row,f16:row,f16:column,1,0,1,1,tensorop,f16,256,128,32,3,4,2,1,16,8,16,80,1024,1610612736,8796629893120,31.2994,47.9243,281048
"""
assert (benchmark._process_raw_result(0, raw_output_FP32))
assert (benchmark._process_raw_result(1, raw_output_TF32_TC))
assert (benchmark._process_raw_result(2, raw_output_FP16_TC))
assert (benchmark.result['FP32'][0] == 18369.7)
assert (benchmark.result['TF32_TC'][0] == 128677)
assert (benchmark.result['FP16_TC'][0] == 281048)
# Negative case - Add invalid raw output.
assert (benchmark._process_raw_result(3, 'Invalid raw output') is False)