add latency metric
This commit is contained in:
Родитель
93eaae32a2
Коммит
d73bafa7a2
|
@ -4,6 +4,7 @@
|
|||
"""Module of the hipBlasLt GEMM benchmark."""
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
from superbench.common.utils import logger
|
||||
from superbench.benchmarks import BenchmarkRegistry, Platform, ReturnCode
|
||||
|
@ -132,6 +133,7 @@ class HipBlasLtBenchmark(BlasLtBaseBenchmark):
|
|||
lines = raw_output.splitlines()
|
||||
index = None
|
||||
tflops = -1
|
||||
time = sys.maxsize
|
||||
metric = None
|
||||
|
||||
# Find the line containing 'hipblaslt-Gflops'
|
||||
|
@ -144,10 +146,13 @@ class HipBlasLtBenchmark(BlasLtBaseBenchmark):
|
|||
if len(fields) < 23:
|
||||
raise ValueError('Invalid result')
|
||||
metric = f'{self._precision_in_commands[cmd_idx]}_{fields[3]}_{"_".join(fields[4:7])}'
|
||||
tflops = max(tflops, float(fields[21])/1000)
|
||||
if float(fields[21])/1000 > tflops:
|
||||
tflops = float(fields[21])/1000
|
||||
time = float(fields[22])
|
||||
if index is None:
|
||||
raise ValueError('Line with "hipblaslt-Gflops" not found in the log.')
|
||||
self._result.add_result(f'{metric}_tflops', tflops)
|
||||
self._result.add_result(f'{metric}_flops', tflops)
|
||||
self._result.add_result(f'{metric}_time', time)
|
||||
|
||||
except BaseException as e:
|
||||
self._result.set_return_code(ReturnCode.MICROBENCHMARK_RESULT_PARSING_FAILURE)
|
||||
|
|
|
@ -159,21 +159,22 @@ class RocmComposableKernelBenchmark(BlasLtBaseBenchmark):
|
|||
|
||||
try:
|
||||
lines = raw_output.splitlines()
|
||||
index = None
|
||||
line = None
|
||||
|
||||
# Find the line containing 'hipblaslt-Gflops'
|
||||
for i, line in enumerate(lines):
|
||||
if 'Best Perf' in line:
|
||||
index = i
|
||||
for i in lines:
|
||||
if 'Best Perf' in i:
|
||||
line = i
|
||||
break
|
||||
|
||||
if index is not None:
|
||||
if line is not None:
|
||||
# Search the text for each pattern
|
||||
datatype_match = re.search(r"datatype = (\w+)", line)
|
||||
m_match = re.search(r"M = (\d+)", line)
|
||||
n_match = re.search(r"N = (\d+)", line)
|
||||
k_match = re.search(r"K = (\d+)", line)
|
||||
flops_match = re.search(r"(\d+\.?\d*) TFlops", line)
|
||||
time_match = re.search(r"(\d+\.?\d*) ms", line)
|
||||
|
||||
# Extract the matched groups
|
||||
datatype = datatype_match.group(1) if datatype_match else None
|
||||
|
@ -181,9 +182,14 @@ class RocmComposableKernelBenchmark(BlasLtBaseBenchmark):
|
|||
n = int(n_match.group(1)) if n_match else None
|
||||
k = int(k_match.group(1)) if k_match else None
|
||||
flops = float(flops_match.group(1)) if flops_match else None
|
||||
time = float(time_match.group(1)) if time_match else None
|
||||
name = (line.split(',')[-2]).split()[0].strip()
|
||||
|
||||
metric = f'{datatype}_{m}_{n}_{k}_flops'
|
||||
time_metric = f'{datatype}_{m}_{n}_{k}_time'
|
||||
self._result.add_result(metric, flops)
|
||||
self._result.add_result(time_metric, time)
|
||||
self._result.add_result(f'{datatype}_{m}_{n}_{k}_kernel', name)
|
||||
else:
|
||||
self._result.set_return_code(ReturnCode.MICROBENCHMARK_RESULT_PARSING_FAILURE)
|
||||
logger.error(
|
||||
|
|
|
@ -68,7 +68,7 @@ class HipblasLtBenchmarkTestCase(BenchmarkTestCase, unittest.TestCase):
|
|||
f' --initialization {benchmark._args.initialization}'
|
||||
else:
|
||||
return f'{benchmark._HipBlasLtBenchmark__bin_path} ' + \
|
||||
f'-m {m} -n {n} -k {k} -j 20 -i 50 {benchmark._in_type_map[t]} -b {b}' + \
|
||||
f'-m {m} -n {n} -k {k} -j 20 -i 50 {benchmark._in_type_map[t]} --batch_count {b}' + \
|
||||
f' --transA {benchmark._args.transA} --transB {benchmark._args.transB}' + \
|
||||
f' --initialization {benchmark._args.initialization}'
|
||||
|
||||
|
@ -105,8 +105,9 @@ N,N,0,1,896,896,896,1,896,802816,0,896,802816,896,802816,896,802816,fp16_r,f32_r
|
|||
self.assertTrue(benchmark._process_raw_result(0, example_raw_output))
|
||||
self.assertEqual(ReturnCode.SUCCESS, benchmark.return_code)
|
||||
|
||||
self.assertEqual(2, len(benchmark.result))
|
||||
self.assertEqual(3, len(benchmark.result))
|
||||
self.assertEqual(58.6245, benchmark.result['fp16_1_896_896_896_flops'][0])
|
||||
self.assertEqual(24.54, benchmark.result['fp16_1_896_896_896_time'][0])
|
||||
|
||||
# Negative case - invalid raw output
|
||||
self.assertFalse(benchmark._process_raw_result(1, 'HipBLAS API failed'))
|
||||
|
|
|
@ -92,5 +92,7 @@ Best Perf for datatype = f16 ALayout = RowMajor BLayout = RowMajor M = 8192 N
|
|||
self.assertTrue(benchmark._process_raw_result(0, example_raw_output))
|
||||
self.assertEqual(ReturnCode.SUCCESS, benchmark.return_code)
|
||||
|
||||
self.assertEqual(2, len(benchmark.result))
|
||||
self.assertEqual(4, len(benchmark.result))
|
||||
self.assertEqual(506.113, benchmark.result['f16_8192_8192_8192_flops'][0])
|
||||
self.assertEqual(2.17246, benchmark.result['f16_8192_8192_8192_time'][0])
|
||||
self.assertEqual('GemmXdlSplitKCShuffle_Default_RRR_B256_Vec8x2x8_256x128x4x8', benchmark.result['f16_8192_8192_8192_kernel'][0])
|
||||
|
|
Загрузка…
Ссылка в новой задаче