Benchmarks: Add Benchmark - Add gemm flops microbenchmark for amd (#152)
**Description** Add gemm flops microbenchmark for amd. **Major Revision** - Add gemm flops microbenchmark for amd. - Add related example and test file.
This commit is contained in:
Родитель
b0df66f7a2
Коммит
f3d53c3d5f
|
@ -0,0 +1,22 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
"""Model benchmark example for XDLOPS GEMM FLOPs performance.
|
||||
|
||||
Commands to run:
|
||||
python3 examples/benchmarks/rocm_gemm_flops_performance.py
|
||||
"""
|
||||
|
||||
from superbench.benchmarks import BenchmarkRegistry, Platform
|
||||
from superbench.common.utils import logger
|
||||
|
||||
if __name__ == '__main__':
|
||||
context = BenchmarkRegistry.create_benchmark_context('gemm-flops', platform=Platform.ROCM)
|
||||
|
||||
benchmark = BenchmarkRegistry.launch_benchmark(context)
|
||||
if benchmark:
|
||||
logger.info(
|
||||
'benchmark: {}, return code: {}, result: {}'.format(
|
||||
benchmark.name, benchmark.return_code, benchmark.result
|
||||
)
|
||||
)
|
|
@ -17,9 +17,11 @@ from superbench.benchmarks.micro_benchmarks.disk_performance import DiskBenchmar
|
|||
from superbench.benchmarks.micro_benchmarks.ib_loopback_performance import IBLoopbackBenchmark
|
||||
from superbench.benchmarks.micro_benchmarks.cuda_nccl_bw_performance import CudaNcclBwBenchmark
|
||||
from superbench.benchmarks.micro_benchmarks.rocm_memory_bw_performance import RocmMemBwBenchmark
|
||||
from superbench.benchmarks.micro_benchmarks.rocm_gemm_flops_performance import RocmGemmFlopsBenchmark
|
||||
|
||||
__all__ = [
|
||||
'MicroBenchmark', 'MicroBenchmarkWithInvoke', 'ShardingMatmul', 'ComputationCommunicationOverlap', 'KernelLaunch',
|
||||
'CublasBenchmark', 'CudnnBenchmark', 'GemmFlopsBenchmark', 'CudaGemmFlopsBenchmark', 'MemBwBenchmark',
|
||||
'CudaMemBwBenchmark', 'DiskBenchmark', 'IBLoopbackBenchmark', 'CudaNcclBwBenchmark', 'RocmMemBwBenchmark'
|
||||
'CudaMemBwBenchmark', 'DiskBenchmark', 'IBLoopbackBenchmark', 'CudaNcclBwBenchmark', 'RocmMemBwBenchmark',
|
||||
'RocmGemmFlopsBenchmark'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,162 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
"""Module of the FLOPs performance benchmarks."""
|
||||
|
||||
import os
|
||||
|
||||
from superbench.common.utils import logger
|
||||
from superbench.benchmarks import BenchmarkRegistry, Platform
|
||||
from superbench.benchmarks.micro_benchmarks import GemmFlopsBenchmark
|
||||
|
||||
|
||||
class RocmGemmFlopsBenchmark(GemmFlopsBenchmark):
|
||||
"""The GEMM FLOPs performance benchmark class."""
|
||||
def __init__(self, name, parameters=''):
|
||||
"""Constructor.
|
||||
|
||||
Args:
|
||||
name (str): benchmark name.
|
||||
parameters (str): benchmark parameters.
|
||||
"""
|
||||
super().__init__(name, parameters)
|
||||
|
||||
self._bin_name = 'rocblas-bench'
|
||||
self._support_precisions = ['FP64', 'FP32_xDLOPS', 'FP16_xDLOPS', 'BF16_xDLOPS', 'INT8_xDLOPS']
|
||||
self.__precision_and_kernel_map = {
|
||||
'FP64': '-r f64_r -f gemm',
|
||||
'FP32_xDLOPS': '-r f32_r -f gemm_ex --compute_type f32_r',
|
||||
'FP16_xDLOPS': '-r f16_r -f gemm_ex --compute_type f32_r',
|
||||
'BF16_xDLOPS': '-r bf16_r -f gemm_ex --compute_type f32_r',
|
||||
'INT8_xDLOPS': '--a_type i8_r --b_type i8_r --c_type i32_r --d_type i32_r -f gemm_ex --compute_type i32_r'
|
||||
}
|
||||
|
||||
def add_parser_arguments(self):
|
||||
"""Add the specified arguments."""
|
||||
super().add_parser_arguments()
|
||||
|
||||
self._parser.add_argument(
|
||||
'--transposeA',
|
||||
type=str.upper,
|
||||
choices=['N', 'T', 'C'],
|
||||
default='N',
|
||||
help='Transpose type of Matrix A, N = no transpose, T = transpose, C = conjugate transpose',
|
||||
)
|
||||
self._parser.add_argument(
|
||||
'--transposeB',
|
||||
type=str.upper,
|
||||
choices=['N', 'T', 'C'],
|
||||
default='T',
|
||||
help='Transpose type of Matrix B, N = no transpose, T = transpose, C = conjugate transpose',
|
||||
)
|
||||
self._parser.add_argument(
|
||||
'--lda',
|
||||
type=int,
|
||||
default=8384,
|
||||
required=False,
|
||||
help='Leading dimension of matrix A.',
|
||||
)
|
||||
self._parser.add_argument(
|
||||
'--ldb',
|
||||
type=int,
|
||||
default=8384,
|
||||
required=False,
|
||||
help='Leading dimension of matrix B.',
|
||||
)
|
||||
self._parser.add_argument(
|
||||
'--ldc',
|
||||
type=int,
|
||||
default=8384,
|
||||
required=False,
|
||||
help='Leading dimension of matrix C.',
|
||||
)
|
||||
self._parser.add_argument(
|
||||
'--ldd',
|
||||
type=int,
|
||||
default=8384,
|
||||
required=False,
|
||||
help='Leading dimension of matrix D.',
|
||||
)
|
||||
self._parser.add_argument(
|
||||
'--alpha',
|
||||
type=int,
|
||||
default=1,
|
||||
required=False,
|
||||
help='Specifies the scalar alpha.',
|
||||
)
|
||||
self._parser.add_argument(
|
||||
'--beta',
|
||||
type=int,
|
||||
default=0,
|
||||
required=False,
|
||||
help='Specifies the scalar beta.',
|
||||
)
|
||||
|
||||
def _preprocess(self):
|
||||
"""Preprocess/preparation operations before the benchmarking.
|
||||
|
||||
Return:
|
||||
True if _preprocess() succeed.
|
||||
"""
|
||||
if not super()._preprocess():
|
||||
return False
|
||||
|
||||
for p in self._precision_need_to_run:
|
||||
command = os.path.join(self._args.bin_dir, self._bin_name)
|
||||
command += ' ' + self.__precision_and_kernel_map[p]
|
||||
command += ' --transposeA {} --transposeB {}'.format(self._args.transposeA, self._args.transposeB)
|
||||
command += ' -m {} -n {} -k {}'.format(self._args.m, self._args.n, self._args.k)
|
||||
command += ' --alpha {} --beta {}'.format(self._args.alpha, self._args.beta)
|
||||
command += ' --lda {} --ldb {} --ldc {} --ldd {}'.format(
|
||||
self._args.lda, self._args.ldb, self._args.ldc, self._args.ldd
|
||||
)
|
||||
self._commands.append(command)
|
||||
|
||||
return True
|
||||
|
||||
def _process_raw_result(self, cmd_idx, raw_output):
|
||||
"""Function to parse raw results and save the summarized results.
|
||||
|
||||
self._result.add_raw_data() and self._result.add_result() need to be called to save the results.
|
||||
|
||||
Args:
|
||||
cmd_idx (int): the index of command corresponding with the raw_output.
|
||||
raw_output (str): raw output string of the micro-benchmark.
|
||||
|
||||
Return:
|
||||
True if the raw output string is valid and result can be extracted.
|
||||
"""
|
||||
precision = self._precision_need_to_run[cmd_idx]
|
||||
self._result.add_raw_data('raw_output_' + precision, raw_output)
|
||||
|
||||
content = raw_output.splitlines()
|
||||
gflops_index = None
|
||||
gflops = -1
|
||||
|
||||
for line in content:
|
||||
try:
|
||||
if 'rocblas-Gflops' in line:
|
||||
line = line.split(',')
|
||||
gflops_index = line.index('rocblas-Gflops')
|
||||
if gflops_index is not None:
|
||||
line = line.split(',')
|
||||
gflops = float(line[gflops_index])
|
||||
if gflops != -1:
|
||||
break
|
||||
except BaseException:
|
||||
pass
|
||||
|
||||
if gflops == -1:
|
||||
logger.error(
|
||||
'The result format is invalid - round: {}, benchmark: {}, raw output: {}.'.format(
|
||||
self._curr_run_index, self._name, raw_output
|
||||
)
|
||||
)
|
||||
return False
|
||||
|
||||
self._result.add_result(precision, gflops)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
BenchmarkRegistry.register_benchmark('gemm-flops', RocmGemmFlopsBenchmark, platform=Platform.ROCM)
|
|
@ -0,0 +1,102 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""Tests for gemm-flops benchmark."""
|
||||
|
||||
import os
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from superbench.benchmarks import BenchmarkRegistry, ReturnCode, Platform, BenchmarkType
|
||||
|
||||
|
||||
class RocmGemmFlopsTest(unittest.TestCase):
|
||||
"""Tests for RocmGemmFlops benchmark."""
|
||||
def setUp(self):
|
||||
"""Method called to prepare the test fixture."""
|
||||
# Create fake binary file just for testing.
|
||||
os.environ['SB_MICRO_PATH'] = '/tmp/superbench/'
|
||||
binary_path = os.path.join(os.getenv('SB_MICRO_PATH'), 'bin')
|
||||
Path(binary_path).mkdir(parents=True, exist_ok=True)
|
||||
self.__binary_file = Path(os.path.join(binary_path, 'rocblas-bench'))
|
||||
self.__binary_file.touch(mode=0o755, exist_ok=True)
|
||||
|
||||
def tearDown(self):
|
||||
"""Method called after the test method has been called and the result recorded."""
|
||||
self.__binary_file.unlink()
|
||||
|
||||
def test_rocm_flops_performance(self):
|
||||
"""Test gemm-flops benchmark."""
|
||||
benchmark_name = 'gemm-flops'
|
||||
(benchmark_class,
|
||||
predefine_params) = BenchmarkRegistry._BenchmarkRegistry__select_benchmark(benchmark_name, Platform.ROCM)
|
||||
assert (benchmark_class)
|
||||
|
||||
# Negative case - MICROBENCHMARK_UNSUPPORTED_ARCHITECTURE.
|
||||
benchmark = benchmark_class(benchmark_name, parameters='--m 7680 --n 8192 --k 8192')
|
||||
|
||||
ret = benchmark._preprocess()
|
||||
assert (ret is True)
|
||||
assert (benchmark.return_code == ReturnCode.SUCCESS)
|
||||
|
||||
# Check basic information.
|
||||
assert (benchmark.name == 'gemm-flops')
|
||||
assert (benchmark.type == BenchmarkType.MICRO)
|
||||
assert (benchmark._bin_name == 'rocblas-bench')
|
||||
|
||||
# Check parameters specified in BenchmarkContext.
|
||||
assert (benchmark._args.m == 7680)
|
||||
assert (benchmark._args.n == 8192)
|
||||
assert (benchmark._args.k == 8192)
|
||||
|
||||
params = '--transposeA N --transposeB T -m 7680 -n 8192 -k 8192' + \
|
||||
' --alpha 1 --beta 0 --lda 8384 --ldb 8384 --ldc 8384 --ldd 8384'
|
||||
# Check command list
|
||||
expected_command = [
|
||||
'rocblas-bench -r f64_r -f gemm ' + params,
|
||||
'rocblas-bench -r f32_r -f gemm_ex --compute_type f32_r ' + params,
|
||||
'rocblas-bench -r f16_r -f gemm_ex --compute_type f32_r ' + params,
|
||||
'rocblas-bench -r bf16_r -f gemm_ex --compute_type f32_r ' + params,
|
||||
'rocblas-bench --a_type i8_r --b_type i8_r --c_type i32_r --d_type i32_r -f gemm_ex --compute_type i32_r ' +
|
||||
params
|
||||
]
|
||||
for i in range(len(expected_command)):
|
||||
commnad = benchmark._bin_name + benchmark._commands[i].split(benchmark._bin_name)[1]
|
||||
print(benchmark._commands)
|
||||
assert (commnad == expected_command[i])
|
||||
|
||||
# Check results and metrics.
|
||||
raw_output_FP64 = """
|
||||
transA,transB,M,N,K,alpha,lda,beta,ldb,ldc,rocblas-Gflops,us
|
||||
N,T,7680,8192,8192,1,8384,0,8384,8384, 10037.5, 102694
|
||||
"""
|
||||
raw_output_FP32_X = """
|
||||
transA,transB,M,N,K,alpha,lda,beta,ldb,ldc,ldd,batch_count,rocblas-Gflops,us
|
||||
N,T,8640,8640,8640,1,8640,0,8640,8640,8640,1, 39441.6, 32705.2
|
||||
"""
|
||||
raw_output_FP16_X = """
|
||||
transA,transB,M,N,K,alpha,lda,beta,ldb,ldc,ldd,batch_count,rocblas-Gflops,us
|
||||
N,T,7680,8192,8192,1,8384,0,8384,8384,8384,1, 153728, 6705.3
|
||||
"""
|
||||
raw_output_BF16_X = """
|
||||
transA,transB,M,N,K,alpha,lda,beta,ldb,ldc,ldd,batch_count,rocblas-Gflops,us
|
||||
N,T,7680,8192,8192,1,8384,0,8384,8384,8384,1, 81374.3, 12667.3
|
||||
"""
|
||||
raw_output_INT8_X = """
|
||||
transA,transB,M,N,K,alpha,lda,beta,ldb,ldc,ldd,batch_count,rocblas-Gflops,us
|
||||
T,N,7680,8192,8192,1,8416,0,8416,8416,8416,1, 162675, 6336.5
|
||||
"""
|
||||
assert (benchmark._process_raw_result(0, raw_output_FP64))
|
||||
assert (benchmark._process_raw_result(1, raw_output_FP32_X))
|
||||
assert (benchmark._process_raw_result(2, raw_output_FP16_X))
|
||||
assert (benchmark._process_raw_result(3, raw_output_BF16_X))
|
||||
assert (benchmark._process_raw_result(4, raw_output_INT8_X))
|
||||
|
||||
assert (benchmark.result['FP64'][0] == 10037.5)
|
||||
assert (benchmark.result['FP32_xDLOPS'][0] == 39441.6)
|
||||
assert (benchmark.result['FP16_xDLOPS'][0] == 153728)
|
||||
assert (benchmark.result['BF16_xDLOPS'][0] == 81374.3)
|
||||
assert (benchmark.result['INT8_xDLOPS'][0] == 162675)
|
||||
|
||||
# Negative case - Add invalid raw output.
|
||||
assert (benchmark._process_raw_result(4, 'Invalid raw output') is False)
|
Загрузка…
Ссылка в новой задаче