From a23a2be45be4609149594ce9650156f94389e4c9 Mon Sep 17 00:00:00 2001 From: Mason Remy Date: Tue, 10 May 2022 18:18:45 -0700 Subject: [PATCH] Fix array volume computation for arrays larger than 2-D We've been taking a product of strides and multiplying that by the major dimension, but we should be just taking the largest stride and multiplying that by the major dimension, as the largest stride already factors in the other strides in the array shape. --- hatlib/arg_info.py | 7 +++---- hatlib/hat.py | 5 +---- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/hatlib/arg_info.py b/hatlib/arg_info.py index 384aa1a..e685da7 100644 --- a/hatlib/arg_info.py +++ b/hatlib/arg_info.py @@ -33,6 +33,7 @@ class ArgInfo: numpy_dtype: type element_num_bytes: int element_strides: Tuple[int, ...] + total_element_count: int total_byte_size: int ctypes_pointer_type: Any usage: hat_file.UsageType = None @@ -51,11 +52,9 @@ class ArgInfo: self.element_strides = param_description.affine_map self.numpy_strides = tuple([self.element_num_bytes * x for x in self.element_strides]) - def product(l): - return reduce(lambda x1, x2: x1*x2, l) - major_dim = self.element_strides.index(max(self.element_strides)) - self.total_byte_size = self.element_num_bytes * self.numpy_shape[major_dim] * product(self.element_strides) + self.total_element_count = self.numpy_shape[major_dim] * self.element_strides[major_dim] + self.total_byte_size = self.element_num_bytes * self.total_element_count # TODO: Update this to take a HATFunction instead, instead of arg_infos and function_name diff --git a/hatlib/hat.py b/hatlib/hat.py index 1a58ab0..8b7e793 100644 --- a/hatlib/hat.py +++ b/hatlib/hat.py @@ -40,12 +40,9 @@ def generate_input_sets_for_func(func: hat_file.Function, input_sets_minimum_siz num_input_sets = (input_sets_minimum_size_MB * 1024 * 1024 // set_size) + 1 + num_additional - def product(l): - return reduce(lambda x1, x2: x1 * x2, l) - input_sets = [[ np.lib.stride_tricks.as_strided( - np.random.rand(p.numpy_shape[0] * product(p.element_strides)).astype(p.numpy_dtype), + np.random.rand(p.total_element_count).astype(p.numpy_dtype), shape=p.numpy_shape, strides=p.numpy_strides ) for p in parameters