зеркало из https://github.com/microsoft/hat.git
Merge pull request #101 from microsoft/dev/ritdas/fix_args
Revert some of the changes from the previous commit regarding input_set args
This commit is contained in:
Коммит
faa1e5c618
|
@ -184,7 +184,7 @@ class Benchmark:
|
|||
else:
|
||||
print_verbose(verbose, "[Benchmarking] Benchmarking device that does not need cache flushing, skipping generation of multiple datasets")
|
||||
|
||||
input_sets = generate_arg_sets_for_func(func, dyn_func_shape_fn=dyn_func_shape_fn)
|
||||
input_sets = [generate_arg_sets_for_func(func, dyn_func_shape_fn=dyn_func_shape_fn)]
|
||||
|
||||
if input_data_process_fn:
|
||||
input_sets = input_data_process_fn(input_sets)
|
||||
|
|
|
@ -205,7 +205,7 @@ class CudaCallableFunc(CallableFunc):
|
|||
cuda.cuCtxDestroy(self.context)
|
||||
|
||||
def init_batch(self, benchmark: bool, warmup_iters=0, device_id: int = 0, args=[]):
|
||||
self.func_info.verify(args)
|
||||
self.func_info.verify(args[0] if benchmark else args)
|
||||
self.device_mem = allocate_cuda_mem(self.func_info.arguments)
|
||||
|
||||
if not benchmark:
|
||||
|
|
|
@ -123,7 +123,7 @@ class RocmCallableFunc(CallableFunc):
|
|||
pass
|
||||
|
||||
def init_batch(self, benchmark: bool, warmup_iters=0, device_id: int = 0, args=[]):
|
||||
self.func_info.verify(args)
|
||||
self.func_info.verify(args[0] if benchmark else args)
|
||||
self.device_mem = allocate_rocm_mem(benchmark, self.func_info.arguments, device_id)
|
||||
|
||||
if not benchmark:
|
||||
|
|
Загрузка…
Ссылка в новой задаче