diff --git a/hatlib/arg_value.py b/hatlib/arg_value.py index 399f96d..20ab370 100644 --- a/hatlib/arg_value.py +++ b/hatlib/arg_value.py @@ -145,7 +145,7 @@ def get_dimension_arg_indices(array_arg: ArgInfo, all_arguments: List[ArgInfo]) return indices -def _gen_random_data(dtype, shape): +def _gen_random_data(dtype, shape, strides): dtype = np.uint16 if dtype == "bfloat16" else dtype if isinstance(dtype, np.dtype): dtype = dtype.type @@ -157,7 +157,7 @@ def _gen_random_data(dtype, shape): else: data = np.random.random(tuple(shape)).astype(dtype) - return data + return np.lib.stride_tricks.as_strided(data, strides=strides) def generate_arg_values(arguments: List[ArgInfo], dim_names_to_values={}) -> List[ArgValue]: @@ -198,7 +198,7 @@ def generate_arg_values(arguments: List[ArgInfo], dim_names_to_values={}) -> Lis shape.append(v if isinstance(v, np.integer) or type(v) == int else v[0]) # materialize an array input using the generated shape - runtime_array_inputs = _gen_random_data(arg.numpy_dtype, shape) + runtime_array_inputs = _gen_random_data(arg.numpy_dtype, shape, arg.numpy_strides) values.append(ArgValue(arg, runtime_array_inputs)) elif arg.name in dim_names_to_values and arg.usage == hat_file.UsageType.Input: @@ -214,7 +214,7 @@ def generate_arg_values(arguments: List[ArgInfo], dim_names_to_values={}) -> Lis arg.numpy_strides = list(map(lambda x: x * arg.element_num_bytes, arg.shape[1:] + [1])) if arg.usage != hat_file.UsageType.Output: - arg_data = _gen_random_data(arg.numpy_dtype, arg.shape) + arg_data = _gen_random_data(arg.numpy_dtype, arg.shape, arg.numpy_strides) values.append(ArgValue(arg, arg_data)) else: values.append(ArgValue(arg))