From f90ac1fb86550d8c2f39cc1abac40e51a55c3465 Mon Sep 17 00:00:00 2001 From: Kern Handa Date: Thu, 3 Nov 2022 17:54:29 -0700 Subject: [PATCH] Fixes dynamic shaped benchmark size usage (#80) Fixes dynamic shaped benchmark size usage This change sets the symbolic dimension values using the shapes provided by the dynamic shape callback function. It also fixes incorrectly setting the strides in ArgInfo. --- hatlib/arg_info.py | 6 ++- hatlib/arg_value.py | 18 +++++---- hatlib/benchmark_hat_package.py | 3 +- hatlib/hat.py | 70 +++++++++++++++++++++++++-------- 4 files changed, 71 insertions(+), 26 deletions(-) diff --git a/hatlib/arg_info.py b/hatlib/arg_info.py index 6c3052d..0c269d9 100644 --- a/hatlib/arg_info.py +++ b/hatlib/arg_info.py @@ -81,7 +81,8 @@ class ArgInfo: self.total_element_count = self.shape[major_dim] * self.element_strides[major_dim] else: - self.element_strides = self.numpy_strides = self.shape = [1] + self.shape = [1] + self.element_strides = self.numpy_strides = [self.element_num_bytes] self.total_element_count = 1 self.total_byte_size = self.element_num_bytes * self.total_element_count @@ -97,7 +98,8 @@ class ArgInfo: self.ctypes_type = ctypes_type else: self.ctypes_type = ctypes.POINTER(ctypes_type) - self.element_strides = self.numpy_strides = self.shape = [1] + self.shape = [1] + self.element_strides = self.numpy_strides = [self.element_num_bytes] self.total_element_count = 1 self.total_byte_size = self.element_num_bytes * self.total_element_count diff --git a/hatlib/arg_value.py b/hatlib/arg_value.py index 3efe4ff..a31bbfe 100644 --- a/hatlib/arg_value.py +++ b/hatlib/arg_value.py @@ -82,11 +82,15 @@ class ArgValue: raise ValueError( f"expected argument to have strides={desc_numpy_strides} but received strides={self.value.strides}" ) - elif self.value.size != desc.total_element_count: + else: + # Will raise ValueError if total_element_count can't be converted to int + desc.total_element_count = int(desc.total_element_count) + # special casing for size=1 arrays - raise ValueError( - f"expected argument to have size={desc.total_element_count} but received shape={self.value.size}" - ) + if self.value.size != desc.total_element_count: + raise ValueError( + f"expected argument to have size={desc.total_element_count} but received shape={self.value.size}" + ) else: pass # TODO - support other pointer levels @@ -130,7 +134,7 @@ def get_dimension_arg_indices(array_arg: ArgInfo, all_arguments: List[ArgInfo]) return indices -def generate_arg_values(arguments: List[ArgInfo]) -> List[ArgValue]: +def generate_arg_values(arguments: List[ArgInfo], dim_names_to_values = {}) -> List[ArgValue]: """Generate argument values from argument descriptions Input and input/output affine_arrays: initialized with random inputs Input and input/output runtime_arrays: initialized with arbitrary dimensions and random inputs @@ -140,7 +144,6 @@ def generate_arg_values(arguments: List[ArgInfo]) -> List[ArgValue]: def generate_dim_value(): return random.choice([2, 3, 4]) # example dimension values - dim_names_to_values = {} values = [] for arg in arguments: @@ -165,7 +168,8 @@ def generate_arg_values(arguments: List[ArgInfo]) -> List[ArgValue]: shape.append(generate_dim_value()) dim_names_to_values[d] = ArgValue(dim_args[d], shape[-1]) else: - shape.append(dim_names_to_values[d].value) + v = dim_names_to_values[d].value + shape.append(v if type(v) == int else v[0]) # materialize an array input using the generated shape runtime_array_inputs = np.random.random(tuple(shape)).astype(arg.numpy_dtype) diff --git a/hatlib/benchmark_hat_package.py b/hatlib/benchmark_hat_package.py index 231de30..50497c4 100644 --- a/hatlib/benchmark_hat_package.py +++ b/hatlib/benchmark_hat_package.py @@ -105,7 +105,8 @@ class Benchmark: set_size = 0 for i in input_sets[0]: - set_size += i.value.size * i.value.dtype.itemsize + if not i.dim_values: + set_size += i.value.size * i.value.dtype.itemsize if verbose: print(f"[Benchmarking] Using {len(input_sets)} input sets, each {set_size} bytes") diff --git a/hatlib/hat.py b/hatlib/hat.py index b9a98fc..0815b59 100644 --- a/hatlib/hat.py +++ b/hatlib/hat.py @@ -24,38 +24,76 @@ For example: # call a package function named 'my_func_698b5e5c' package.my_func_698b5e5c(A, B, D, E) """ +import numpy as np + from typing import Callable, List, Tuple, Union from functools import reduce from . import hat_file from . import hat_package -from .arg_value import generate_arg_values +from .arg_value import generate_arg_values, ArgValue from .arg_info import integer_like from .function_info import FunctionInfo +PLACEHOLDER_SIZE = 128 # arbitrary, to be replaced with a better way to estimate size for runtime arrays -PLACEHOLDER_SIZE = 128 # arbitrary, to be replaced with a better way to estimate size for runtime arrays -def generate_arg_sets_for_func(func: hat_file.Function, input_sets_minimum_size_MB: int = 0, num_additional: int = 0, dyn_func_shape_fn: Callable[[FunctionInfo], List[List[int]]]=None): - def default_dyn_func_shape_fn(func: hat_file.Function) -> List[List[int]]: - return [[int(d) if integer_like(d) else PLACEHOLDER_SIZE for d in p.shape] for p in func.arguments] +def generate_arg_sets_for_func( + func: hat_file.Function, + input_sets_minimum_size_MB: int = 0, + num_additional: int = 0, + dyn_func_shape_fn: Callable[[FunctionInfo], List[List[int]]] = None +): + + def default_dyn_func_shape_fn(func_info: FunctionInfo) -> List[List[int]]: + return [[int(d) if integer_like(d) else PLACEHOLDER_SIZE for d in p.shape] for p in func_info.arguments + if p.pointer_level == 1 and p.usage != hat_file.UsageType.Output] func_info = FunctionInfo(func) - - # plug in values for non-constant dimensions in an attempt to estimate the minimum set size - if dyn_func_shape_fn is None: - dyn_func_shape_fn = default_dyn_func_shape_fn - numerical_shapes = dyn_func_shape_fn(func_info) - parameters = func_info.arguments - shapes_to_sizes = [reduce(lambda x, y: x * y, shape) for shape in numerical_shapes] - set_size = reduce( - lambda x, y: x + y, map(lambda size, p: size * p.element_num_bytes, shapes_to_sizes, parameters) - ) + # plug in values for non-constant dimensions in an attempt to estimate the minimum set size + dim_names_to_values = {} + if any(map(lambda p: not p.is_constant_shaped, func_info.arguments)): + if dyn_func_shape_fn is None: + dyn_func_shape_fn = default_dyn_func_shape_fn + numerical_shapes = dyn_func_shape_fn(func_info) + # TODO: We really need to be able to distinguish between args that are dimensions vs. just scalars + shape_idx = 0 + param_idx = 0 + while param_idx < len(parameters): + p = parameters[param_idx] + if p.is_constant_shaped: + if p.pointer_level and p.usage != hat_file.UsageType.Output: + shape_idx += 1 + + param_idx += 1 + continue + + if p.usage == hat_file.UsageType.Output: + param_idx += 1 + continue + + numerical_shape = numerical_shapes[shape_idx] + + # parameter is NOT constant shaped + dyn_dimensions = filter(lambda idx_d: not integer_like(idx_d[1]), enumerate(p.shape)) + for dyn_dim_idx, dyn_dim in dyn_dimensions: + if dyn_dim not in dim_names_to_values: + param_idx += 1 + dim_names_to_values[dyn_dim] = ArgValue(parameters[param_idx], numerical_shape[dyn_dim_idx]) + + param_idx += 1 + shape_idx += 1 + + else: + numerical_shapes = [p.shape if p.shape else [int(p.size)] for p in func.arguments] + + shapes_to_sizes = [reduce(lambda x, y: x * y, shape, 1) for shape in numerical_shapes] + set_size = reduce(lambda x, y: x + y, map(lambda size, p: size * p.element_num_bytes, shapes_to_sizes, parameters)) num_input_sets = (input_sets_minimum_size_MB * 1024 * 1024 // set_size) + 1 + num_additional - arg_sets = [generate_arg_values(parameters) for _ in range(num_input_sets)] + arg_sets = [generate_arg_values(parameters, dim_names_to_values) for _ in range(num_input_sets)] return arg_sets[0] if len(arg_sets) == 1 else arg_sets