Merge pull request #100 from microsoft/dev/ritdas/fix_strides

Fix strides of randomly generated data
This commit is contained in:
Captain Jack Sparrow 2023-05-08 08:15:27 -07:00 коммит произвёл GitHub
Родитель 00e7471dc2 d12a4cb7b3
Коммит b5145a4b8e
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
4 изменённых файлов: 6 добавлений и 6 удалений

Просмотреть файл

@ -145,7 +145,7 @@ def get_dimension_arg_indices(array_arg: ArgInfo, all_arguments: List[ArgInfo])
return indices
def _gen_random_data(dtype, shape):
def _gen_random_data(dtype, shape, strides=None):
dtype = np.uint16 if dtype == "bfloat16" else dtype
if isinstance(dtype, np.dtype):
dtype = dtype.type
@ -157,7 +157,7 @@ def _gen_random_data(dtype, shape):
else:
data = np.random.random(tuple(shape)).astype(dtype)
return data
return np.lib.stride_tricks.as_strided(data, strides=strides) if strides is not None else data
def generate_arg_values(arguments: List[ArgInfo], dim_names_to_values={}) -> List[ArgValue]:
@ -214,7 +214,7 @@ def generate_arg_values(arguments: List[ArgInfo], dim_names_to_values={}) -> Lis
arg.numpy_strides = list(map(lambda x: x * arg.element_num_bytes, arg.shape[1:] + [1]))
if arg.usage != hat_file.UsageType.Output:
arg_data = _gen_random_data(arg.numpy_dtype, arg.shape)
arg_data = _gen_random_data(arg.numpy_dtype, arg.shape, arg.numpy_strides)
values.append(ArgValue(arg, arg_data))
else:
values.append(ArgValue(arg))

Просмотреть файл

@ -184,7 +184,7 @@ class Benchmark:
else:
print_verbose(verbose, "[Benchmarking] Benchmarking device that does not need cache flushing, skipping generation of multiple datasets")
input_sets = [generate_arg_sets_for_func(func, dyn_func_shape_fn=dyn_func_shape_fn)]
input_sets = generate_arg_sets_for_func(func, dyn_func_shape_fn=dyn_func_shape_fn)
if input_data_process_fn:
input_sets = input_data_process_fn(input_sets)

Просмотреть файл

@ -205,7 +205,7 @@ class CudaCallableFunc(CallableFunc):
cuda.cuCtxDestroy(self.context)
def init_batch(self, benchmark: bool, warmup_iters=0, device_id: int = 0, args=[]):
self.func_info.verify(args[0] if benchmark else args)
self.func_info.verify(args)
self.device_mem = allocate_cuda_mem(self.func_info.arguments)
if not benchmark:

Просмотреть файл

@ -123,7 +123,7 @@ class RocmCallableFunc(CallableFunc):
pass
def init_batch(self, benchmark: bool, warmup_iters=0, device_id: int = 0, args=[]):
self.func_info.verify(args[0] if benchmark else args)
self.func_info.verify(args)
self.device_mem = allocate_rocm_mem(benchmark, self.func_info.arguments, device_id)
if not benchmark: