Benchmarks: Micro benchmark - Add hipBLASLt function benchmark (#576)
**Description** hipblaslt function benchmark and rebase cublaslt function benchmark.
This commit is contained in:
Родитель
9f4880cb8e
Коммит
79089b6517
|
@ -33,7 +33,7 @@ steps:
|
|||
- script: |
|
||||
SB_MICRO_PATH=$PWD python3 setup.py test
|
||||
displayName: Run unit tests
|
||||
timeoutInMinutes: 30
|
||||
timeoutInMinutes: 60
|
||||
- script: |
|
||||
bash <(curl -s https://codecov.io/bash) -cF cuda-unit-test
|
||||
displayName: Report coverage results
|
||||
|
|
|
@ -136,7 +136,7 @@ RUN echo PATH="$PATH" > /etc/environment && \
|
|||
WORKDIR ${SB_HOME}
|
||||
|
||||
ADD third_party third_party
|
||||
RUN make -C third_party rocm
|
||||
RUN make -C third_party rocm -o rocm_hipblaslt
|
||||
|
||||
ADD . .
|
||||
RUN python3 -m pip install --upgrade setuptools==65.7 && \
|
||||
|
|
|
@ -141,7 +141,7 @@ RUN echo PATH="$PATH" > /etc/environment && \
|
|||
WORKDIR ${SB_HOME}
|
||||
|
||||
ADD third_party third_party
|
||||
RUN make ROCBLAS_BRANCH=release/rocm-rel-5.1 -C third_party rocm
|
||||
RUN make ROCBLAS_BRANCH=release/rocm-rel-5.1 -C third_party rocm -o rocm_hipblaslt
|
||||
|
||||
ADD . .
|
||||
RUN python3 -m pip install --no-cache-dir .[amdworker] && \
|
||||
|
|
|
@ -9,7 +9,9 @@ from superbench.benchmarks.micro_benchmarks.memory_bw_performance_base import Me
|
|||
|
||||
from superbench.benchmarks.micro_benchmarks.computation_communication_overlap import ComputationCommunicationOverlap
|
||||
from superbench.benchmarks.micro_benchmarks.cublas_function import CublasBenchmark
|
||||
from superbench.benchmarks.micro_benchmarks.blaslt_function_base import BlasLtBaseBenchmark
|
||||
from superbench.benchmarks.micro_benchmarks.cublaslt_function import CublasLtBenchmark
|
||||
from superbench.benchmarks.micro_benchmarks.hipblaslt_function import HipBlasLtBenchmark
|
||||
from superbench.benchmarks.micro_benchmarks.cuda_gemm_flops_performance import CudaGemmFlopsBenchmark
|
||||
from superbench.benchmarks.micro_benchmarks.cuda_memory_bw_performance import CudaMemBwBenchmark
|
||||
from superbench.benchmarks.micro_benchmarks.cuda_nccl_bw_performance import CudaNcclBwBenchmark
|
||||
|
@ -37,6 +39,7 @@ from superbench.benchmarks.micro_benchmarks.directx_mem_bw_performance import Di
|
|||
from superbench.benchmarks.micro_benchmarks.directx_gemm_flops_performance import DirectXGPUCoreFlops
|
||||
|
||||
__all__ = [
|
||||
'BlasLtBaseBenchmark',
|
||||
'ComputationCommunicationOverlap',
|
||||
'CpuMemBwLatencyBenchmark',
|
||||
'CpuHplBenchmark',
|
||||
|
@ -49,6 +52,7 @@ __all__ = [
|
|||
'CudnnBenchmark',
|
||||
'DiskBenchmark',
|
||||
'DistInference',
|
||||
'HipBlasLtBenchmark',
|
||||
'GPCNetBenchmark',
|
||||
'GemmFlopsBenchmark',
|
||||
'GpuBurnBenchmark',
|
||||
|
|
|
@ -0,0 +1,141 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
"""Module of the BLASLt GEMM Base Class."""
|
||||
import itertools
|
||||
|
||||
from superbench.common.utils import logger
|
||||
from superbench.benchmarks.micro_benchmarks import MicroBenchmarkWithInvoke
|
||||
|
||||
|
||||
def mrange(start, stop=-1, multiplication_factor=2, symbol='x'):
|
||||
"""Range constructor with multiplication factor.
|
||||
|
||||
Args:
|
||||
start (int): Start number.
|
||||
stop (int, optional): Stop number. Defaults to -1.
|
||||
multiplication_factor (int, optional): Multiplication factor. Defaults to 2.
|
||||
symbol (str, optional): Symbol. Defaults to 'x' (multiplication).
|
||||
|
||||
Yields:
|
||||
int: number in the range.
|
||||
"""
|
||||
if symbol == 'x':
|
||||
while True:
|
||||
yield start
|
||||
start *= multiplication_factor
|
||||
if start > stop or start == 0 or multiplication_factor < 2:
|
||||
break
|
||||
elif symbol == '+':
|
||||
while True:
|
||||
yield start
|
||||
start = start + multiplication_factor
|
||||
if start > stop or start == 0 or multiplication_factor < 1:
|
||||
break
|
||||
else:
|
||||
raise ValueError(f'Invalid symbol {symbol}.')
|
||||
|
||||
|
||||
def validate_mrange(string):
|
||||
"""Validate mrange string in format start[[:stop]:multiplication_factor].
|
||||
|
||||
Args:
|
||||
string (str): mrange string.
|
||||
|
||||
Returns:
|
||||
bool: whether the mrange is expected.
|
||||
"""
|
||||
nums = string.split(':')
|
||||
if len(nums) > 3:
|
||||
return False
|
||||
|
||||
if len(nums) < 3:
|
||||
return all(x.isdigit() for x in nums)
|
||||
return nums[0].isdigit() and nums[1].isdigit() and (nums[2].lstrip('+').isdigit() or nums[2].lstrip('x').isdigit())
|
||||
|
||||
|
||||
class BlasLtBaseBenchmark(MicroBenchmarkWithInvoke):
|
||||
"""The BLASLt GEMM Base class."""
|
||||
def __init__(self, name, parameters=''):
|
||||
"""Constructor.
|
||||
|
||||
Args:
|
||||
name (str): benchmark name.
|
||||
parameters (str): benchmark parameters.
|
||||
"""
|
||||
super().__init__(name, parameters)
|
||||
|
||||
def add_parser_arguments(self):
|
||||
"""Add the specified arguments."""
|
||||
super().add_parser_arguments()
|
||||
|
||||
self._parser.add_argument(
|
||||
'--shapes',
|
||||
type=str,
|
||||
nargs='+',
|
||||
default=[f'{x},{x},{x}' for x in [2048, 4096, 8192]],
|
||||
help='Shapes in m,n,k format. Support format start:stop:multiplication_factor, e.g., 16:128:2.',
|
||||
)
|
||||
self._parser.add_argument(
|
||||
'--batch',
|
||||
type=str,
|
||||
default='0',
|
||||
required=False,
|
||||
help=(
|
||||
'Batch size for strided batch GEMM, set 0 to disable.'
|
||||
' Support format start:stop:multiplication_factor, e.g., 16:128:2.'
|
||||
),
|
||||
)
|
||||
self._parser.add_argument(
|
||||
'--num_warmup',
|
||||
type=int,
|
||||
default=20,
|
||||
required=False,
|
||||
help='Number of warm up steps.',
|
||||
)
|
||||
self._parser.add_argument(
|
||||
'--num_steps',
|
||||
type=int,
|
||||
default=50,
|
||||
required=False,
|
||||
help='Number of steps to measure.',
|
||||
)
|
||||
|
||||
def _preprocess(self):
|
||||
"""Preprocess/preparation operations before the benchmarking.
|
||||
|
||||
Return:
|
||||
True if _preprocess() succeed.
|
||||
"""
|
||||
if not super()._preprocess():
|
||||
return False
|
||||
|
||||
if not validate_mrange(self._args.batch):
|
||||
logger.error(f'Invalid batch size {self._args.batch}.')
|
||||
return False
|
||||
|
||||
for _in_type in self._args.in_types:
|
||||
if _in_type not in self._in_types:
|
||||
logger.error(f'Invalid input type {_in_type}.')
|
||||
return False
|
||||
|
||||
self._shapes_to_run = []
|
||||
for _in_type in self._args.in_types:
|
||||
for _b in mrange(*map(int, self._args.batch.split(':'))):
|
||||
for shape in self._args.shapes:
|
||||
shape_list = shape.replace(',', ' ').split()
|
||||
if len(shape_list) != 3 or not all(validate_mrange(x) for x in shape_list):
|
||||
logger.error(f'Invalid shape {shape}.')
|
||||
return False
|
||||
for _m, _n, _k in itertools.product(
|
||||
*map(
|
||||
lambda shape: mrange(
|
||||
*map(lambda dim: int(dim.lstrip('+').lstrip('x')), shape.split(':')),
|
||||
symbol=shape.split(':')[2][0]
|
||||
if len(shape.split(':')) == 3 and any([i in shape for i in ['+', 'x']]) else 'x'
|
||||
), shape_list
|
||||
)
|
||||
):
|
||||
self._shapes_to_run.append((_m, _n, _k, _b, _in_type))
|
||||
|
||||
return True
|
|
@ -4,14 +4,13 @@
|
|||
"""Module of the cuBLASLt GEMM benchmark."""
|
||||
|
||||
import os
|
||||
import itertools
|
||||
|
||||
from superbench.common.utils import logger
|
||||
from superbench.benchmarks import BenchmarkRegistry, Platform, ReturnCode
|
||||
from superbench.benchmarks.micro_benchmarks import MicroBenchmarkWithInvoke
|
||||
from superbench.benchmarks.micro_benchmarks import BlasLtBaseBenchmark
|
||||
|
||||
|
||||
class CublasLtBenchmark(MicroBenchmarkWithInvoke):
|
||||
class CublasLtBenchmark(BlasLtBaseBenchmark):
|
||||
"""The cuBLASLt GEMM benchmark class."""
|
||||
def __init__(self, name, parameters=''):
|
||||
"""Constructor.
|
||||
|
@ -25,72 +24,10 @@ class CublasLtBenchmark(MicroBenchmarkWithInvoke):
|
|||
self._bin_name = 'cublaslt_gemm'
|
||||
self._in_types = ['fp64', 'fp32', 'fp16', 'bf16', 'fp8e4m3', 'fp8e5m2', 'int8']
|
||||
|
||||
def mrange(self, start, stop=-1, multiplication_factor=2):
|
||||
"""Range constructor with multiplication factor.
|
||||
|
||||
Args:
|
||||
start (int): Start number.
|
||||
stop (int, optional): Stop number. Defaults to -1.
|
||||
multiplication_factor (int, optional): Multiplication factor. Defaults to 2.
|
||||
|
||||
Yields:
|
||||
int: number in the range.
|
||||
"""
|
||||
while True:
|
||||
yield start
|
||||
start *= multiplication_factor
|
||||
if start > stop or start == 0 or multiplication_factor < 2:
|
||||
break
|
||||
|
||||
def validate_mrange(self, string):
|
||||
"""Validate mrange string in format start[[:stop]:multiplication_factor].
|
||||
|
||||
Args:
|
||||
string (str): mrange string.
|
||||
|
||||
Returns:
|
||||
bool: whether the mrange is expected.
|
||||
"""
|
||||
nums = string.split(':')
|
||||
if len(nums) > 3:
|
||||
return False
|
||||
return bool(all(x.isdigit() for x in nums))
|
||||
|
||||
def add_parser_arguments(self):
|
||||
"""Add the specified arguments."""
|
||||
super().add_parser_arguments()
|
||||
|
||||
self._parser.add_argument(
|
||||
'--shapes',
|
||||
type=str,
|
||||
nargs='+',
|
||||
default=[f'{x},{x},{x}' for x in [2048, 4096, 8192]],
|
||||
help='Shapes in m,n,k format. Support format start:stop:multiplication_factor, e.g., 16:128:2.',
|
||||
)
|
||||
self._parser.add_argument(
|
||||
'--batch',
|
||||
type=str,
|
||||
default='0',
|
||||
required=False,
|
||||
help=(
|
||||
'Batch size for strided batch GEMM, set 0 to disable.'
|
||||
' Support format start:stop:multiplication_factor, e.g., 16:128:2.'
|
||||
),
|
||||
)
|
||||
self._parser.add_argument(
|
||||
'--num_warmup',
|
||||
type=int,
|
||||
default=20,
|
||||
required=False,
|
||||
help='Number of warm up steps.',
|
||||
)
|
||||
self._parser.add_argument(
|
||||
'--num_steps',
|
||||
type=int,
|
||||
default=50,
|
||||
required=False,
|
||||
help='Number of steps to measure.',
|
||||
)
|
||||
self._parser.add_argument(
|
||||
'--in_types',
|
||||
type=str,
|
||||
|
@ -111,28 +48,12 @@ class CublasLtBenchmark(MicroBenchmarkWithInvoke):
|
|||
|
||||
self.__bin_path = os.path.join(self._args.bin_dir, self._bin_name)
|
||||
|
||||
if not self.validate_mrange(self._args.batch):
|
||||
logger.error(f'Invalid batch size {self._args.batch}.')
|
||||
return False
|
||||
|
||||
self._commands = []
|
||||
for _in_type in self._args.in_types:
|
||||
if _in_type not in self._in_types:
|
||||
logger.error(f'Invalid input type {_in_type}.')
|
||||
return False
|
||||
for _b in self.mrange(*map(int, self._args.batch.split(':'))):
|
||||
for shape in self._args.shapes:
|
||||
shape_list = shape.replace(',', ' ').split()
|
||||
if len(shape_list) != 3 or not all(self.validate_mrange(x) for x in shape_list):
|
||||
logger.error(f'Invalid shape {shape}.')
|
||||
return False
|
||||
for _m, _n, _k in itertools.product(
|
||||
*map(lambda shape: self.mrange(*map(int, shape.split(':'))), shape_list)
|
||||
):
|
||||
self._commands.append(
|
||||
f'{self.__bin_path} -m {_m} -n {_n} -k {_k} -b {_b} '
|
||||
f'-w {self._args.num_warmup} -i {self._args.num_steps} -t {_in_type}'
|
||||
)
|
||||
for _m, _n, _k, _b, _in_type in self._shapes_to_run:
|
||||
self._commands.append(
|
||||
f'{self.__bin_path} -m {_m} -n {_n} -k {_k} -b {_b} '
|
||||
f'-w {self._args.num_warmup} -i {self._args.num_steps} -t {_in_type}'
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
|
|
|
@ -0,0 +1,120 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
"""Module of the hipBlasLt GEMM benchmark."""
|
||||
|
||||
import os
|
||||
import re
|
||||
|
||||
from superbench.common.utils import logger
|
||||
from superbench.benchmarks import BenchmarkRegistry, Platform, ReturnCode
|
||||
from superbench.benchmarks.micro_benchmarks import BlasLtBaseBenchmark
|
||||
|
||||
|
||||
class HipBlasLtBenchmark(BlasLtBaseBenchmark):
|
||||
"""The hipBlasLt GEMM benchmark class."""
|
||||
def __init__(self, name, parameters=''):
|
||||
"""Constructor.
|
||||
|
||||
Args:
|
||||
name (str): benchmark name.
|
||||
parameters (str): benchmark parameters.
|
||||
"""
|
||||
super().__init__(name, parameters)
|
||||
|
||||
self._bin_name = 'hipblaslt-bench'
|
||||
self._in_types = ['fp32', 'fp16', 'bf16']
|
||||
self._in_type_map = {
|
||||
'fp16': '--a_type f16_r --b_type f16_r --c_type f16_r --d_type f16_r --compute_type f32_r',
|
||||
'fp32': '--a_type f32_r --b_type f32_r --c_type f32_r --d_type f32_r --compute_type f32_r',
|
||||
'bf16': '--a_type bf16_r --b_type bf16_r --c_type bf16_r --d_type bf16_r --compute_type f32_r',
|
||||
}
|
||||
|
||||
def add_parser_arguments(self):
|
||||
"""Add the specified arguments."""
|
||||
super().add_parser_arguments()
|
||||
|
||||
self._parser.add_argument(
|
||||
'--in_types',
|
||||
type=str,
|
||||
nargs='+',
|
||||
default=['fp16'],
|
||||
required=False,
|
||||
help='List of input data types, support {}.'.format(' '.join(self._in_types)),
|
||||
)
|
||||
|
||||
def _preprocess(self):
|
||||
"""Preprocess/preparation operations before the benchmarking.
|
||||
|
||||
Return:
|
||||
True if _preprocess() succeed.
|
||||
"""
|
||||
if not super()._preprocess():
|
||||
return False
|
||||
|
||||
self.__bin_path = os.path.join(self._args.bin_dir, self._bin_name)
|
||||
|
||||
self._commands = []
|
||||
self._precision_in_commands = []
|
||||
for (_m, _n, _k, _b, _in_type) in self._shapes_to_run:
|
||||
command = f'{self.__bin_path} -m {_m} -n {_n} -k {_k} -j {self._args.num_warmup}' + \
|
||||
f' -i {self._args.num_steps} {self._in_type_map[_in_type]}'
|
||||
command = command + f' -b {str(_b)}' if _b > 0 else command
|
||||
logger.info(command)
|
||||
self._commands.append(command)
|
||||
self._precision_in_commands.append(_in_type)
|
||||
|
||||
return True
|
||||
|
||||
def _process_raw_result(self, cmd_idx, raw_output):
|
||||
"""Function to parse raw results and save the summarized results.
|
||||
|
||||
self._result.add_raw_data() and self._result.add_result() need to be called to save the results.
|
||||
|
||||
Args:
|
||||
cmd_idx (int): the index of command corresponding with the raw_output.
|
||||
raw_output (str): raw output string of the micro-benchmark.
|
||||
|
||||
Return:
|
||||
True if the raw output string is valid and result can be extracted.
|
||||
"""
|
||||
self._result.add_raw_data(f'raw_output_{cmd_idx}', raw_output, self._args.log_raw_data)
|
||||
|
||||
try:
|
||||
lines = raw_output.splitlines()
|
||||
index = None
|
||||
|
||||
# Find the line containing 'hipblaslt-Gflops'
|
||||
for i, line in enumerate(lines):
|
||||
if 'hipblaslt-Gflops' in line:
|
||||
index = i
|
||||
break
|
||||
|
||||
if index is None:
|
||||
raise ValueError('Line with "hipblaslt-Gflops" not found in the log.')
|
||||
|
||||
# Split the line into fields using a comma as the delimiter
|
||||
fields = lines[index + 1].strip().split(',')
|
||||
|
||||
# Check the number of fields and the format of the first two fields
|
||||
if len(fields) != 23 or not all(
|
||||
re.match(r'\d*\.\d*$', item.strip()) or item.strip().isdigit() for item in fields[-2:]
|
||||
):
|
||||
raise ValueError('Invalid result')
|
||||
|
||||
self._result.add_result(
|
||||
f'{self._precision_in_commands[cmd_idx]}_{fields[3]}_{"_".join(fields[4:7])}_flops', float(fields[-2])
|
||||
)
|
||||
except BaseException as e:
|
||||
self._result.set_return_code(ReturnCode.MICROBENCHMARK_RESULT_PARSING_FAILURE)
|
||||
logger.error(
|
||||
'The result format is invalid - round: {}, benchmark: {}, raw output: {}, message: {}.'.format(
|
||||
self._curr_run_index, self._name, raw_output, str(e)
|
||||
)
|
||||
)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
BenchmarkRegistry.register_benchmark('hipblaslt-gemm', HipBlasLtBenchmark, platform=Platform.ROCM)
|
|
@ -9,6 +9,7 @@ from types import GeneratorType, SimpleNamespace
|
|||
from tests.helper.testcase import BenchmarkTestCase
|
||||
from superbench.benchmarks import BenchmarkRegistry, BenchmarkType, ReturnCode, Platform
|
||||
from superbench.benchmarks.result import BenchmarkResult
|
||||
from superbench.benchmarks.micro_benchmarks.blaslt_function_base import mrange, validate_mrange
|
||||
|
||||
|
||||
class CublasLtBenchmarkTestCase(BenchmarkTestCase, unittest.TestCase):
|
||||
|
@ -37,26 +38,28 @@ class CublasLtBenchmarkTestCase(BenchmarkTestCase, unittest.TestCase):
|
|||
|
||||
def test_mrange(self):
|
||||
"""Test mrange generation."""
|
||||
benchmark = self.get_benchmark()
|
||||
self.assertIsInstance(benchmark.mrange(1), GeneratorType)
|
||||
self.assertListEqual([4, 8, 16, 32], list(benchmark.mrange(4, 32, 2)))
|
||||
self.assertListEqual([2, 4, 8, 16], list(benchmark.mrange(2, 31, 2)))
|
||||
self.assertListEqual([2, 4, 8], list(benchmark.mrange(2, 8)))
|
||||
self.assertListEqual([2], list(benchmark.mrange(2, 0, 2)))
|
||||
self.assertListEqual([2], list(benchmark.mrange(2)))
|
||||
self.assertListEqual([2], list(benchmark.mrange(2, 4, 1)))
|
||||
self.assertListEqual([2], list(benchmark.mrange(2, 4, 0)))
|
||||
self.assertListEqual([0], list(benchmark.mrange(0, 0)))
|
||||
self.assertListEqual([0], list(benchmark.mrange(0)))
|
||||
self.assertIsInstance(mrange(1), GeneratorType)
|
||||
self.assertListEqual([4, 8, 16, 32], list(mrange(4, 32, 2)))
|
||||
self.assertListEqual([2, 4, 8, 16], list(mrange(2, 31, 2)))
|
||||
self.assertListEqual([2, 4, 8], list(mrange(2, 8)))
|
||||
self.assertListEqual([2], list(mrange(2, 0, 2)))
|
||||
self.assertListEqual([2], list(mrange(2)))
|
||||
self.assertListEqual([2], list(mrange(2, 4, 1)))
|
||||
self.assertListEqual([2], list(mrange(2, 4, 0)))
|
||||
self.assertListEqual([0], list(mrange(0, 0)))
|
||||
self.assertListEqual([0], list(mrange(0)))
|
||||
self.assertListEqual([4, 8, 16, 32], list(mrange(4, 32, 2, 'x')))
|
||||
self.assertListEqual([4, 8, 12, 16, 20, 24, 28, 32], list(mrange(4, 32, 4, '+')))
|
||||
|
||||
def test_validate_mrange(self):
|
||||
"""Test mrange validation."""
|
||||
benchmark = self.get_benchmark()
|
||||
self.assertTrue(benchmark.validate_mrange('2:32:2'))
|
||||
self.assertTrue(benchmark.validate_mrange('4:32'))
|
||||
self.assertTrue(benchmark.validate_mrange('8'))
|
||||
self.assertFalse(benchmark.validate_mrange('2:32:2:4'))
|
||||
self.assertFalse(benchmark.validate_mrange('2.5:32'))
|
||||
self.assertTrue(validate_mrange('2:32:2'))
|
||||
self.assertTrue(validate_mrange('4:32'))
|
||||
self.assertTrue(validate_mrange('8'))
|
||||
self.assertFalse(validate_mrange('2:32:2:4'))
|
||||
self.assertFalse(validate_mrange('2.5:32'))
|
||||
self.assertFalse(validate_mrange('2:32:2:x4'))
|
||||
self.assertFalse(validate_mrange('2:32:2:+4'))
|
||||
|
||||
def test_cublaslt_gemm_command_generation(self):
|
||||
"""Test cublaslt-gemm benchmark command generation."""
|
||||
|
|
|
@ -0,0 +1,108 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""Tests for hipblaslt-bench benchmark."""
|
||||
|
||||
import unittest
|
||||
from types import SimpleNamespace
|
||||
|
||||
from tests.helper.testcase import BenchmarkTestCase
|
||||
from superbench.benchmarks import BenchmarkRegistry, BenchmarkType, ReturnCode, Platform
|
||||
from superbench.benchmarks.result import BenchmarkResult
|
||||
|
||||
|
||||
class HipblasLtBenchmarkTestCase(BenchmarkTestCase, unittest.TestCase):
|
||||
"""Class for hipblaslt-bench benchmark test cases."""
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
"""Hook method for setting up class fixture before running tests in the class."""
|
||||
super().setUpClass()
|
||||
cls.benchmark_name = 'hipblaslt-gemm'
|
||||
cls.createMockEnvs(cls)
|
||||
cls.createMockFiles(cls, ['bin/hipblaslt-bench'])
|
||||
|
||||
def get_benchmark(self):
|
||||
"""Get Benchmark."""
|
||||
(benchmark_cls, _) = BenchmarkRegistry._BenchmarkRegistry__select_benchmark(self.benchmark_name, Platform.ROCM)
|
||||
return benchmark_cls(self.benchmark_name, parameters='')
|
||||
|
||||
def test_hipblaslt_gemm_cls(self):
|
||||
"""Test hipblaslt-bench benchmark class."""
|
||||
for platform in Platform:
|
||||
(benchmark_cls, _) = BenchmarkRegistry._BenchmarkRegistry__select_benchmark(self.benchmark_name, platform)
|
||||
if platform is Platform.ROCM:
|
||||
self.assertIsNotNone(benchmark_cls)
|
||||
else:
|
||||
self.assertIsNone(benchmark_cls)
|
||||
|
||||
def test_hipblaslt_gemm_command_generation(self):
|
||||
"""Test hipblaslt-bench benchmark command generation."""
|
||||
(benchmark_cls, _) = BenchmarkRegistry._BenchmarkRegistry__select_benchmark(self.benchmark_name, Platform.ROCM)
|
||||
benchmark = benchmark_cls(
|
||||
self.benchmark_name,
|
||||
parameters='--batch 4:2:-1 --shapes 2,4,8 --in_types fp16 fp32 fp64 int8',
|
||||
)
|
||||
self.assertFalse(benchmark._preprocess())
|
||||
benchmark = benchmark_cls(
|
||||
self.benchmark_name,
|
||||
parameters=' --shapes 2,4,8 --in_types fp16 fp32 fp64 int8',
|
||||
)
|
||||
self.assertFalse(benchmark._preprocess())
|
||||
benchmark = benchmark_cls(
|
||||
self.benchmark_name,
|
||||
parameters=' --shapes 2:4,4:8 --in_types fp16 fp32',
|
||||
)
|
||||
self.assertFalse(benchmark._preprocess())
|
||||
benchmark = benchmark_cls(
|
||||
self.benchmark_name,
|
||||
parameters='--shapes 2:4,4:8,8:32 2:4,4:8,8:32:+4 --in_types fp16 fp32 bf16',
|
||||
)
|
||||
self.assertTrue(benchmark._preprocess())
|
||||
self.assertEqual((2 * 2 * 3 + 2 * 2 * 7) * len(benchmark._args.in_types), len(benchmark._commands))
|
||||
|
||||
def cmd(t, b, m, n, k):
|
||||
if b == 0:
|
||||
return f'{benchmark._HipBlasLtBenchmark__bin_path} ' + \
|
||||
f'-m {m} -n {n} -k {k} -j 20 -i 50 {benchmark._in_type_map[t]}'
|
||||
else:
|
||||
return f'{benchmark._HipBlasLtBenchmark__bin_path} ' + \
|
||||
f'-m {m} -n {n} -k {k} -j 20 -i 50 {benchmark._in_type_map[t]} -b {b}'
|
||||
|
||||
for _t in ['fp16', 'fp32', 'bf16']:
|
||||
for _m in [2, 4]:
|
||||
for _n in [4, 8]:
|
||||
for _k in [8, 16, 32]:
|
||||
self.assertIn(cmd(_t, 0, _m, _n, _k), benchmark._commands)
|
||||
for _k in [8, 12, 16, 20, 24, 28, 32]:
|
||||
self.assertIn(cmd(_t, 0, _m, _n, _k), benchmark._commands)
|
||||
|
||||
def test_hipblaslt_gemm_result_parsing(self):
|
||||
"""Test hipblaslt-bench benchmark result parsing."""
|
||||
benchmark = self.get_benchmark()
|
||||
self.assertTrue(benchmark._preprocess())
|
||||
benchmark._args = SimpleNamespace(shapes=['896,896,896'], in_types=['fp16'], log_raw_data=False)
|
||||
benchmark._result = BenchmarkResult(self.benchmark_name, BenchmarkType.MICRO, ReturnCode.SUCCESS, run_count=1)
|
||||
|
||||
example_raw_output = """
|
||||
hipBLASLt version: 600
|
||||
hipBLASLt git version: 52776da
|
||||
Query device success: there are 1 devices
|
||||
-------------------------------------------------------------------------------
|
||||
Device ID 0 : AMD Radeon Graphics gfx942:sramecc+:xnack-
|
||||
with 206.1 GB memory, max. SCLK 2100 MHz, max. MCLK 1300 MHz, compute capability 9.4
|
||||
maxGridDimX 2147483647, sharedMemPerBlock 65.5 KB, maxThreadsPerBlock 1024, warpSize 64
|
||||
-------------------------------------------------------------------------------
|
||||
|
||||
Is supported 1 / Total solutions: 1
|
||||
[0]transA,transB,grouped_gemm,batch_count,m,n,k,alpha,lda,stride_a,beta,ldb,stride_b,ldc,stride_c,ldd,stride_d,d_type,compute_type,activation_type,bias_vector,hipblaslt-Gflops,us
|
||||
N,N,0,1,896,896,896,1,896,802816,0,896,802816,896,802816,896,802816,fp16_r,f32_r,none,0, 58624.5, 24.54
|
||||
"""
|
||||
# Positive case - valid raw output
|
||||
self.assertTrue(benchmark._process_raw_result(0, example_raw_output))
|
||||
self.assertEqual(ReturnCode.SUCCESS, benchmark.return_code)
|
||||
|
||||
self.assertEqual(2, len(benchmark.result))
|
||||
self.assertEqual(58624.5, benchmark.result['fp16_1_896_896_896_flops'][0])
|
||||
|
||||
# Negative case - invalid raw output
|
||||
self.assertFalse(benchmark._process_raw_result(1, 'HipBLAS API failed'))
|
|
@ -11,12 +11,12 @@ HPCX_HOME ?= /opt/hpcx
|
|||
CUDA_VER ?= $(shell nvcc --version | grep 'release' | awk '{print $$6}' | cut -c2- | cut -d '.' -f1-2)
|
||||
ROCBLAS_BRANCH ?= rocm-$(shell dpkg -l | grep 'rocm-dev ' | awk '{print $$3}' | cut -d '.' -f1-3)
|
||||
|
||||
.PHONY: all cuda rocm common cuda_cutlass cuda_bandwidthTest cuda_nccl_tests cuda_perftest rocm_perftest fio rocm_rccl_tests rocm_rocblas rocm_bandwidthTest gpcnet cuda_gpuburn cpu_stream cpu_hpl directx_amf_encoding_latency directx_amd
|
||||
.PHONY: all cuda rocm common cuda_cutlass cuda_bandwidthTest cuda_nccl_tests cuda_perftest rocm_perftest fio rocm_rccl_tests rocm_rocblas rocm_bandwidthTest gpcnet cuda_gpuburn cpu_stream cpu_hpl directx_amf_encoding_latency directx_amd rocm_hipblaslt
|
||||
|
||||
# Build all targets.
|
||||
all: cuda rocm
|
||||
cuda: common cuda_cutlass cuda_bandwidthTest cuda_nccl_tests cuda_perftest gpcnet cuda_gpuburn
|
||||
rocm: common rocm_perftest rocm_rccl_tests rocm_rocblas rocm_bandwidthTest
|
||||
rocm: common rocm_perftest rocm_rccl_tests rocm_rocblas rocm_bandwidthTest rocm_hipblaslt
|
||||
cpu: common cpu_perftest
|
||||
common: cpu_hpl cpu_stream fio
|
||||
directx_amd: directx_amf_encoding_latency
|
||||
|
@ -103,6 +103,18 @@ ifeq (, $(wildcard $(SB_MICRO_PATH)/bin/rocblas-bench))
|
|||
cp -v ./rocBLAS/build/release/clients/staging/rocblas-bench $(SB_MICRO_PATH)/bin/
|
||||
endif
|
||||
|
||||
# Build hipblaslt-bench.
|
||||
# hipBLASLt is released with rocm, like rocm-4.2.0 and so on.
|
||||
# The version we use is the released tag which is consistent with the rocm version in the environment or docker.
|
||||
# Since it takes several hours to build, avoid to build again if hipblaslt-bench exsists.
|
||||
rocm_hipblaslt: sb_micro_path
|
||||
@if [ ! -e $(SB_MICRO_PATH)/bin/hipblaslt-bench ] && [ -z `which hipblaslt-bench` ]; then \
|
||||
if [ -d hipBLASLt ]; then rm -rf hipBLASLt; fi; \
|
||||
git clone -b ${ROCBLAS_BRANCH} https://github.com/ROCmSoftwarePlatform/hipBLASLt.git ./hipBLASLt; \
|
||||
cd ./hipBLASLt && ./install.sh -dc; \
|
||||
cp -v ./hipBLASLt/build/release/clients/staging/hipblaslt-bench $(SB_MICRO_PATH)/bin/; \
|
||||
fi
|
||||
|
||||
# Build hipBusBandwidth.
|
||||
# HIP is released with rocm, like rocm-4.2.0 and so on.
|
||||
# The version we use is the released tag which is consistent with the rocm version in the environment or docker.
|
||||
|
|
Загрузка…
Ссылка в новой задаче