Fixes more issues in dynamic shape benchmarking (#81)

This commit is contained in:
Kern Handa 2022-11-15 17:52:10 -08:00 коммит произвёл GitHub
Родитель f90ac1fb86
Коммит 262252f36f
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
2 изменённых файлов: 4 добавлений и 2 удалений

Просмотреть файл

@ -27,6 +27,8 @@ class ArgValue:
else: else:
# no value provided, allocate the pointer # no value provided, allocate the pointer
self.allocate() self.allocate()
elif type(self.value) in [int, float]:
self.value = self.arg_info.numpy_dtype.type(self.value)
self.dim_values = None self.dim_values = None
def allocate(self): def allocate(self):
@ -169,7 +171,7 @@ def generate_arg_values(arguments: List[ArgInfo], dim_names_to_values = {}) -> L
dim_names_to_values[d] = ArgValue(dim_args[d], shape[-1]) dim_names_to_values[d] = ArgValue(dim_args[d], shape[-1])
else: else:
v = dim_names_to_values[d].value v = dim_names_to_values[d].value
shape.append(v if type(v) == int else v[0]) shape.append(v if isinstance(v, np.integer) or type(v) == int else v[0])
# materialize an array input using the generated shape # materialize an array input using the generated shape
runtime_array_inputs = np.random.random(tuple(shape)).astype(arg.numpy_dtype) runtime_array_inputs = np.random.random(tuple(shape)).astype(arg.numpy_dtype)

Просмотреть файл

@ -87,7 +87,7 @@ def generate_arg_sets_for_func(
shape_idx += 1 shape_idx += 1
else: else:
numerical_shapes = [p.shape if p.shape else [int(p.size)] for p in func.arguments] numerical_shapes = [p.shape if p.shape else ([int(p.size)] if p.size else [1]) for p in func.arguments]
shapes_to_sizes = [reduce(lambda x, y: x * y, shape, 1) for shape in numerical_shapes] shapes_to_sizes = [reduce(lambda x, y: x * y, shape, 1) for shape in numerical_shapes]
set_size = reduce(lambda x, y: x + y, map(lambda size, p: size * p.element_num_bytes, shapes_to_sizes, parameters)) set_size = reduce(lambda x, y: x + y, map(lambda size, p: size * p.element_num_bytes, shapes_to_sizes, parameters))