diff --git a/hatlib/benchmark_hat_package.py b/hatlib/benchmark_hat_package.py index d3c6eed..5604662 100644 --- a/hatlib/benchmark_hat_package.py +++ b/hatlib/benchmark_hat_package.py @@ -184,7 +184,7 @@ class Benchmark: else: print_verbose(verbose, "[Benchmarking] Benchmarking device that does not need cache flushing, skipping generation of multiple datasets") - input_sets = generate_arg_sets_for_func(func, dyn_func_shape_fn=dyn_func_shape_fn) + input_sets = [generate_arg_sets_for_func(func, dyn_func_shape_fn=dyn_func_shape_fn)] if input_data_process_fn: input_sets = input_data_process_fn(input_sets) diff --git a/hatlib/cuda_loader.py b/hatlib/cuda_loader.py index 19d1b6c..51b32bc 100644 --- a/hatlib/cuda_loader.py +++ b/hatlib/cuda_loader.py @@ -205,7 +205,7 @@ class CudaCallableFunc(CallableFunc): cuda.cuCtxDestroy(self.context) def init_batch(self, benchmark: bool, warmup_iters=0, device_id: int = 0, args=[]): - self.func_info.verify(args) + self.func_info.verify(args[0] if benchmark else args) self.device_mem = allocate_cuda_mem(self.func_info.arguments) if not benchmark: diff --git a/hatlib/rocm_loader.py b/hatlib/rocm_loader.py index 24684fa..ec5a290 100644 --- a/hatlib/rocm_loader.py +++ b/hatlib/rocm_loader.py @@ -123,7 +123,7 @@ class RocmCallableFunc(CallableFunc): pass def init_batch(self, benchmark: bool, warmup_iters=0, device_id: int = 0, args=[]): - self.func_info.verify(args) + self.func_info.verify(args[0] if benchmark else args) self.device_mem = allocate_rocm_mem(benchmark, self.func_info.arguments, device_id) if not benchmark: