зеркало из https://github.com/microsoft/hat.git
Merge pull request #100 from microsoft/dev/ritdas/fix_strides
Fix strides of randomly generated data
This commit is contained in:
Коммит
b5145a4b8e
|
@ -145,7 +145,7 @@ def get_dimension_arg_indices(array_arg: ArgInfo, all_arguments: List[ArgInfo])
|
|||
return indices
|
||||
|
||||
|
||||
def _gen_random_data(dtype, shape):
|
||||
def _gen_random_data(dtype, shape, strides=None):
|
||||
dtype = np.uint16 if dtype == "bfloat16" else dtype
|
||||
if isinstance(dtype, np.dtype):
|
||||
dtype = dtype.type
|
||||
|
@ -157,7 +157,7 @@ def _gen_random_data(dtype, shape):
|
|||
else:
|
||||
data = np.random.random(tuple(shape)).astype(dtype)
|
||||
|
||||
return data
|
||||
return np.lib.stride_tricks.as_strided(data, strides=strides) if strides is not None else data
|
||||
|
||||
|
||||
def generate_arg_values(arguments: List[ArgInfo], dim_names_to_values={}) -> List[ArgValue]:
|
||||
|
@ -214,7 +214,7 @@ def generate_arg_values(arguments: List[ArgInfo], dim_names_to_values={}) -> Lis
|
|||
arg.numpy_strides = list(map(lambda x: x * arg.element_num_bytes, arg.shape[1:] + [1]))
|
||||
|
||||
if arg.usage != hat_file.UsageType.Output:
|
||||
arg_data = _gen_random_data(arg.numpy_dtype, arg.shape)
|
||||
arg_data = _gen_random_data(arg.numpy_dtype, arg.shape, arg.numpy_strides)
|
||||
values.append(ArgValue(arg, arg_data))
|
||||
else:
|
||||
values.append(ArgValue(arg))
|
||||
|
|
|
@ -184,7 +184,7 @@ class Benchmark:
|
|||
else:
|
||||
print_verbose(verbose, "[Benchmarking] Benchmarking device that does not need cache flushing, skipping generation of multiple datasets")
|
||||
|
||||
input_sets = [generate_arg_sets_for_func(func, dyn_func_shape_fn=dyn_func_shape_fn)]
|
||||
input_sets = generate_arg_sets_for_func(func, dyn_func_shape_fn=dyn_func_shape_fn)
|
||||
|
||||
if input_data_process_fn:
|
||||
input_sets = input_data_process_fn(input_sets)
|
||||
|
|
|
@ -205,7 +205,7 @@ class CudaCallableFunc(CallableFunc):
|
|||
cuda.cuCtxDestroy(self.context)
|
||||
|
||||
def init_batch(self, benchmark: bool, warmup_iters=0, device_id: int = 0, args=[]):
|
||||
self.func_info.verify(args[0] if benchmark else args)
|
||||
self.func_info.verify(args)
|
||||
self.device_mem = allocate_cuda_mem(self.func_info.arguments)
|
||||
|
||||
if not benchmark:
|
||||
|
|
|
@ -123,7 +123,7 @@ class RocmCallableFunc(CallableFunc):
|
|||
pass
|
||||
|
||||
def init_batch(self, benchmark: bool, warmup_iters=0, device_id: int = 0, args=[]):
|
||||
self.func_info.verify(args[0] if benchmark else args)
|
||||
self.func_info.verify(args)
|
||||
self.device_mem = allocate_rocm_mem(benchmark, self.func_info.arguments, device_id)
|
||||
|
||||
if not benchmark:
|
||||
|
|
Загрузка…
Ссылка в новой задаче