Fix strides of randomly generated data

This commit is contained in:
Ritwik Das 2023-05-08 06:19:26 -07:00
Родитель 00e7471dc2
Коммит 2e2ad05b6e
1 изменённых файлов: 4 добавлений и 4 удалений

Просмотреть файл

@ -145,7 +145,7 @@ def get_dimension_arg_indices(array_arg: ArgInfo, all_arguments: List[ArgInfo])
return indices
def _gen_random_data(dtype, shape):
def _gen_random_data(dtype, shape, strides):
dtype = np.uint16 if dtype == "bfloat16" else dtype
if isinstance(dtype, np.dtype):
dtype = dtype.type
@ -157,7 +157,7 @@ def _gen_random_data(dtype, shape):
else:
data = np.random.random(tuple(shape)).astype(dtype)
return data
return np.lib.stride_tricks.as_strided(data, strides=strides)
def generate_arg_values(arguments: List[ArgInfo], dim_names_to_values={}) -> List[ArgValue]:
@ -198,7 +198,7 @@ def generate_arg_values(arguments: List[ArgInfo], dim_names_to_values={}) -> Lis
shape.append(v if isinstance(v, np.integer) or type(v) == int else v[0])
# materialize an array input using the generated shape
runtime_array_inputs = _gen_random_data(arg.numpy_dtype, shape)
runtime_array_inputs = _gen_random_data(arg.numpy_dtype, shape, arg.numpy_strides)
values.append(ArgValue(arg, runtime_array_inputs))
elif arg.name in dim_names_to_values and arg.usage == hat_file.UsageType.Input:
@ -214,7 +214,7 @@ def generate_arg_values(arguments: List[ArgInfo], dim_names_to_values={}) -> Lis
arg.numpy_strides = list(map(lambda x: x * arg.element_num_bytes, arg.shape[1:] + [1]))
if arg.usage != hat_file.UsageType.Output:
arg_data = _gen_random_data(arg.numpy_dtype, arg.shape)
arg_data = _gen_random_data(arg.numpy_dtype, arg.shape, arg.numpy_strides)
values.append(ArgValue(arg, arg_data))
else:
values.append(ArgValue(arg))