зеркало из https://github.com/microsoft/hat.git
Fixes more issues in dynamic shape benchmarking (#81)
This commit is contained in:
Родитель
f90ac1fb86
Коммит
262252f36f
|
@ -27,6 +27,8 @@ class ArgValue:
|
|||
else:
|
||||
# no value provided, allocate the pointer
|
||||
self.allocate()
|
||||
elif type(self.value) in [int, float]:
|
||||
self.value = self.arg_info.numpy_dtype.type(self.value)
|
||||
self.dim_values = None
|
||||
|
||||
def allocate(self):
|
||||
|
@ -169,7 +171,7 @@ def generate_arg_values(arguments: List[ArgInfo], dim_names_to_values = {}) -> L
|
|||
dim_names_to_values[d] = ArgValue(dim_args[d], shape[-1])
|
||||
else:
|
||||
v = dim_names_to_values[d].value
|
||||
shape.append(v if type(v) == int else v[0])
|
||||
shape.append(v if isinstance(v, np.integer) or type(v) == int else v[0])
|
||||
|
||||
# materialize an array input using the generated shape
|
||||
runtime_array_inputs = np.random.random(tuple(shape)).astype(arg.numpy_dtype)
|
||||
|
|
|
@ -87,7 +87,7 @@ def generate_arg_sets_for_func(
|
|||
shape_idx += 1
|
||||
|
||||
else:
|
||||
numerical_shapes = [p.shape if p.shape else [int(p.size)] for p in func.arguments]
|
||||
numerical_shapes = [p.shape if p.shape else ([int(p.size)] if p.size else [1]) for p in func.arguments]
|
||||
|
||||
shapes_to_sizes = [reduce(lambda x, y: x * y, shape, 1) for shape in numerical_shapes]
|
||||
set_size = reduce(lambda x, y: x + y, map(lambda size, p: size * p.element_num_bytes, shapes_to_sizes, parameters))
|
||||
|
|
Загрузка…
Ссылка в новой задаче