зеркало из https://github.com/microsoft/hat.git
Another fix
This commit is contained in:
Родитель
2e2ad05b6e
Коммит
b02b922559
|
@ -145,7 +145,7 @@ def get_dimension_arg_indices(array_arg: ArgInfo, all_arguments: List[ArgInfo])
|
||||||
return indices
|
return indices
|
||||||
|
|
||||||
|
|
||||||
def _gen_random_data(dtype, shape, strides):
|
def _gen_random_data(dtype, shape, strides=None):
|
||||||
dtype = np.uint16 if dtype == "bfloat16" else dtype
|
dtype = np.uint16 if dtype == "bfloat16" else dtype
|
||||||
if isinstance(dtype, np.dtype):
|
if isinstance(dtype, np.dtype):
|
||||||
dtype = dtype.type
|
dtype = dtype.type
|
||||||
|
@ -157,7 +157,7 @@ def _gen_random_data(dtype, shape, strides):
|
||||||
else:
|
else:
|
||||||
data = np.random.random(tuple(shape)).astype(dtype)
|
data = np.random.random(tuple(shape)).astype(dtype)
|
||||||
|
|
||||||
return np.lib.stride_tricks.as_strided(data, strides=strides)
|
return np.lib.stride_tricks.as_strided(data, strides=strides) if strides is not None else data
|
||||||
|
|
||||||
|
|
||||||
def generate_arg_values(arguments: List[ArgInfo], dim_names_to_values={}) -> List[ArgValue]:
|
def generate_arg_values(arguments: List[ArgInfo], dim_names_to_values={}) -> List[ArgValue]:
|
||||||
|
@ -198,7 +198,7 @@ def generate_arg_values(arguments: List[ArgInfo], dim_names_to_values={}) -> Lis
|
||||||
shape.append(v if isinstance(v, np.integer) or type(v) == int else v[0])
|
shape.append(v if isinstance(v, np.integer) or type(v) == int else v[0])
|
||||||
|
|
||||||
# materialize an array input using the generated shape
|
# materialize an array input using the generated shape
|
||||||
runtime_array_inputs = _gen_random_data(arg.numpy_dtype, shape, arg.numpy_strides)
|
runtime_array_inputs = _gen_random_data(arg.numpy_dtype, shape)
|
||||||
values.append(ArgValue(arg, runtime_array_inputs))
|
values.append(ArgValue(arg, runtime_array_inputs))
|
||||||
|
|
||||||
elif arg.name in dim_names_to_values and arg.usage == hat_file.UsageType.Input:
|
elif arg.name in dim_names_to_values and arg.usage == hat_file.UsageType.Input:
|
||||||
|
|
Загрузка…
Ссылка в новой задаче