Benchmarks - Add TensorRT inference benchmark (#236)
__Description__ Add TensorRT inference benchmark for torchvision models. __Major Revision__ - Measure TensorRT inference performance.
This commit is contained in:
Родитель
b913e1f668
Коммит
8a00c8a03b
|
@ -62,7 +62,6 @@ cover/
|
|||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
|
|
@ -66,6 +66,23 @@ TODO
|
|||
|
||||
TODO
|
||||
|
||||
### `tensorrt-inference`
|
||||
|
||||
#### Introduction
|
||||
|
||||
Inference PyTorch/ONNX models on NVIDIA GPUs with [TensorRT](https://developer.nvidia.com/tensorrt).
|
||||
|
||||
#### Metrics
|
||||
|
||||
| Name | Unit | Description |
|
||||
|-------------------------------------------|-----------|----------------------------------------------------------------------------------------------------------|
|
||||
| tensorrt-inference/gpu_lat_ms_mean | time (ms) | The mean GPU latency to execute the kernels for a query. |
|
||||
| tensorrt-inference/gpu_lat_ms_99 | time (ms) | The 99th percentile GPU latency to execute the kernels for a query. |
|
||||
| tensorrt-inference/host_lat_ms_mean | time (ms) | The mean H2D, GPU, and D2H latency to execute the kernels for a query. |
|
||||
| tensorrt-inference/host_lat_ms_99 | time (ms) | The 99th percentile H2D, GPU, and D2H latency to execute the kernels for a query. |
|
||||
| tensorrt-inference/end_to_end_lat_ms_mean | time (ms) | The mean duration from when the H2D of a query is called to when the D2H of the same query is completed. |
|
||||
| tensorrt-inference/end_to_end_lat_ms_99 | time (ms) | The P99 duration from when the H2D of a query is called to when the D2H of the same query is completed. |
|
||||
|
||||
## Communication Benchmarks
|
||||
|
||||
### `mem-bw`
|
||||
|
|
|
@ -0,0 +1,21 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
"""Micro benchmark example for TensorRT inference performance.
|
||||
|
||||
Commands to run:
|
||||
python3 examples/benchmarks/tensorrt_inference_performance.py
|
||||
"""
|
||||
|
||||
from superbench.benchmarks import BenchmarkRegistry, Platform
|
||||
from superbench.common.utils import logger
|
||||
|
||||
if __name__ == '__main__':
|
||||
context = BenchmarkRegistry.create_benchmark_context('tensorrt-inference', platform=Platform.CUDA)
|
||||
benchmark = BenchmarkRegistry.launch_benchmark(context)
|
||||
if benchmark:
|
||||
logger.info(
|
||||
'benchmark: {}, return code: {}, result: {}'.format(
|
||||
benchmark.name, benchmark.return_code, benchmark.result
|
||||
)
|
||||
)
|
|
@ -4,28 +4,47 @@
|
|||
"""A module containing all the micro-benchmarks."""
|
||||
|
||||
from superbench.benchmarks.micro_benchmarks.micro_base import MicroBenchmark, MicroBenchmarkWithInvoke
|
||||
from superbench.benchmarks.micro_benchmarks.sharding_matmul import ShardingMatmul
|
||||
from superbench.benchmarks.micro_benchmarks.computation_communication_overlap import ComputationCommunicationOverlap
|
||||
from superbench.benchmarks.micro_benchmarks.kernel_launch_overhead import KernelLaunch
|
||||
from superbench.benchmarks.micro_benchmarks.cublas_function import CublasBenchmark
|
||||
from superbench.benchmarks.micro_benchmarks.cudnn_function import CudnnBenchmark
|
||||
from superbench.benchmarks.micro_benchmarks.gemm_flops_performance_base import GemmFlopsBenchmark
|
||||
from superbench.benchmarks.micro_benchmarks.cuda_gemm_flops_performance import CudaGemmFlopsBenchmark
|
||||
from superbench.benchmarks.micro_benchmarks.memory_bw_performance_base import MemBwBenchmark
|
||||
|
||||
from superbench.benchmarks.micro_benchmarks.computation_communication_overlap import ComputationCommunicationOverlap
|
||||
from superbench.benchmarks.micro_benchmarks.cublas_function import CublasBenchmark
|
||||
from superbench.benchmarks.micro_benchmarks.cuda_gemm_flops_performance import CudaGemmFlopsBenchmark
|
||||
from superbench.benchmarks.micro_benchmarks.cuda_memory_bw_performance import CudaMemBwBenchmark
|
||||
from superbench.benchmarks.micro_benchmarks.disk_performance import DiskBenchmark
|
||||
from superbench.benchmarks.micro_benchmarks.ib_loopback_performance import IBLoopbackBenchmark
|
||||
from superbench.benchmarks.micro_benchmarks.cuda_nccl_bw_performance import CudaNcclBwBenchmark
|
||||
from superbench.benchmarks.micro_benchmarks.rocm_memory_bw_performance import RocmMemBwBenchmark
|
||||
from superbench.benchmarks.micro_benchmarks.rocm_gemm_flops_performance import RocmGemmFlopsBenchmark
|
||||
from superbench.benchmarks.micro_benchmarks.ib_validation_performance import IBBenchmark
|
||||
from superbench.benchmarks.micro_benchmarks.gpu_copy_bw_performance import GpuCopyBwBenchmark
|
||||
from superbench.benchmarks.micro_benchmarks.tcp_connectivity import TCPConnectivityBenchmark
|
||||
from superbench.benchmarks.micro_benchmarks.cudnn_function import CudnnBenchmark
|
||||
from superbench.benchmarks.micro_benchmarks.disk_performance import DiskBenchmark
|
||||
from superbench.benchmarks.micro_benchmarks.gpcnet_performance import GPCNetBenchmark
|
||||
from superbench.benchmarks.micro_benchmarks.gpu_copy_bw_performance import GpuCopyBwBenchmark
|
||||
from superbench.benchmarks.micro_benchmarks.ib_loopback_performance import IBLoopbackBenchmark
|
||||
from superbench.benchmarks.micro_benchmarks.ib_validation_performance import IBBenchmark
|
||||
from superbench.benchmarks.micro_benchmarks.kernel_launch_overhead import KernelLaunch
|
||||
from superbench.benchmarks.micro_benchmarks.rocm_gemm_flops_performance import RocmGemmFlopsBenchmark
|
||||
from superbench.benchmarks.micro_benchmarks.rocm_memory_bw_performance import RocmMemBwBenchmark
|
||||
from superbench.benchmarks.micro_benchmarks.sharding_matmul import ShardingMatmul
|
||||
from superbench.benchmarks.micro_benchmarks.tcp_connectivity import TCPConnectivityBenchmark
|
||||
from superbench.benchmarks.micro_benchmarks.tensorrt_inference_performance import TensorRTInferenceBenchmark
|
||||
|
||||
__all__ = [
|
||||
'MicroBenchmark', 'MicroBenchmarkWithInvoke', 'ShardingMatmul', 'ComputationCommunicationOverlap', 'KernelLaunch',
|
||||
'CublasBenchmark', 'CudnnBenchmark', 'GemmFlopsBenchmark', 'CudaGemmFlopsBenchmark', 'MemBwBenchmark',
|
||||
'CudaMemBwBenchmark', 'DiskBenchmark', 'IBLoopbackBenchmark', 'CudaNcclBwBenchmark', 'RocmMemBwBenchmark',
|
||||
'RocmGemmFlopsBenchmark', 'IBBenchmark', 'GpuCopyBwBenchmark', 'TCPConnectivityBenchmark', 'GPCNetBenchmark'
|
||||
'ComputationCommunicationOverlap',
|
||||
'CublasBenchmark',
|
||||
'CudaGemmFlopsBenchmark',
|
||||
'CudaMemBwBenchmark',
|
||||
'CudaNcclBwBenchmark',
|
||||
'CudnnBenchmark',
|
||||
'DiskBenchmark',
|
||||
'GPCNetBenchmark',
|
||||
'GemmFlopsBenchmark',
|
||||
'GpuCopyBwBenchmark',
|
||||
'IBBenchmark',
|
||||
'IBLoopbackBenchmark',
|
||||
'KernelLaunch',
|
||||
'MemBwBenchmark',
|
||||
'MicroBenchmark',
|
||||
'MicroBenchmarkWithInvoke',
|
||||
'RocmGemmFlopsBenchmark',
|
||||
'RocmMemBwBenchmark',
|
||||
'ShardingMatmul',
|
||||
'TCPConnectivityBenchmark',
|
||||
'TensorRTInferenceBenchmark',
|
||||
]
|
||||
|
|
|
@ -0,0 +1,160 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
"""TensorRT inference micro-benchmark."""
|
||||
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
import torch.hub
|
||||
import torch.onnx
|
||||
import torchvision.models
|
||||
|
||||
from superbench.common.utils import logger
|
||||
from superbench.benchmarks import BenchmarkRegistry, Platform, ReturnCode
|
||||
from superbench.benchmarks.micro_benchmarks import MicroBenchmarkWithInvoke
|
||||
|
||||
|
||||
class TensorRTInferenceBenchmark(MicroBenchmarkWithInvoke):
|
||||
"""TensorRT inference micro-benchmark class."""
|
||||
def __init__(self, name, parameters=''):
|
||||
"""Constructor.
|
||||
|
||||
Args:
|
||||
name (str): benchmark name.
|
||||
parameters (str): benchmark parameters.
|
||||
"""
|
||||
super().__init__(name, parameters)
|
||||
|
||||
self._bin_name = 'trtexec'
|
||||
self._pytorch_models = [
|
||||
'resnet50',
|
||||
'resnet101',
|
||||
'resnet152',
|
||||
'densenet169',
|
||||
'densenet201',
|
||||
'vgg11',
|
||||
'vgg13',
|
||||
'vgg16',
|
||||
'vgg19',
|
||||
]
|
||||
self.__model_cache_path = Path(torch.hub.get_dir()) / 'checkpoints'
|
||||
|
||||
def add_parser_arguments(self):
|
||||
"""Add the specified arguments."""
|
||||
super().add_parser_arguments()
|
||||
|
||||
self._parser.add_argument(
|
||||
'--pytorch_models',
|
||||
type=str,
|
||||
nargs='+',
|
||||
default=self._pytorch_models,
|
||||
help='ONNX models for TensorRT inference benchmark, e.g., {}.'.format(', '.join(self._pytorch_models)),
|
||||
)
|
||||
|
||||
self._parser.add_argument(
|
||||
'--precision',
|
||||
type=str,
|
||||
choices=['int8', 'fp16', 'fp32'],
|
||||
default='int8',
|
||||
required=False,
|
||||
help='Precision for inference, allow int8, fp16, or fp32 only.',
|
||||
)
|
||||
|
||||
self._parser.add_argument(
|
||||
'--batch_size',
|
||||
type=int,
|
||||
default=32,
|
||||
required=False,
|
||||
help='Set batch size for implicit batch engines.',
|
||||
)
|
||||
|
||||
self._parser.add_argument(
|
||||
'--iterations',
|
||||
type=int,
|
||||
default=256,
|
||||
required=False,
|
||||
help='Run at least N inference iterations.',
|
||||
)
|
||||
|
||||
def _preprocess(self):
|
||||
"""Preprocess/preparation operations before the benchmarking.
|
||||
|
||||
Return:
|
||||
True if _preprocess() succeed.
|
||||
"""
|
||||
if not super()._preprocess():
|
||||
return False
|
||||
|
||||
self.__bin_path = str(Path(self._args.bin_dir) / self._bin_name)
|
||||
|
||||
for model in self._args.pytorch_models:
|
||||
if hasattr(torchvision.models, model):
|
||||
torch.onnx.export(
|
||||
getattr(torchvision.models, model)(pretrained=True).cuda(),
|
||||
torch.randn(self._args.batch_size, 3, 224, 224, device='cuda'),
|
||||
f'{self.__model_cache_path / (model + ".onnx")}',
|
||||
)
|
||||
self._commands.append(
|
||||
' '.join(
|
||||
filter(
|
||||
None, [
|
||||
self.__bin_path,
|
||||
None if self._args.precision == 'fp32' else f'--{self._args.precision}',
|
||||
f'--batch={self._args.batch_size}',
|
||||
f'--iterations={self._args.iterations}',
|
||||
'--workspace=1024',
|
||||
'--percentile=99',
|
||||
f'--onnx={self.__model_cache_path / (model + ".onnx")}',
|
||||
]
|
||||
)
|
||||
)
|
||||
)
|
||||
else:
|
||||
logger.error('Cannot find PyTorch model %s.', model)
|
||||
return False
|
||||
return True
|
||||
|
||||
def _process_raw_result(self, cmd_idx, raw_output):
|
||||
"""Function to parse raw results and save the summarized results.
|
||||
|
||||
self._result.add_raw_data() and self._result.add_result() need to be called to save the results.
|
||||
|
||||
Args:
|
||||
cmd_idx (int): the index of command corresponding with the raw_output.
|
||||
raw_output (str): raw output string of the micro-benchmark.
|
||||
|
||||
Return:
|
||||
True if the raw output string is valid and result can be extracted.
|
||||
"""
|
||||
self._result.add_raw_data(f'raw_output_{self._args.pytorch_models[cmd_idx]}', raw_output)
|
||||
|
||||
success = False
|
||||
try:
|
||||
for line in raw_output.strip().splitlines():
|
||||
line = line.strip()
|
||||
if '[I] mean:' in line or '[I] percentile:' in line:
|
||||
tag = 'mean' if '[I] mean:' in line else '99'
|
||||
lats = re.findall(r'(\d+\.\d+) ms', line)
|
||||
if len(lats) == 1:
|
||||
self._result.add_result(f'gpu_lat_ms_{tag}', float(lats[0]))
|
||||
elif len(lats) == 2:
|
||||
self._result.add_result(f'host_lat_ms_{tag}', float(lats[0]))
|
||||
self._result.add_result(f'end_to_end_lat_ms_{tag}', float(lats[1]))
|
||||
success = True
|
||||
except BaseException as e:
|
||||
self._result.set_return_code(ReturnCode.MICROBENCHMARK_RESULT_PARSING_FAILURE)
|
||||
logger.error(
|
||||
'The result format is invalid - round: {}, benchmark: {}, raw output: {}, message: {}.'.format(
|
||||
self._curr_run_index, self._name, raw_output, str(e)
|
||||
)
|
||||
)
|
||||
return False
|
||||
return success
|
||||
|
||||
|
||||
BenchmarkRegistry.register_benchmark(
|
||||
'tensorrt-inference',
|
||||
TensorRTInferenceBenchmark,
|
||||
platform=Platform.CUDA,
|
||||
)
|
|
@ -0,0 +1,143 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""Tests for tensorrt-inference benchmark."""
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
|
||||
from tests.helper import decorator
|
||||
from superbench.benchmarks import BenchmarkRegistry, BenchmarkType, ReturnCode, Platform
|
||||
from superbench.benchmarks.result import BenchmarkResult
|
||||
|
||||
|
||||
class TensorRTInferenceBenchmarkTestCase(unittest.TestCase):
|
||||
"""Class for tensorrt-inferencee benchmark test cases."""
|
||||
def setUp(self):
|
||||
"""Hook method for setting up the test fixture before exercising it."""
|
||||
self.benchmark_name = 'tensorrt-inference'
|
||||
self.__tmp_dir = tempfile.mkdtemp()
|
||||
self.__curr_micro_path = os.environ.get('SB_MICRO_PATH', '')
|
||||
os.environ['SB_MICRO_PATH'] = self.__tmp_dir
|
||||
os.environ['TORCH_HOME'] = self.__tmp_dir
|
||||
(Path(self.__tmp_dir) / 'bin').mkdir(parents=True, exist_ok=True)
|
||||
(Path(self.__tmp_dir) / 'bin' / 'trtexec').touch(mode=0o755, exist_ok=True)
|
||||
|
||||
def tearDown(self):
|
||||
"""Hook method for deconstructing the test fixture after testing it."""
|
||||
shutil.rmtree(self.__tmp_dir)
|
||||
os.environ['SB_MICRO_PATH'] = self.__curr_micro_path
|
||||
del os.environ['TORCH_HOME']
|
||||
|
||||
def test_tensorrt_inference_cls(self):
|
||||
"""Test tensorrt-inference benchmark class."""
|
||||
for platform in Platform:
|
||||
(benchmark_cls, _) = BenchmarkRegistry._BenchmarkRegistry__select_benchmark(self.benchmark_name, platform)
|
||||
if platform is Platform.CUDA:
|
||||
self.assertIsNotNone(benchmark_cls)
|
||||
else:
|
||||
self.assertIsNone(benchmark_cls)
|
||||
|
||||
@decorator.cuda_test
|
||||
@decorator.pytorch_test
|
||||
def test_tensorrt_inference_params(self):
|
||||
"""Test tensorrt-inference benchmark preprocess with different parameters."""
|
||||
(benchmark_cls, _) = BenchmarkRegistry._BenchmarkRegistry__select_benchmark(self.benchmark_name, Platform.CUDA)
|
||||
|
||||
test_cases = [
|
||||
{
|
||||
'precision': 'fp32',
|
||||
},
|
||||
{
|
||||
'pytorch_models': ['resnet50', 'mnasnet0_5'],
|
||||
'precision': 'fp16',
|
||||
},
|
||||
{
|
||||
'pytorch_models': ['resnet50'],
|
||||
'batch_size': 4,
|
||||
},
|
||||
{
|
||||
'batch_size': 4,
|
||||
'iterations': 128,
|
||||
},
|
||||
]
|
||||
for test_case in test_cases:
|
||||
with self.subTest(msg='Testing with case', test_case=test_case):
|
||||
parameter_list = []
|
||||
if 'pytorch_models' in test_case:
|
||||
parameter_list.append(f'--pytorch_models {" ".join(test_case["pytorch_models"])}')
|
||||
if 'precision' in test_case:
|
||||
parameter_list.append(f'--precision {test_case["precision"]}')
|
||||
if 'batch_size' in test_case:
|
||||
parameter_list.append(f'--batch_size {test_case["batch_size"]}')
|
||||
if 'iterations' in test_case:
|
||||
parameter_list.append(f'--iterations {test_case["iterations"]}')
|
||||
|
||||
# Check basic information
|
||||
benchmark = benchmark_cls(self.benchmark_name, parameters=' '.join(parameter_list))
|
||||
self.assertTrue(benchmark)
|
||||
|
||||
# Limit model number
|
||||
benchmark._pytorch_models = benchmark._pytorch_models[:1]
|
||||
benchmark._TensorRTInferenceBenchmark__model_cache_path = Path(self.__tmp_dir) / 'hub/checkpoints'
|
||||
|
||||
# Preprocess
|
||||
ret = benchmark._preprocess()
|
||||
self.assertTrue(ret)
|
||||
self.assertEqual(ReturnCode.SUCCESS, benchmark.return_code)
|
||||
self.assertEqual(BenchmarkType.MICRO, benchmark.type)
|
||||
self.assertEqual(self.benchmark_name, benchmark.name)
|
||||
|
||||
# Check parameters
|
||||
self.assertEqual(
|
||||
test_case.get('pytorch_models', benchmark._pytorch_models),
|
||||
benchmark._args.pytorch_models,
|
||||
)
|
||||
self.assertEqual(
|
||||
test_case.get('precision', 'int8'),
|
||||
benchmark._args.precision,
|
||||
)
|
||||
self.assertEqual(
|
||||
test_case.get('batch_size', 32),
|
||||
benchmark._args.batch_size,
|
||||
)
|
||||
self.assertEqual(
|
||||
test_case.get('iterations', 256),
|
||||
benchmark._args.iterations,
|
||||
)
|
||||
|
||||
# Check models
|
||||
for model in benchmark._args.pytorch_models:
|
||||
self.assertTrue(
|
||||
(benchmark._TensorRTInferenceBenchmark__model_cache_path / f'{model}.onnx').is_file()
|
||||
)
|
||||
|
||||
# Command list should equal to default model number
|
||||
self.assertEqual(
|
||||
len(test_case.get('pytorch_models', benchmark._pytorch_models)), len(benchmark._commands)
|
||||
)
|
||||
|
||||
@decorator.load_data('tests/data/tensorrt_inference.log')
|
||||
def test_tensorrt_inference_result_parsing(self, test_raw_log):
|
||||
"""Test tensorrt-inference benchmark result parsing."""
|
||||
(benchmark_cls, _) = BenchmarkRegistry._BenchmarkRegistry__select_benchmark(self.benchmark_name, Platform.CUDA)
|
||||
benchmark = benchmark_cls(self.benchmark_name, parameters='')
|
||||
benchmark._args = SimpleNamespace(pytorch_models=['model_0', 'model_1'])
|
||||
benchmark._result = BenchmarkResult(self.benchmark_name, BenchmarkType.MICRO, ReturnCode.SUCCESS, run_count=1)
|
||||
|
||||
# Positive case - valid raw output
|
||||
self.assertTrue(benchmark._process_raw_result(0, test_raw_log))
|
||||
self.assertEqual(ReturnCode.SUCCESS, benchmark.return_code)
|
||||
|
||||
self.assertEqual(6, len(benchmark.result))
|
||||
for tag in ['mean', '99']:
|
||||
self.assertEqual(0.5, benchmark.result[f'gpu_lat_ms_{tag}'][0])
|
||||
self.assertEqual(0.6, benchmark.result[f'host_lat_ms_{tag}'][0])
|
||||
self.assertEqual(1.0, benchmark.result[f'end_to_end_lat_ms_{tag}'][0])
|
||||
|
||||
# Negative case - invalid raw output
|
||||
self.assertFalse(benchmark._process_raw_result(1, 'Invalid raw output'))
|
|
@ -0,0 +1,104 @@
|
|||
[11/02/2021-09:15:15] [I] === Model Options ===
|
||||
[11/02/2021-09:15:15] [I] Format: ONNX
|
||||
[11/02/2021-09:15:15] [I] Model: resnet/model/resnet50-v1.onnx
|
||||
[11/02/2021-09:15:15] [I] Output:
|
||||
[11/02/2021-09:15:15] [I] === Build Options ===
|
||||
[11/02/2021-09:15:15] [I] Max batch: explicit
|
||||
[11/02/2021-09:15:15] [I] Workspace: 1024 MiB
|
||||
[11/02/2021-09:15:15] [I] minTiming: 1
|
||||
[11/02/2021-09:15:15] [I] avgTiming: 8
|
||||
[11/02/2021-09:15:15] [I] Precision: FP32+INT8
|
||||
[11/02/2021-09:15:15] [I] Calibration: Dynamic
|
||||
[11/02/2021-09:15:15] [I] Refit: Disabled
|
||||
[11/02/2021-09:15:15] [I] Safe mode: Disabled
|
||||
[11/02/2021-09:15:15] [I] Save engine:
|
||||
[11/02/2021-09:15:15] [I] Load engine:
|
||||
[11/02/2021-09:15:15] [I] Builder Cache: Enabled
|
||||
[11/02/2021-09:15:15] [I] NVTX verbosity: 0
|
||||
[11/02/2021-09:15:15] [I] Tactic sources: Using default tactic sources
|
||||
[11/02/2021-09:15:15] [I] Input(s)s format: fp32:CHW
|
||||
[11/02/2021-09:15:15] [I] Output(s)s format: fp32:CHW
|
||||
[11/02/2021-09:15:15] [I] Input build shapes: model
|
||||
[11/02/2021-09:15:15] [I] Input calibration shapes: model
|
||||
[11/02/2021-09:15:15] [I] === System Options ===
|
||||
[11/02/2021-09:15:15] [I] Device: 0
|
||||
[11/02/2021-09:15:15] [I] DLACore:
|
||||
[11/02/2021-09:15:15] [I] Plugins:
|
||||
[11/02/2021-09:15:15] [I] === Inference Options ===
|
||||
[11/02/2021-09:15:15] [I] Batch: Explicit
|
||||
[11/02/2021-09:15:15] [I] Input inference shapes: model
|
||||
[11/02/2021-09:15:15] [I] Iterations: 1024
|
||||
[11/02/2021-09:15:15] [I] Duration: 3s (+ 200ms warm up)
|
||||
[11/02/2021-09:15:15] [I] Sleep time: 0ms
|
||||
[11/02/2021-09:15:15] [I] Streams: 1
|
||||
[11/02/2021-09:15:15] [I] ExposeDMA: Disabled
|
||||
[11/02/2021-09:15:15] [I] Data transfers: Enabled
|
||||
[11/02/2021-09:15:15] [I] Spin-wait: Disabled
|
||||
[11/02/2021-09:15:15] [I] Multithreading: Disabled
|
||||
[11/02/2021-09:15:15] [I] CUDA Graph: Disabled
|
||||
[11/02/2021-09:15:15] [I] Separate profiling: Disabled
|
||||
[11/02/2021-09:15:15] [I] Skip inference: Disabled
|
||||
[11/02/2021-09:15:15] [I] Inputs:
|
||||
[11/02/2021-09:15:15] [I] === Reporting Options ===
|
||||
[11/02/2021-09:15:15] [I] Verbose: Disabled
|
||||
[11/02/2021-09:15:15] [I] Averages: 10 inferences
|
||||
[11/02/2021-09:15:15] [I] Percentile: 99
|
||||
[11/02/2021-09:15:15] [I] Dump refittable layers:Disabled
|
||||
[11/02/2021-09:15:15] [I] Dump output: Disabled
|
||||
[11/02/2021-09:15:15] [I] Profile: Disabled
|
||||
[11/02/2021-09:15:15] [I] Export timing to JSON file:
|
||||
[11/02/2021-09:15:15] [I] Export output to JSON file:
|
||||
[11/02/2021-09:15:15] [I] Export profile to JSON file:
|
||||
[11/02/2021-09:15:15] [I]
|
||||
[11/02/2021-09:15:16] [I] === Device Information ===
|
||||
[11/02/2021-09:15:16] [I] Selected Device: A100-SXM4-40GB
|
||||
[11/02/2021-09:15:16] [I] Compute Capability: 8.0
|
||||
[11/02/2021-09:15:16] [I] SMs: 108
|
||||
[11/02/2021-09:15:16] [I] Compute Clock Rate: 1.41 GHz
|
||||
[11/02/2021-09:15:16] [I] Device Global Memory: 40536 MiB
|
||||
[11/02/2021-09:15:16] [I] Shared Memory per SM: 164 KiB
|
||||
[11/02/2021-09:15:16] [I] Memory Bus Width: 5120 bits (ECC enabled)
|
||||
[11/02/2021-09:15:16] [I] Memory Clock Rate: 1.215 GHz
|
||||
[11/02/2021-09:15:16] [I]
|
||||
----------------------------------------------------------------
|
||||
Input filename: resnet/model/resnet50-v1.onnx
|
||||
ONNX IR version: 0.0.3
|
||||
Opset version: 8
|
||||
Producer name:
|
||||
Producer version:
|
||||
Domain:
|
||||
Model version: 0
|
||||
Doc string:
|
||||
----------------------------------------------------------------
|
||||
[11/02/2021-09:15:26] [I] FP32 and INT8 precisions have been specified - more performance might be enabled by additionally specifying --fp16 or --best
|
||||
[11/02/2021-09:15:26] [W] [TRT] Calibrator is not being used. Users must provide dynamic range for all tensors that are not Int32.
|
||||
[11/02/2021-09:16:39] [I] [TRT] Some tactics do not have sufficient workspace memory to run. Increasing workspace size may increase performance, please check verbose output.
|
||||
[11/02/2021-09:17:06] [I] [TRT] Detected 1 inputs and 1 output network tensors.
|
||||
[11/02/2021-09:17:06] [I] Engine built in 109.833 sec.
|
||||
[11/02/2021-09:17:06] [I] Starting inference
|
||||
[11/02/2021-09:17:09] [I] Warmup completed 0 queries over 200 ms
|
||||
[11/02/2021-09:17:09] [I] Timing trace has 0 queries over 3.00142 s
|
||||
[11/02/2021-09:17:09] [I] Trace averages of 10 runs:
|
||||
[11/02/2021-09:17:09] [I] Average on 10 runs - GPU latency: 0.5 ms - Host latency: 0.6 ms (end to end 1.0 ms, enqueue 0.2 ms)
|
||||
[11/02/2021-09:17:09] [I] Average on 10 runs - GPU latency: 0.5 ms - Host latency: 0.6 ms (end to end 1.0 ms, enqueue 0.2 ms)
|
||||
[11/02/2021-09:17:09] [I] Average on 10 runs - GPU latency: 0.5 ms - Host latency: 0.6 ms (end to end 1.0 ms, enqueue 0.2 ms)
|
||||
[11/02/2021-09:17:09] [I] Host Latency
|
||||
[11/02/2021-09:17:09] [I] min: 0.6 ms (end to end 1.0 ms)
|
||||
[11/02/2021-09:17:09] [I] max: 0.6 ms (end to end 1.0 ms)
|
||||
[11/02/2021-09:17:09] [I] mean: 0.6 ms (end to end 1.0 ms)
|
||||
[11/02/2021-09:17:09] [I] median: 0.6 ms (end to end 1.0 ms)
|
||||
[11/02/2021-09:17:09] [I] percentile: 0.6 ms at 99% (end to end 1.0 ms at 99%)
|
||||
[11/02/2021-09:17:09] [I] throughput: 0 qps
|
||||
[11/02/2021-09:17:09] [I] walltime: 3.00142 s
|
||||
[11/02/2021-09:17:09] [I] Enqueue Time
|
||||
[11/02/2021-09:17:09] [I] min: 0.2 ms
|
||||
[11/02/2021-09:17:09] [I] max: 0.2 ms
|
||||
[11/02/2021-09:17:09] [I] median: 0.2 ms
|
||||
[11/02/2021-09:17:09] [I] GPU Compute
|
||||
[11/02/2021-09:17:09] [I] min: 0.5 ms
|
||||
[11/02/2021-09:17:09] [I] max: 0.5 ms
|
||||
[11/02/2021-09:17:09] [I] mean: 0.5 ms
|
||||
[11/02/2021-09:17:09] [I] median: 0.5 ms
|
||||
[11/02/2021-09:17:09] [I] percentile: 0.5 ms at 99%
|
||||
[11/02/2021-09:17:09] [I] total compute time: 2.96622 s
|
||||
&&&& PASSED TensorRT.trtexec # trtexec --batch=32 --iterations=1024 --workspace=1024 --percentile=99 --onnx=resnet/model/resnet50-v1.onnx --int8
|
|
@ -5,8 +5,32 @@
|
|||
|
||||
import os
|
||||
import unittest
|
||||
import functools
|
||||
from pathlib import Path
|
||||
|
||||
cuda_test = unittest.skipIf(os.environ.get('SB_TEST_CUDA', '1') == '0', 'Skip CUDA tests.')
|
||||
rocm_test = unittest.skipIf(os.environ.get('SB_TEST_ROCM', '0') == '0', 'Skip ROCm tests.')
|
||||
|
||||
pytorch_test = unittest.skipIf(os.environ.get('SB_TEST_PYTORCH', '1') == '0', 'Skip PyTorch tests.')
|
||||
|
||||
|
||||
def load_data(filepath):
|
||||
"""Decorator to load data file.
|
||||
|
||||
Args:
|
||||
filepath (str): Data file path, e.g., tests/data/output.log.
|
||||
|
||||
Returns:
|
||||
func: decorated function, data variable is assigned to last argument.
|
||||
"""
|
||||
with Path(filepath).open() as fp:
|
||||
data = fp.read()
|
||||
|
||||
def decorator(func):
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
return func(*args, data, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
|
Загрузка…
Ссылка в новой задаче