зеркало из https://github.com/microsoft/hat.git
Revert some of the changes from the previous commit regarding input_set args
This commit is contained in:
Родитель
b5145a4b8e
Коммит
93335be767
|
@ -184,7 +184,7 @@ class Benchmark:
|
|||
else:
|
||||
print_verbose(verbose, "[Benchmarking] Benchmarking device that does not need cache flushing, skipping generation of multiple datasets")
|
||||
|
||||
input_sets = generate_arg_sets_for_func(func, dyn_func_shape_fn=dyn_func_shape_fn)
|
||||
input_sets = [generate_arg_sets_for_func(func, dyn_func_shape_fn=dyn_func_shape_fn)]
|
||||
|
||||
if input_data_process_fn:
|
||||
input_sets = input_data_process_fn(input_sets)
|
||||
|
|
|
@ -205,7 +205,7 @@ class CudaCallableFunc(CallableFunc):
|
|||
cuda.cuCtxDestroy(self.context)
|
||||
|
||||
def init_batch(self, benchmark: bool, warmup_iters=0, device_id: int = 0, args=[]):
|
||||
self.func_info.verify(args)
|
||||
self.func_info.verify(args[0] if benchmark else args)
|
||||
self.device_mem = allocate_cuda_mem(self.func_info.arguments)
|
||||
|
||||
if not benchmark:
|
||||
|
|
|
@ -123,7 +123,7 @@ class RocmCallableFunc(CallableFunc):
|
|||
pass
|
||||
|
||||
def init_batch(self, benchmark: bool, warmup_iters=0, device_id: int = 0, args=[]):
|
||||
self.func_info.verify(args)
|
||||
self.func_info.verify(args[0] if benchmark else args)
|
||||
self.device_mem = allocate_rocm_mem(benchmark, self.func_info.arguments, device_id)
|
||||
|
||||
if not benchmark:
|
||||
|
|
Загрузка…
Ссылка в новой задаче