Merge pull request #101 from microsoft/dev/ritdas/fix_args

Revert some of the changes from the previous commit regarding input_set args
This commit is contained in:
Captain Jack Sparrow 2023-05-09 09:08:05 -07:00 коммит произвёл GitHub
Родитель b5145a4b8e 93335be767
Коммит faa1e5c618
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
3 изменённых файлов: 3 добавлений и 3 удалений

Просмотреть файл

@ -184,7 +184,7 @@ class Benchmark:
else: else:
print_verbose(verbose, "[Benchmarking] Benchmarking device that does not need cache flushing, skipping generation of multiple datasets") print_verbose(verbose, "[Benchmarking] Benchmarking device that does not need cache flushing, skipping generation of multiple datasets")
input_sets = generate_arg_sets_for_func(func, dyn_func_shape_fn=dyn_func_shape_fn) input_sets = [generate_arg_sets_for_func(func, dyn_func_shape_fn=dyn_func_shape_fn)]
if input_data_process_fn: if input_data_process_fn:
input_sets = input_data_process_fn(input_sets) input_sets = input_data_process_fn(input_sets)

Просмотреть файл

@ -205,7 +205,7 @@ class CudaCallableFunc(CallableFunc):
cuda.cuCtxDestroy(self.context) cuda.cuCtxDestroy(self.context)
def init_batch(self, benchmark: bool, warmup_iters=0, device_id: int = 0, args=[]): def init_batch(self, benchmark: bool, warmup_iters=0, device_id: int = 0, args=[]):
self.func_info.verify(args) self.func_info.verify(args[0] if benchmark else args)
self.device_mem = allocate_cuda_mem(self.func_info.arguments) self.device_mem = allocate_cuda_mem(self.func_info.arguments)
if not benchmark: if not benchmark:

Просмотреть файл

@ -123,7 +123,7 @@ class RocmCallableFunc(CallableFunc):
pass pass
def init_batch(self, benchmark: bool, warmup_iters=0, device_id: int = 0, args=[]): def init_batch(self, benchmark: bool, warmup_iters=0, device_id: int = 0, args=[]):
self.func_info.verify(args) self.func_info.verify(args[0] if benchmark else args)
self.device_mem = allocate_rocm_mem(benchmark, self.func_info.arguments, device_id) self.device_mem = allocate_rocm_mem(benchmark, self.func_info.arguments, device_id)
if not benchmark: if not benchmark: