Merge branch 'main' into dev/ritdas/optimizations

This commit is contained in:
Ritwik Das 2022-05-11 10:08:52 -07:00
Родитель ddc0711527 5349e07c12
Коммит 52c1c4046f
2 изменённых файлов: 4 добавлений и 8 удалений

Просмотреть файл

@ -33,6 +33,7 @@ class ArgInfo:
numpy_dtype: type
element_num_bytes: int
element_strides: Tuple[int, ...]
total_element_count: int
total_byte_size: int
ctypes_pointer_type: Any
usage: hat_file.UsageType = None
@ -51,11 +52,9 @@ class ArgInfo:
self.element_strides = param_description.affine_map
self.numpy_strides = tuple([self.element_num_bytes * x for x in self.element_strides])
def product(l):
return reduce(lambda x1, x2: x1*x2, l)
major_dim = self.element_strides.index(max(self.element_strides))
self.total_byte_size = self.element_num_bytes * self.numpy_shape[major_dim] * product(self.element_strides)
self.total_element_count = self.numpy_shape[major_dim] * self.element_strides[major_dim]
self.total_byte_size = self.element_num_bytes * self.total_element_count
# TODO: Update this to take a HATFunction instead, instead of arg_infos and function_name

Просмотреть файл

@ -40,12 +40,9 @@ def generate_input_sets_for_func(func: hat_file.Function, input_sets_minimum_siz
num_input_sets = (input_sets_minimum_size_MB * 1024 * 1024 // set_size) + 1 + num_additional
def product(l):
return reduce(lambda x1, x2: x1 * x2, l)
input_sets = [[
np.lib.stride_tricks.as_strided(
np.random.rand(p.numpy_shape[0] * product(p.element_strides)).astype(p.numpy_dtype),
np.random.rand(p.total_element_count).astype(p.numpy_dtype),
shape=p.numpy_shape,
strides=p.numpy_strides
) for p in parameters