зеркало из https://github.com/microsoft/hat.git
Merge branch 'main' into dev/ritdas/optimizations
This commit is contained in:
Коммит
52c1c4046f
|
@ -33,6 +33,7 @@ class ArgInfo:
|
|||
numpy_dtype: type
|
||||
element_num_bytes: int
|
||||
element_strides: Tuple[int, ...]
|
||||
total_element_count: int
|
||||
total_byte_size: int
|
||||
ctypes_pointer_type: Any
|
||||
usage: hat_file.UsageType = None
|
||||
|
@ -51,11 +52,9 @@ class ArgInfo:
|
|||
self.element_strides = param_description.affine_map
|
||||
self.numpy_strides = tuple([self.element_num_bytes * x for x in self.element_strides])
|
||||
|
||||
def product(l):
|
||||
return reduce(lambda x1, x2: x1*x2, l)
|
||||
|
||||
major_dim = self.element_strides.index(max(self.element_strides))
|
||||
self.total_byte_size = self.element_num_bytes * self.numpy_shape[major_dim] * product(self.element_strides)
|
||||
self.total_element_count = self.numpy_shape[major_dim] * self.element_strides[major_dim]
|
||||
self.total_byte_size = self.element_num_bytes * self.total_element_count
|
||||
|
||||
|
||||
# TODO: Update this to take a HATFunction instead, instead of arg_infos and function_name
|
||||
|
|
|
@ -40,12 +40,9 @@ def generate_input_sets_for_func(func: hat_file.Function, input_sets_minimum_siz
|
|||
|
||||
num_input_sets = (input_sets_minimum_size_MB * 1024 * 1024 // set_size) + 1 + num_additional
|
||||
|
||||
def product(l):
|
||||
return reduce(lambda x1, x2: x1 * x2, l)
|
||||
|
||||
input_sets = [[
|
||||
np.lib.stride_tricks.as_strided(
|
||||
np.random.rand(p.numpy_shape[0] * product(p.element_strides)).astype(p.numpy_dtype),
|
||||
np.random.rand(p.total_element_count).astype(p.numpy_dtype),
|
||||
shape=p.numpy_shape,
|
||||
strides=p.numpy_strides
|
||||
) for p in parameters
|
||||
|
|
Загрузка…
Ссылка в новой задаче