This commit is contained in:
yukirora 2024-05-23 15:24:11 +00:00
Родитель 93eaae32a2
Коммит d73bafa7a2
4 изменённых файлов: 24 добавлений и 10 удалений

Просмотреть файл

@ -4,6 +4,7 @@
"""Module of the hipBlasLt GEMM benchmark."""
import os
import sys
from superbench.common.utils import logger
from superbench.benchmarks import BenchmarkRegistry, Platform, ReturnCode
@ -132,6 +133,7 @@ class HipBlasLtBenchmark(BlasLtBaseBenchmark):
lines = raw_output.splitlines()
index = None
tflops = -1
time = sys.maxsize
metric = None
# Find the line containing 'hipblaslt-Gflops'
@ -144,10 +146,13 @@ class HipBlasLtBenchmark(BlasLtBaseBenchmark):
if len(fields) < 23:
raise ValueError('Invalid result')
metric = f'{self._precision_in_commands[cmd_idx]}_{fields[3]}_{"_".join(fields[4:7])}'
tflops = max(tflops, float(fields[21])/1000)
if float(fields[21])/1000 > tflops:
tflops = float(fields[21])/1000
time = float(fields[22])
if index is None:
raise ValueError('Line with "hipblaslt-Gflops" not found in the log.')
self._result.add_result(f'{metric}_tflops', tflops)
self._result.add_result(f'{metric}_flops', tflops)
self._result.add_result(f'{metric}_time', time)
except BaseException as e:
self._result.set_return_code(ReturnCode.MICROBENCHMARK_RESULT_PARSING_FAILURE)

Просмотреть файл

@ -159,21 +159,22 @@ class RocmComposableKernelBenchmark(BlasLtBaseBenchmark):
try:
lines = raw_output.splitlines()
index = None
line = None
# Find the line containing 'hipblaslt-Gflops'
for i, line in enumerate(lines):
if 'Best Perf' in line:
index = i
for i in lines:
if 'Best Perf' in i:
line = i
break
if index is not None:
if line is not None:
# Search the text for each pattern
datatype_match = re.search(r"datatype = (\w+)", line)
m_match = re.search(r"M = (\d+)", line)
n_match = re.search(r"N = (\d+)", line)
k_match = re.search(r"K = (\d+)", line)
flops_match = re.search(r"(\d+\.?\d*) TFlops", line)
time_match = re.search(r"(\d+\.?\d*) ms", line)
# Extract the matched groups
datatype = datatype_match.group(1) if datatype_match else None
@ -181,9 +182,14 @@ class RocmComposableKernelBenchmark(BlasLtBaseBenchmark):
n = int(n_match.group(1)) if n_match else None
k = int(k_match.group(1)) if k_match else None
flops = float(flops_match.group(1)) if flops_match else None
time = float(time_match.group(1)) if time_match else None
name = (line.split(',')[-2]).split()[0].strip()
metric = f'{datatype}_{m}_{n}_{k}_flops'
time_metric = f'{datatype}_{m}_{n}_{k}_time'
self._result.add_result(metric, flops)
self._result.add_result(time_metric, time)
self._result.add_result(f'{datatype}_{m}_{n}_{k}_kernel', name)
else:
self._result.set_return_code(ReturnCode.MICROBENCHMARK_RESULT_PARSING_FAILURE)
logger.error(

Просмотреть файл

@ -68,7 +68,7 @@ class HipblasLtBenchmarkTestCase(BenchmarkTestCase, unittest.TestCase):
f' --initialization {benchmark._args.initialization}'
else:
return f'{benchmark._HipBlasLtBenchmark__bin_path} ' + \
f'-m {m} -n {n} -k {k} -j 20 -i 50 {benchmark._in_type_map[t]} -b {b}' + \
f'-m {m} -n {n} -k {k} -j 20 -i 50 {benchmark._in_type_map[t]} --batch_count {b}' + \
f' --transA {benchmark._args.transA} --transB {benchmark._args.transB}' + \
f' --initialization {benchmark._args.initialization}'
@ -105,8 +105,9 @@ N,N,0,1,896,896,896,1,896,802816,0,896,802816,896,802816,896,802816,fp16_r,f32_r
self.assertTrue(benchmark._process_raw_result(0, example_raw_output))
self.assertEqual(ReturnCode.SUCCESS, benchmark.return_code)
self.assertEqual(2, len(benchmark.result))
self.assertEqual(3, len(benchmark.result))
self.assertEqual(58.6245, benchmark.result['fp16_1_896_896_896_flops'][0])
self.assertEqual(24.54, benchmark.result['fp16_1_896_896_896_time'][0])
# Negative case - invalid raw output
self.assertFalse(benchmark._process_raw_result(1, 'HipBLAS API failed'))

Просмотреть файл

@ -92,5 +92,7 @@ Best Perf for datatype = f16 ALayout = RowMajor BLayout = RowMajor M = 8192 N
self.assertTrue(benchmark._process_raw_result(0, example_raw_output))
self.assertEqual(ReturnCode.SUCCESS, benchmark.return_code)
self.assertEqual(2, len(benchmark.result))
self.assertEqual(4, len(benchmark.result))
self.assertEqual(506.113, benchmark.result['f16_8192_8192_8192_flops'][0])
self.assertEqual(2.17246, benchmark.result['f16_8192_8192_8192_time'][0])
self.assertEqual('GemmXdlSplitKCShuffle_Default_RRR_B256_Vec8x2x8_256x128x4x8', benchmark.result['f16_8192_8192_8192_kernel'][0])