зеркало из https://github.com/microsoft/hat.git
Initial verify_hat support for runtime arrays and elements (#65)
* support arbirtary pointer levels for declared types * handle logical_type=element * arg values * refactor (func name, arg_info) to FunctionInfo * fix circular imports * . * support dylib * update verify_hat_package * verify -> verify_args * fixes * fix verify_args * formatting * add verify hat test * refactor * handle ndarrays as arguments * cleanup * runtime_array verify test * print output * basic test passing * Update test_create_simple_hat_file.py * Update test_create_simple_hat_file.py * comments * comments * TODOs * nfc * [test] moved creation to workdir * rename * Print output dimension references and clarify HAT schema (#66) * infer shapes from size, add shape order requirement * merged * pretty print using cross references Co-authored-by: Lisa Ong <onglisa@microsoft.com> * revert formatting changes * revert more formatting only changes * simplify * Add input and input/output runtime_array support (#67) * wip * scaffold * scaffold and initial support for input elements and input/output runtime_arrays * . * fixups * don't swallow exceptions * support cargs for non pointer args * cleanup * refactor * support integer-like types when checking constant shapes * nfc Co-authored-by: Lisa Ong <onglisa@microsoft.com> * test coverage for usage type Input and InputOutput * [test] Support windows in verify_hat tests (#69) * wip * build for windows (#68) Co-authored-by: Lisa Ong <onglisa@microsoft.com> * windows tomlkit expects lists * manual CI trigger Co-authored-by: Lisa Ong <onglisa@microsoft.com> * fix windows test * fix logic and add comment * . * verify_args -> verify * args -> arguments * PR feedback Co-authored-by: Lisa Ong <onglisa@microsoft.com>
This commit is contained in:
Родитель
65743f6e83
Коммит
d99d61d970
|
@ -5,6 +5,8 @@ on:
|
|||
branches: [ main ]
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
# allow manual triggers
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
build:
|
||||
|
|
|
@ -361,4 +361,5 @@ dist/
|
|||
*.egg-info/
|
||||
|
||||
# test
|
||||
test_acccgen
|
||||
test_acccgen
|
||||
test_output
|
|
@ -5,3 +5,4 @@ from .hat_to_dynamic import *
|
|||
from .hat_to_lib import *
|
||||
from .benchmark_hat_package import run_benchmark
|
||||
from .platform_utilities import *
|
||||
from .verify_hat_package import *
|
|
@ -1,42 +1,44 @@
|
|||
import ctypes
|
||||
from functools import reduce
|
||||
import numpy as np
|
||||
import sys
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, List, Tuple
|
||||
from typing import Any, Tuple, Union
|
||||
|
||||
from . import hat_file
|
||||
|
||||
# hat_declared_type : [ ctype, dtype_str ]
|
||||
# element_type : [ ctype, dtype_str ]
|
||||
ARG_TYPES = {
|
||||
"int8_t*" : [ ctypes.c_int8, "int8" ],
|
||||
"int16_t*" : [ ctypes.c_int16, "int16" ],
|
||||
"int32_t*" : [ ctypes.c_int32, "int32" ],
|
||||
"int64_t*" : [ ctypes.c_int64, "int64" ],
|
||||
"uint8_t*" : [ ctypes.c_uint8, "uint8" ],
|
||||
"uint16_t*" : [ ctypes.c_uint16, "uint16" ],
|
||||
"uint32_t*" : [ ctypes.c_uint32, "uint32" ],
|
||||
"uint64_t*" : [ ctypes.c_uint64, "uint64" ],
|
||||
"float16_t*" : [ ctypes.c_uint16, "float16" ], # same bitwidth as uint16
|
||||
"bfloat16_t*" : [ ctypes.c_uint16, "bfloat16" ],
|
||||
"float*" : [ ctypes.c_float, "float32" ],
|
||||
"double*" : [ ctypes.c_double, "float64" ],
|
||||
"int8_t": [ctypes.c_int8, "int8"],
|
||||
"int16_t": [ctypes.c_int16, "int16"],
|
||||
"int32_t": [ctypes.c_int32, "int32"],
|
||||
"int64_t": [ctypes.c_int64, "int64"],
|
||||
"uint8_t": [ctypes.c_uint8, "uint8"],
|
||||
"uint16_t": [ctypes.c_uint16, "uint16"],
|
||||
"uint32_t": [ctypes.c_uint32, "uint32"],
|
||||
"uint64_t": [ctypes.c_uint64, "uint64"],
|
||||
"float16_t": [ctypes.c_uint16, "float16"], # same bitwidth as uint16
|
||||
"bfloat16_t": [ctypes.c_uint16, "bfloat16"],
|
||||
"float": [ctypes.c_float, "float32"],
|
||||
"double": [ctypes.c_double, "float64"],
|
||||
}
|
||||
CTYPE_ENTRY = 0
|
||||
DTYPE_ENTRY = 1
|
||||
|
||||
|
||||
@dataclass
|
||||
class ArgInfo:
|
||||
"""Extracts necessary information from the description of a function argument in a hat file"""
|
||||
name: str
|
||||
hat_declared_type: str
|
||||
numpy_shape: Tuple[int, ...]
|
||||
shape: Tuple[Union[int, str], ...] # int for affine_arrays, str symbols for runtime_arrays
|
||||
numpy_strides: Tuple[int, ...]
|
||||
numpy_dtype: type
|
||||
element_num_bytes: int
|
||||
element_strides: Tuple[int, ...]
|
||||
total_element_count: int
|
||||
total_byte_size: int
|
||||
ctypes_pointer_type: Any
|
||||
total_element_count: Union[int, str]
|
||||
total_byte_size: Union[int, str]
|
||||
ctypes_type: Any
|
||||
pointer_level: int
|
||||
usage: hat_file.UsageType = None
|
||||
|
||||
def _get_type(self, type_str):
|
||||
|
@ -46,66 +48,70 @@ class ArgInfo:
|
|||
|
||||
return np.dtype(type_str)
|
||||
|
||||
def _get_pointer_level(self, declared_type: str):
|
||||
pos = declared_type.find("*")
|
||||
if pos == -1:
|
||||
return 0
|
||||
return declared_type[pos:].count("*")
|
||||
|
||||
def __init__(self, param_description: hat_file.Parameter):
|
||||
self.name = param_description.name
|
||||
self.hat_declared_type = param_description.declared_type
|
||||
self.numpy_shape = tuple(param_description.shape)
|
||||
self.shape = tuple(param_description.shape)
|
||||
self.usage = param_description.usage
|
||||
self.pointer_level = self._get_pointer_level(self.hat_declared_type)
|
||||
element_type = self.hat_declared_type[:(-1 * self.pointer_level)] \
|
||||
if self.pointer_level else self.hat_declared_type
|
||||
|
||||
if not self.hat_declared_type in ARG_TYPES:
|
||||
raise NotImplementedError(f"Unsupported declared_type {self.hat_declared_type} in hat file")
|
||||
if not element_type in ARG_TYPES:
|
||||
raise NotImplementedError(f"Unsupported element_type {element_type} in hat file")
|
||||
|
||||
self.ctypes_pointer_type = ctypes.POINTER(ARG_TYPES[self.hat_declared_type][CTYPE_ENTRY])
|
||||
dtype_entry = ARG_TYPES[self.hat_declared_type][DTYPE_ENTRY]
|
||||
ctypes_type = ARG_TYPES[element_type][CTYPE_ENTRY]
|
||||
dtype_entry = ARG_TYPES[element_type][DTYPE_ENTRY]
|
||||
self.numpy_dtype = self._get_type(dtype_entry)
|
||||
self.element_num_bytes = 2 if dtype_entry == "bfloat16" else self.numpy_dtype.itemsize
|
||||
|
||||
if self.numpy_shape:
|
||||
self.element_strides = param_description.affine_map
|
||||
self.numpy_strides = tuple([self.element_num_bytes * x for x in self.element_strides])
|
||||
if param_description.logical_type == hat_file.ParameterType.AffineArray:
|
||||
self.ctypes_type = ctypes.POINTER(ctypes_type)
|
||||
if self.shape:
|
||||
self.element_strides = param_description.affine_map
|
||||
self.numpy_strides = tuple([self.element_num_bytes * x for x in self.element_strides])
|
||||
|
||||
major_dim = self.element_strides.index(max(self.element_strides))
|
||||
self.total_element_count = self.numpy_shape[major_dim] * self.element_strides[major_dim]
|
||||
|
||||
else:
|
||||
self.element_strides = self.numpy_strides = self.numpy_shape = [1]
|
||||
major_dim = self.element_strides.index(max(self.element_strides))
|
||||
self.total_element_count = self.shape[major_dim] * self.element_strides[major_dim]
|
||||
|
||||
else:
|
||||
self.element_strides = self.numpy_strides = self.shape = [1]
|
||||
self.total_element_count = 1
|
||||
self.total_byte_size = self.element_num_bytes * self.total_element_count
|
||||
|
||||
elif param_description.logical_type == hat_file.ParameterType.RuntimeArray:
|
||||
self.ctypes_type = ctypes.POINTER(ctypes_type)
|
||||
self.total_byte_size = f"{self.element_num_bytes} * {param_description.size}"
|
||||
self.total_element_count = param_description.size
|
||||
# assume the sizes are in shape order
|
||||
self.shape = re.split(r"\s?\*\s?", param_description.size)
|
||||
|
||||
elif param_description.logical_type == hat_file.ParameterType.Element:
|
||||
if param_description.usage == hat_file.UsageType.Input:
|
||||
self.ctypes_type = ctypes_type
|
||||
else:
|
||||
self.ctypes_type = ctypes.POINTER(ctypes_type)
|
||||
self.element_strides = self.numpy_strides = self.shape = [1]
|
||||
self.total_element_count = 1
|
||||
self.total_byte_size = self.element_num_bytes * self.total_element_count
|
||||
|
||||
self.total_byte_size = self.element_num_bytes * self.total_element_count
|
||||
else:
|
||||
raise ValueError(f"Unknown logical type {param_description.logical_type} in hat file")
|
||||
|
||||
@property
|
||||
def is_constant_shaped(self):
|
||||
|
||||
# TODO: Update this to take a HATFunction instead, instead of arg_infos and function_name
|
||||
def verify_args(args: List, arg_infos: List[ArgInfo], function_name: str):
|
||||
""" Verifies that a list of arguments matches a list of argument descriptions in a HAT file
|
||||
"""
|
||||
# check number of args
|
||||
if len(args) != len(arg_infos):
|
||||
sys.exit(f"Error calling {function_name}(...): expected {len(arg_infos)} arguments but received {len(args)}")
|
||||
def integer_like(s: Any):
|
||||
# handle types such as tomlkit.items.Integer
|
||||
try:
|
||||
return int(s) == s
|
||||
except:
|
||||
return False
|
||||
|
||||
# for each arg
|
||||
for i in range(len(args)):
|
||||
arg = args[i]
|
||||
arg_info = arg_infos[i]
|
||||
|
||||
# confirm that the arg is a numpy ndarray
|
||||
if not isinstance(arg, np.ndarray):
|
||||
sys.exit(
|
||||
"Error calling {function_name}(...): expected argument {i} to be <class 'numpy.ndarray'> but received {type(arg)}"
|
||||
)
|
||||
|
||||
# confirm that the arg dtype matches the dexcription in the hat package
|
||||
if arg_info.numpy_dtype != arg.dtype:
|
||||
sys.exit(
|
||||
f"Error calling {function_name}(...): expected argument {i} to have dtype={arg_info.numpy_dtype} but received dtype={arg.dtype}"
|
||||
)
|
||||
|
||||
# confirm that the arg shape is correct
|
||||
if arg_info.numpy_shape != arg.shape:
|
||||
sys.exit(
|
||||
f"Error calling {function_name}(...): expected argument {i} to have shape={arg_info.numpy_shape} but received shape={arg.shape}"
|
||||
)
|
||||
|
||||
# confirm that the arg strides are correct
|
||||
if arg_info.numpy_strides != arg.strides:
|
||||
sys.exit(
|
||||
f"Error calling {function_name}(...): expected argument {i} to have strides={arg_info.numpy_strides} but received strides={arg.strides}"
|
||||
)
|
||||
return all(integer_like(s) for s in self.shape)
|
|
@ -0,0 +1,170 @@
|
|||
from typing import Any, List
|
||||
from ctypes import byref
|
||||
import numpy as np
|
||||
import random
|
||||
|
||||
from .arg_info import ArgInfo
|
||||
from . import hat_file
|
||||
|
||||
|
||||
class ArgValue:
|
||||
"""An argument containing a scalar, ndarray, or pointer value.
|
||||
Used for calling HAT functions from ctypes"""
|
||||
|
||||
def __init__(self, arg_info: ArgInfo, value: Any = None):
|
||||
# TODO: set the free and alloc function symbols here?
|
||||
self.arg_info = arg_info
|
||||
self.pointer_level = arg_info.pointer_level
|
||||
self.ctypes_type = arg_info.ctypes_type
|
||||
|
||||
if self.pointer_level > 2: # punt until we really need this
|
||||
raise NotImplementedError("Pointer levels > 2 are not supported")
|
||||
|
||||
self.value = value
|
||||
if self.value is None:
|
||||
if not self.pointer_level:
|
||||
raise ValueError("A value is required for non-pointers")
|
||||
else:
|
||||
# no value provided, allocate the pointer
|
||||
self.allocate()
|
||||
self.dim_values = None
|
||||
|
||||
def allocate(self):
|
||||
if not self.pointer_level:
|
||||
return # nothing to do
|
||||
if self.value:
|
||||
return # value already assigned, nothing to do
|
||||
|
||||
if self.pointer_level == 1:
|
||||
# allocate an ndarray with random input values
|
||||
self.value = np.lib.stride_tricks.as_strided(
|
||||
np.random.rand(self.arg_info.total_element_count).astype(self.arg_info.numpy_dtype),
|
||||
shape=self.arg_info.shape,
|
||||
strides=self.arg_info.numpy_strides
|
||||
)
|
||||
elif self.pointer_level == 2:
|
||||
# allocate a pointer. HAT function will perform the actual allocation.
|
||||
self.value = self.ctypes_type()
|
||||
|
||||
def as_carg(self):
|
||||
"Return the C interface for this argument"
|
||||
if self.pointer_level:
|
||||
if isinstance(self.value, np.ndarray):
|
||||
return self.value.ctypes.data_as(self.ctypes_type)
|
||||
else:
|
||||
return byref(self.value)
|
||||
else:
|
||||
return self.ctypes_type(self.value)
|
||||
|
||||
def verify(self, desc: ArgInfo):
|
||||
"Verifies that this argument matches an argument description"
|
||||
if desc.pointer_level == 1:
|
||||
if not isinstance(self.value, np.ndarray):
|
||||
raise ValueError(f"expected argument to be <class 'numpy.ndarray'> but received {type(self.value)}")
|
||||
|
||||
if desc.numpy_dtype != self.value.dtype:
|
||||
raise ValueError(
|
||||
f"expected argument to have dtype={desc.numpy_dtype} but received dtype={self.value.dtype}"
|
||||
)
|
||||
|
||||
if desc.is_constant_shaped:
|
||||
# confirm that the arg shape is correct (numpy represents shapes as tuples)
|
||||
if tuple(desc.shape) != self.value.shape:
|
||||
raise ValueError(
|
||||
f"expected argument to have shape={desc.shape} but received shape={self.value.shape}"
|
||||
)
|
||||
|
||||
# confirm that the arg strides are correct (numpy represents strides as tuples)
|
||||
if tuple(desc.numpy_strides) != self.value.strides:
|
||||
raise ValueError(
|
||||
f"expected argument to have strides={desc.numpy_strides} but received strides={self.value.strides}"
|
||||
)
|
||||
else:
|
||||
pass # TODO - support other pointer levels
|
||||
|
||||
def __repr__(self):
|
||||
if self.pointer_level:
|
||||
if isinstance(self.value, np.ndarray):
|
||||
return ",".join(map(str, self.value.ravel()[:32]))
|
||||
else:
|
||||
try:
|
||||
if self.dim_values:
|
||||
# cross-reference the dimension output values to pretty print the output
|
||||
shape = [d.value[0] for d in self.dim_values] # stored as single-element ndarrays
|
||||
s = repr(np.ctypeslib.as_array(self.value, shape))
|
||||
else:
|
||||
s = repr(self.value.contents)
|
||||
except Exception as e:
|
||||
if e.args[0].startswith("NULL pointer"):
|
||||
s = f"{repr(self.value)} nullptr"
|
||||
else:
|
||||
raise (e)
|
||||
return s
|
||||
else:
|
||||
return repr(self.value)
|
||||
|
||||
def __del__(self):
|
||||
if self.pointer_level == 2:
|
||||
pass # TODO - free the pointer, presumably calling a symbol passed into this ArgValue
|
||||
|
||||
|
||||
def get_dimension_arg_indices(array_arg: ArgInfo, all_arguments: List[ArgInfo]) -> List[int]:
|
||||
# Returns the dimension argument indices in shape order for an array argument
|
||||
indices = []
|
||||
for sym_name in array_arg.shape:
|
||||
for i, info in enumerate(all_arguments):
|
||||
if info.name == sym_name: # limitation: only string shapes are supported
|
||||
indices.append(i)
|
||||
break
|
||||
else:
|
||||
# not found
|
||||
raise RuntimeError(f"{sym_name} is not an argument to the function") # likely an invalid HAT file
|
||||
return indices
|
||||
|
||||
|
||||
def generate_arg_values(arguments: List[ArgInfo]) -> List[ArgValue]:
|
||||
"""Generate argument values from argument descriptions
|
||||
Input and input/output affine_arrays: initialized with random inputs
|
||||
Input and input/output runtime_arrays: initialized with arbitrary dimensions and random inputs
|
||||
Output elements and runtime_arrays: pointers are allocated
|
||||
"""
|
||||
|
||||
def generate_dim_value():
|
||||
return random.choice([128, 256, 1234]) # example dimension values
|
||||
|
||||
dim_names_to_values = {}
|
||||
values = []
|
||||
|
||||
for arg in arguments:
|
||||
if arg.usage != hat_file.UsageType.Output and not arg.is_constant_shaped:
|
||||
# input runtime arrays
|
||||
dim_args = [arguments[i] for i in get_dimension_arg_indices(arg, arguments)]
|
||||
|
||||
# assign generated shape values to the corresponding dimension arguments
|
||||
shape = []
|
||||
for d in dim_args:
|
||||
if d.name not in dim_names_to_values:
|
||||
shape.append(generate_dim_value())
|
||||
dim_names_to_values[d.name] = ArgValue(d, shape[-1])
|
||||
else:
|
||||
shape.append(dim_names_to_values[d.name].value)
|
||||
|
||||
# materialize an array input using the generated shape
|
||||
runtime_array_inputs = np.random.random(tuple(shape)).astype(arg.numpy_dtype)
|
||||
values.append(ArgValue(arg, runtime_array_inputs))
|
||||
|
||||
elif arg.name in dim_names_to_values:
|
||||
# input element that is a dimension value (populated when its input runtime array is created)
|
||||
values.append(dim_names_to_values[arg.name])
|
||||
else:
|
||||
# everything else is known size or a pointer
|
||||
values.append(ArgValue(arg))
|
||||
|
||||
# collect the dimension ArgValues for each output runtime_array ArgValue
|
||||
for value in values:
|
||||
if value.arg_info.usage == hat_file.UsageType.Output and not value.arg_info.is_constant_shaped:
|
||||
dim_values = [values[i] for i in get_dimension_arg_indices(value.arg_info, arguments)]
|
||||
assert dim_values, f"Runtime array {value.arg_info.name} has no dimensions"
|
||||
value.dim_values = dim_values
|
||||
|
||||
return values
|
|
@ -9,7 +9,7 @@ import traceback
|
|||
|
||||
from .callable_func import CallableFunc
|
||||
from .hat_file import HATFile
|
||||
from .hat import load, generate_input_sets_for_func
|
||||
from .hat import load, generate_arg_sets_for_func
|
||||
|
||||
|
||||
class Benchmark:
|
||||
|
@ -84,13 +84,13 @@ class Benchmark:
|
|||
if not isinstance(benchmark_func, CallableFunc):
|
||||
# generate sufficient input sets to overflow the L3 cache, since we don't know the size of the model
|
||||
# we'll make a guess based on the minimum input set size
|
||||
input_sets = generate_input_sets_for_func(func,
|
||||
input_sets = generate_arg_sets_for_func(func,
|
||||
input_sets_minimum_size_MB,
|
||||
num_additional=10)
|
||||
|
||||
set_size = 0
|
||||
for i in input_sets[0]:
|
||||
set_size += i.size * i.dtype.itemsize
|
||||
set_size += i.value.size * i.value.dtype.itemsize
|
||||
|
||||
if verbose:
|
||||
print(f"[Benchmarking] Using {len(input_sets)} input sets, each {set_size} bytes")
|
||||
|
@ -129,11 +129,11 @@ class Benchmark:
|
|||
else:
|
||||
if verbose:
|
||||
print(f"[Benchmarking] Benchmarking device function on gpu {gpu_id}. {batch_size} batches of warming up for {warmup_iterations} and then measuring with {min_timing_iterations} iterations.")
|
||||
input_sets = generate_input_sets_for_func(func)
|
||||
input_sets = generate_arg_sets_for_func(func)
|
||||
|
||||
set_size = 0
|
||||
for i in input_sets:
|
||||
set_size += i.size * i.dtype.itemsize
|
||||
set_size += i.value.size * i.value.dtype.itemsize
|
||||
|
||||
if verbose:
|
||||
print(f"[Benchmarking] Using input of {set_size} bytes")
|
||||
|
|
|
@ -4,8 +4,9 @@ import sys
|
|||
import numpy as np
|
||||
from typing import List
|
||||
from cuda import cuda, nvrtc
|
||||
from .arg_info import ArgInfo, verify_args
|
||||
from .arg_info import ArgInfo
|
||||
from .callable_func import CallableFunc
|
||||
from .function_info import FunctionInfo
|
||||
from .hat_file import Function
|
||||
|
||||
|
||||
|
@ -38,6 +39,7 @@ def _find_cuda_incl_path() -> pathlib.Path:
|
|||
|
||||
return cuda_path
|
||||
|
||||
|
||||
def _get_compute_capability(gpu_id) -> int:
|
||||
err, major = cuda.cuDeviceGetAttribute(cuda.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, gpu_id)
|
||||
ASSERT_DRV(err)
|
||||
|
@ -47,6 +49,7 @@ def _get_compute_capability(gpu_id) -> int:
|
|||
|
||||
return (major * 10) + minor
|
||||
|
||||
|
||||
def compile_cuda_program(cuda_src_path: pathlib.Path, func_name, gpu_id):
|
||||
src = cuda_src_path.read_text()
|
||||
|
||||
|
@ -55,15 +58,15 @@ def compile_cuda_program(cuda_src_path: pathlib.Path, func_name, gpu_id):
|
|||
raise RuntimeError("Unable to determine CUDA include path. Please set CUDA_PATH environment variable.")
|
||||
|
||||
opts = [
|
||||
# https://docs.nvidia.com/cuda/nvrtc/index.html#group__options
|
||||
# https://docs.nvidia.com/cuda/nvrtc/index.html#group__options
|
||||
f'--gpu-architecture=compute_{_get_compute_capability(gpu_id)}'.encode(),
|
||||
b'--ptxas-options=--warn-on-spills', # https://docs.nvidia.com/cuda/cuda-compiler-driver-nvcc/index.html#options-for-passing-specific-phase-options-ptxas-options
|
||||
b'--ptxas-options=--warn-on-spills', # https://docs.nvidia.com/cuda/cuda-compiler-driver-nvcc/index.html#options-for-passing-specific-phase-options-ptxas-options
|
||||
b'-use_fast_math',
|
||||
b'--include-path=' + str(cuda_incl_path).encode(),
|
||||
b'-std=c++17',
|
||||
b'-default-device',
|
||||
#b'--restrict',
|
||||
#b'--device-int128'
|
||||
#b'--restrict',
|
||||
#b'--device-int128'
|
||||
]
|
||||
|
||||
# Create program
|
||||
|
@ -180,10 +183,8 @@ class CudaCallableFunc(CallableFunc):
|
|||
def __init__(self, func: Function, cuda_src_path: str) -> None:
|
||||
super().__init__()
|
||||
self.hat_func = func
|
||||
self.func_name = func.name
|
||||
self.func_info = FunctionInfo(func)
|
||||
self.kernel = None
|
||||
hat_arg_descriptions = func.arguments
|
||||
self.arg_infos = [ArgInfo(d) for d in hat_arg_descriptions]
|
||||
self.launch_params = func.launch_parameters
|
||||
self.device_mem = None
|
||||
self.ptrs = None
|
||||
|
@ -198,19 +199,19 @@ class CudaCallableFunc(CallableFunc):
|
|||
|
||||
ptx = _PTX_CACHE.get(self.cuda_src_path)
|
||||
if not ptx:
|
||||
_PTX_CACHE[self.cuda_src_path] = ptx = compile_cuda_program(self.cuda_src_path, self.func_name, gpu_id)
|
||||
_PTX_CACHE[self.cuda_src_path] = ptx = compile_cuda_program(self.cuda_src_path, self.func_info.name, gpu_id)
|
||||
|
||||
self.kernel = get_func_from_ptx(ptx, self.func_name)
|
||||
self.kernel = get_func_from_ptx(ptx, self.func_info.name)
|
||||
|
||||
def cleanup_runtime(self, benchmark: bool):
|
||||
cuda.cuCtxDestroy(self.context)
|
||||
|
||||
def init_main(self, benchmark: bool, warmup_iters=0, args=[], gpu_id: int=0):
|
||||
verify_args(args, self.arg_infos, self.func_name)
|
||||
self.device_mem = allocate_cuda_mem(self.arg_infos)
|
||||
def init_main(self, benchmark: bool, warmup_iters=0, args=[], gpu_id: int = 0):
|
||||
self.func_info.verify(args)
|
||||
self.device_mem = allocate_cuda_mem(self.func_info.arguments)
|
||||
|
||||
if not benchmark:
|
||||
transfer_mem_host_to_cuda(device_args=self.device_mem, host_args=args, arg_infos=self.arg_infos)
|
||||
transfer_mem_host_to_cuda(device_args=self.device_mem, host_args=args, arg_infos=self.func_info.arguments)
|
||||
|
||||
self.ptrs = device_args_to_ptr_list(self.device_mem)
|
||||
|
||||
|
@ -270,7 +271,7 @@ class CudaCallableFunc(CallableFunc):
|
|||
def cleanup_main(self, benchmark: bool, args=[]):
|
||||
# If there's no device mem, that means allocation during initialization failed, which means nothing else needs to be cleaned up either
|
||||
if not benchmark and self.device_mem:
|
||||
transfer_mem_cuda_to_host(device_args=self.device_mem, host_args=args, arg_infos=self.arg_infos)
|
||||
transfer_mem_cuda_to_host(device_args=self.device_mem, host_args=args, arg_infos=self.func_info.arguments)
|
||||
if self.device_mem:
|
||||
free_cuda_mem(self.device_mem)
|
||||
err, = cuda.cuCtxSynchronize()
|
||||
|
|
|
@ -0,0 +1,46 @@
|
|||
import numpy as np
|
||||
import sys
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, List
|
||||
|
||||
from .arg_info import ArgInfo
|
||||
from .arg_value import ArgValue
|
||||
from . import hat_file
|
||||
|
||||
|
||||
@dataclass
|
||||
class FunctionInfo:
|
||||
"Information about a HAT function"
|
||||
desc: hat_file.Function
|
||||
arguments: List[ArgInfo] = field(default_factory=list)
|
||||
name: str = ""
|
||||
|
||||
def __post_init__(self):
|
||||
self.name = self.desc.name
|
||||
self.arguments = list(map(ArgInfo, self.desc.arguments))
|
||||
|
||||
def verify(self, args: List[Any]):
|
||||
"Verifies that a list of argument values matches the function description"
|
||||
if len(args) != len(self.arguments):
|
||||
sys.exit(
|
||||
f"Error calling {self.name}(...): expected {len(self.arguments)} arguments but received {len(args)}"
|
||||
)
|
||||
|
||||
for i, (info, value) in enumerate(zip(self.arguments, args)):
|
||||
try:
|
||||
if isinstance(value, np.ndarray):
|
||||
value = ArgValue(info, value)
|
||||
|
||||
value.verify(info)
|
||||
except ValueError as v:
|
||||
sys.exit(f"Error calling {self.name}(...): argument {i} failed verification: {v}")
|
||||
|
||||
def as_cargs(self, args: List[Any]):
|
||||
"Converts arguments to their C interfaces"
|
||||
arg_values = [
|
||||
ArgValue(info, value) if isinstance(value, np.ndarray) else value
|
||||
for info, value in zip(self.arguments, args)
|
||||
]
|
||||
|
||||
return [value.as_carg() for value in arg_values]
|
||||
|
|
@ -24,36 +24,36 @@ For example:
|
|||
# call a package function named 'my_func_698b5e5c'
|
||||
package.my_func_698b5e5c(A, B, D, E)
|
||||
"""
|
||||
import numpy as np
|
||||
from typing import Tuple, Union
|
||||
from functools import reduce
|
||||
|
||||
from . import hat_file
|
||||
from . import hat_package
|
||||
from .arg_info import ArgInfo
|
||||
from .arg_value import generate_arg_values
|
||||
from .function_info import FunctionInfo
|
||||
|
||||
|
||||
def generate_input_sets_for_func(func: hat_file.Function, input_sets_minimum_size_MB: int = 0, num_additional: int = 0):
|
||||
parameters = list(map(ArgInfo, func.arguments))
|
||||
shapes_to_sizes = [reduce(lambda x, y: x * y, p.numpy_shape) for p in parameters]
|
||||
set_size = reduce(lambda x, y: x + y, map(lambda size, p: size * p.element_num_bytes, shapes_to_sizes, parameters))
|
||||
def generate_arg_sets_for_func(func: hat_file.Function, input_sets_minimum_size_MB: int = 0, num_additional: int = 0):
|
||||
func_info = FunctionInfo(func)
|
||||
parameters = func_info.arguments
|
||||
|
||||
# use constant-sized params to estimate the minimum set size
|
||||
const_sized_parameters = list(filter(lambda p: p.is_constant_shaped, parameters))
|
||||
|
||||
shapes_to_sizes = [reduce(lambda x, y: x * y, p.shape) for p in const_sized_parameters]
|
||||
set_size = reduce(
|
||||
lambda x, y: x + y, map(lambda size, p: size * p.element_num_bytes, shapes_to_sizes, const_sized_parameters)
|
||||
)
|
||||
num_input_sets = (input_sets_minimum_size_MB * 1024 * 1024 // set_size) + 1 + num_additional
|
||||
|
||||
input_sets = [[
|
||||
np.lib.stride_tricks.as_strided(
|
||||
np.random.rand(p.total_element_count).astype(p.numpy_dtype),
|
||||
shape=p.numpy_shape,
|
||||
strides=p.numpy_strides
|
||||
) for p in parameters
|
||||
] for _ in range(num_input_sets)]
|
||||
arg_sets = [generate_arg_values(parameters) for _ in range(num_input_sets)]
|
||||
|
||||
return input_sets[0] if len(input_sets) == 1 else input_sets
|
||||
return arg_sets[0] if len(arg_sets) == 1 else arg_sets
|
||||
|
||||
|
||||
def generate_input_sets_for_hat_file(hat_path):
|
||||
def generate_arg_sets_for_hat_file(hat_path):
|
||||
t = hat_file.HATFile.Deserialize(hat_path)
|
||||
return {func_name: generate_input_sets_for_func(func_desc)
|
||||
return {func_name: generate_arg_sets_for_func(func_desc)
|
||||
for func_name, func_desc in t.function_map.items()}
|
||||
|
||||
|
||||
|
|
|
@ -1,17 +1,12 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
# Utility to parse and validate a HAT package
|
||||
|
||||
import ctypes
|
||||
from typing import Any, List, Union
|
||||
from collections import OrderedDict
|
||||
from functools import partial
|
||||
|
||||
from .hat_file import HATFile, Function, Parameter
|
||||
from .arg_info import ArgInfo, verify_args
|
||||
|
||||
import os
|
||||
|
||||
from .hat_file import HATFile, Function
|
||||
from .function_info import FunctionInfo
|
||||
|
||||
|
||||
class HATPackage:
|
||||
|
||||
|
@ -66,17 +61,16 @@ class AttributeDict(OrderedDict):
|
|||
return OrderedDict.__getitem__(key)
|
||||
|
||||
|
||||
def _make_cpu_func(shared_lib: ctypes.CDLL, function_name: str, arg_infos: List[Parameter]):
|
||||
arg_infos = [ArgInfo(d) for d in arg_infos]
|
||||
fn = shared_lib[function_name]
|
||||
def _make_cpu_func(shared_lib: ctypes.CDLL, func: Function):
|
||||
func_info = FunctionInfo(func)
|
||||
fn = shared_lib[func_info.name]
|
||||
|
||||
def f(*args):
|
||||
# verify that the (numpy) input args match the description in
|
||||
# the hat file
|
||||
verify_args(args, arg_infos, function_name)
|
||||
# verify that the args match the description in the hat file
|
||||
func_info.verify(args)
|
||||
|
||||
# prepare the args to the hat package
|
||||
hat_args = [arg.ctypes.data_as(arg_info.ctypes_pointer_type) for arg, arg_info in zip(args, arg_infos)]
|
||||
hat_args = func_info.as_cargs(args)
|
||||
|
||||
# call the function in the hat package
|
||||
fn(*hat_args)
|
||||
|
@ -97,7 +91,7 @@ def _load_pkg_binary_module(hat_pkg: HATPackage):
|
|||
shared_lib = None
|
||||
if os.path.isfile(hat_pkg.link_target_path):
|
||||
|
||||
supported_extensions = [".dll", ".so"]
|
||||
supported_extensions = [".dll", ".so", ".dylib"]
|
||||
_, extension = os.path.splitext(hat_pkg.link_target_path)
|
||||
|
||||
if extension and extension not in supported_extensions:
|
||||
|
@ -149,7 +143,7 @@ def hat_package_to_func_dict(hat_pkg: HATPackage) -> AttributeDict:
|
|||
launches = func_desc.launches
|
||||
if not launches and shared_lib:
|
||||
|
||||
func_dict[func_name] = _make_cpu_func(shared_lib, func_desc.name, func_desc.arguments)
|
||||
func_dict[func_name] = _make_cpu_func(shared_lib, func_desc)
|
||||
else:
|
||||
device_func = hat_pkg.hat_file.device_function_map.get(launches)
|
||||
|
||||
|
|
|
@ -3,8 +3,9 @@ import pathlib
|
|||
import numpy as np
|
||||
from typing import List
|
||||
|
||||
from .arg_info import ArgInfo, verify_args
|
||||
from .arg_info import ArgInfo
|
||||
from .callable_func import CallableFunc
|
||||
from .function_info import FunctionInfo
|
||||
from .hat_file import Function
|
||||
from .gpu_headers import ROCM_HEADER_MAP
|
||||
from .pyhip.hip import *
|
||||
|
@ -37,7 +38,9 @@ def get_func_from_rocm_program(rocm_program, func_name):
|
|||
kernel = hipModuleGetFunction(rocm_module, func_name)
|
||||
return kernel
|
||||
|
||||
cached_mem=[]
|
||||
|
||||
cached_mem = []
|
||||
|
||||
|
||||
def allocate_rocm_mem(benchmark: bool, arg_infos: List[ArgInfo], gpu_id: int):
|
||||
device_mem = []
|
||||
|
@ -90,10 +93,8 @@ class RocmCallableFunc(CallableFunc):
|
|||
def __init__(self, func: Function, rocm_src_path: str) -> None:
|
||||
super().__init__()
|
||||
self.hat_func = func
|
||||
self.func_name = func.name
|
||||
self.func_info = FunctionInfo(func)
|
||||
self.kernel = None
|
||||
hat_arg_descriptions = func.arguments
|
||||
self.arg_infos = [ArgInfo(d) for d in hat_arg_descriptions]
|
||||
self.launch_params = func.launch_parameters
|
||||
self.device_mem = None
|
||||
self.ptrs = None
|
||||
|
@ -115,22 +116,23 @@ class RocmCallableFunc(CallableFunc):
|
|||
|
||||
rocm_program = _HSACO_CACHE.get(self.rocm_src_path)
|
||||
if not rocm_program:
|
||||
_HSACO_CACHE[self.rocm_src_path] = rocm_program = compile_rocm_program(self.rocm_src_path, self.func_name)
|
||||
_HSACO_CACHE[self.rocm_src_path
|
||||
] = rocm_program = compile_rocm_program(self.rocm_src_path, self.func_info.name)
|
||||
|
||||
self.kernel = get_func_from_rocm_program(rocm_program, self.func_name)
|
||||
self.kernel = get_func_from_rocm_program(rocm_program, self.func_info.name)
|
||||
|
||||
def cleanup_runtime(self, benchmark: bool):
|
||||
pass
|
||||
|
||||
def init_main(self, benchmark: bool, warmup_iters=0, args=[], gpu_id: int=0):
|
||||
verify_args(args, self.arg_infos, self.func_name)
|
||||
self.device_mem = allocate_rocm_mem(benchmark, self.arg_infos, gpu_id)
|
||||
def init_main(self, benchmark: bool, warmup_iters=0, args=[], gpu_id: int = 0):
|
||||
self.func_info.verify(args)
|
||||
self.device_mem = allocate_rocm_mem(benchmark, self.func_info.arguments, gpu_id)
|
||||
|
||||
if not benchmark:
|
||||
transfer_mem_host_to_rocm(device_args=self.device_mem, host_args=args, arg_infos=self.arg_infos)
|
||||
transfer_mem_host_to_rocm(device_args=self.device_mem, host_args=args, arg_infos=self.func_info.arguments)
|
||||
|
||||
class DataStruct(ctypes.Structure):
|
||||
_fields_ = [(f"arg{i}", ctypes.c_void_p) for i in range(len(self.arg_infos))]
|
||||
_fields_ = [(f"arg{i}", ctypes.c_void_p) for i in range(len(self.func_info.arguments))]
|
||||
|
||||
self.data = DataStruct(*self.device_mem)
|
||||
|
||||
|
@ -175,7 +177,7 @@ class RocmCallableFunc(CallableFunc):
|
|||
def cleanup_main(self, benchmark: bool, args=[]):
|
||||
# If there's no device mem, that means allocation during initialization failed, which means nothing else needs to be cleaned up either
|
||||
if not benchmark and self.device_mem:
|
||||
transfer_mem_rocm_to_host(device_args=self.device_mem, host_args=args, arg_infos=self.arg_infos)
|
||||
transfer_mem_rocm_to_host(device_args=self.device_mem, host_args=args, arg_infos=self.func_info.arguments)
|
||||
free_rocm_mem(self.device_mem)
|
||||
hipDeviceSynchronize()
|
||||
|
||||
|
|
|
@ -7,33 +7,32 @@ from . import hat
|
|||
|
||||
def verify_hat_package(hat_path):
|
||||
_, funcs = hat.load(hat_path)
|
||||
inputs = hat.generate_input_sets_for_hat_file(hat_path)
|
||||
args = hat.generate_arg_sets_for_hat_file(hat_path)
|
||||
for name, fn in funcs.items():
|
||||
print(f"\n{'*' * 10}\n")
|
||||
|
||||
print(f"[*] Verifying function {name} --")
|
||||
func_inputs = inputs[name]
|
||||
|
||||
print("[*] Inputs before function call:")
|
||||
for i, func_input in enumerate(func_inputs):
|
||||
print(f"[*]\tInput {i}: {','.join(map(str, func_input.ravel()[:32]))}")
|
||||
print(f"[*] Verifying function {name} --")
|
||||
func_args = args[name]
|
||||
|
||||
print("[*] Args before function call:")
|
||||
for i, func_arg in enumerate(func_args):
|
||||
print(f"[*]\tArg {i}: {func_arg}")
|
||||
|
||||
try:
|
||||
|
||||
time = fn(*inputs[name])
|
||||
time = fn(*args[name])
|
||||
|
||||
except RuntimeError as e:
|
||||
print(f"[!] Error while running {name}: {e}")
|
||||
continue
|
||||
|
||||
print("Inputs after function call:")
|
||||
for i, func_input in enumerate(func_inputs):
|
||||
print(f"[*]\tInput {i}: {','.join(map(str, func_input.ravel()[:32]))}")
|
||||
|
||||
print("Args after function call:")
|
||||
for i, func_arg in enumerate(func_args):
|
||||
print(f"[*]\tArg {i}: {func_arg}")
|
||||
|
||||
if time:
|
||||
print(f"[*] Function execution time: {time:4f}ms")
|
||||
|
||||
del inputs[name]
|
||||
|
||||
del args[name]
|
||||
|
||||
else:
|
||||
print(f"\n{'*' * 10}\n")
|
||||
|
|
|
@ -56,8 +56,8 @@ version = "0.0.0.3"
|
|||
optional = true
|
||||
|
||||
# A string describing the number of elements in the buffer for a runtime_array logical type.
|
||||
# Typically expected to reference other parameters in the function.
|
||||
# e.g. "N", "lda * K"
|
||||
# Typically expected to reference other parameters in the function in their shape order.
|
||||
# e.g. "N", "lda * K" for shape (lda, K)
|
||||
[types.paramType.size]
|
||||
type = "string"
|
||||
optional = true
|
||||
|
|
|
@ -28,6 +28,7 @@ void MatMul(const float* A, const float* B, float* C);
|
|||
#ifdef TOML
|
||||
'''
|
||||
|
||||
|
||||
class CreateSimpleHatFile_test(unittest.TestCase):
|
||||
|
||||
def test_create_simple_hat_file(self):
|
||||
|
@ -94,10 +95,13 @@ class CreateSimpleHatFile_test(unittest.TestCase):
|
|||
return_info=return_arg
|
||||
)
|
||||
auxiliary_key_name = "test_auxiliary_key"
|
||||
hat_function.auxiliary[auxiliary_key_name] = { "name" : "matmul" }
|
||||
hat_function.auxiliary[auxiliary_key_name] = { "name": "matmul" }
|
||||
|
||||
workdir = "./test_output"
|
||||
os.makedirs(workdir, exist_ok=True)
|
||||
|
||||
link_target_path = "./fake_link_target.lib"
|
||||
hat_file_path = "./test_simple_hat_path.hat"
|
||||
hat_file_path = f"{workdir}/test_simple_hat_path.hat"
|
||||
new_hat_file = hat.HATFile(
|
||||
name="simple_hat_file",
|
||||
functions=[hat_function],
|
||||
|
@ -128,5 +132,6 @@ class CreateSimpleHatFile_test(unittest.TestCase):
|
|||
# Check that the code strings are equal. Serialization/deserialization doesn't always preserve leading/trailing whitespace so use strip() to normalize
|
||||
self.assertEqual(parsed_hat_file.declaration.code.strip(), SAMPLE_MATMUL_DECL_CODE.strip())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
unittest.main()
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
import os
|
||||
import sys
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from hatlib import (
|
||||
|
|
|
@ -0,0 +1,405 @@
|
|||
#!/usr/bin/env python3
|
||||
import re
|
||||
import hatlib as hat
|
||||
import os
|
||||
import shutil
|
||||
import unittest
|
||||
|
||||
|
||||
class VerifyHat_test(unittest.TestCase):
|
||||
|
||||
def build(self, impl_code: str, workdir: str, name: str, func_name: str) -> str:
|
||||
hat.ensure_compiler_in_path()
|
||||
if hat.get_platform() == hat.OperatingSystem.Windows:
|
||||
return self.windows_build(impl_code, workdir, name, func_name)
|
||||
else:
|
||||
return self.linux_build(impl_code, workdir, name)
|
||||
|
||||
def windows_build(self, impl_code: str, workdir: str, name: str, func_name: str) -> str:
|
||||
source_path = f"{workdir}/{name}.c"
|
||||
lib_path = f"{workdir}/{name}.dll"
|
||||
|
||||
shutil.rmtree(workdir, ignore_errors=True)
|
||||
os.makedirs(workdir, exist_ok=True)
|
||||
with open(source_path, "w") as f:
|
||||
print(impl_code, file=f)
|
||||
|
||||
dllmain_path = f"{workdir}/dllmain.cpp"
|
||||
with open(dllmain_path, "w") as f:
|
||||
print("#include <windows.h>\n", file=f)
|
||||
print("BOOL APIENTRY DllMain(HMODULE, DWORD, LPVOID) { return TRUE; }\n", file=f)
|
||||
|
||||
if os.path.exists(lib_path):
|
||||
os.remove(lib_path)
|
||||
|
||||
hat.run_command(
|
||||
f'cl.exe "{source_path}" "{dllmain_path}" /nologo /link /DLL /EXPORT:{func_name} /OUT:"{lib_path}"',
|
||||
quiet=True
|
||||
)
|
||||
self.assertTrue(os.path.isfile(lib_path))
|
||||
return lib_path
|
||||
|
||||
def linux_build(self, impl_code: str, workdir: str, name: str) -> str:
|
||||
source_path = f"{workdir}/{name}.c"
|
||||
lib_path = f"{workdir}/{name}.so"
|
||||
|
||||
shutil.rmtree(workdir, ignore_errors=True)
|
||||
os.makedirs(workdir, exist_ok=True)
|
||||
with open(source_path, "w") as f:
|
||||
print(impl_code, file=f)
|
||||
|
||||
if os.path.exists(lib_path):
|
||||
os.remove(lib_path)
|
||||
|
||||
hat.run_command(f'gcc -shared -fPIC -o "{lib_path}" "{source_path}"', quiet=True)
|
||||
self.assertTrue(os.path.isfile(lib_path))
|
||||
return lib_path
|
||||
|
||||
def create_hat_file(self, hat_input: hat.HATFile):
|
||||
hat_path = hat_input.path
|
||||
if os.path.exists(hat_path):
|
||||
os.remove(hat_path)
|
||||
hat_input.Serialize(hat_path)
|
||||
self.assertTrue(os.path.exists(hat_path))
|
||||
|
||||
def test_basic(self):
|
||||
# Generate a HAT package with a C implementation and call verify_hat
|
||||
impl_code = '''#include <math.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#ifdef _MSC_VER
|
||||
#define DLL_EXPORT __declspec( dllexport )
|
||||
#else
|
||||
#define DLL_EXPORT
|
||||
#endif
|
||||
|
||||
DLL_EXPORT void Softmax(const float input[2][2], float output[2][2])
|
||||
{
|
||||
/* Softmax 13 (TF, pytorch style)
|
||||
axis = 0
|
||||
*/
|
||||
for (uint32_t i1 = 0; i1 < 2; ++i1) {
|
||||
float max = -INFINITY;
|
||||
for (uint32_t i0 = 0; i0 < 2; ++i0) {
|
||||
max = max > input[i0][i1] ? max : input[i0][i1];
|
||||
}
|
||||
float sum = 0.0;
|
||||
for (uint32_t i0 = 0; i0 < 2; ++i0) {
|
||||
sum += expf(input[i0][i1] - max);
|
||||
}
|
||||
for (uint32_t i0 = 0; i0 < 2; ++i0) {
|
||||
output[i0][i1] = expf(input[i0][i1] - max) / sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
'''
|
||||
decl_code = '''#endif // TOML
|
||||
#pragma once
|
||||
|
||||
#if defined(__cplusplus)
|
||||
extern "C"
|
||||
{
|
||||
#endif // defined(__cplusplus)
|
||||
|
||||
void Softmax(const float input[2][2], float output[2][2]);
|
||||
|
||||
#ifndef __Softmax_DEFINED__
|
||||
#define __Softmax_DEFINED__
|
||||
void (*Softmax)(float*, float*) = Softmax;
|
||||
#endif
|
||||
|
||||
#if defined(__cplusplus)
|
||||
} // extern "C"
|
||||
#endif // defined(__cplusplus)
|
||||
|
||||
#ifdef TOML
|
||||
'''
|
||||
workdir = "./test_output/verify_hat_basic"
|
||||
name = "softmax"
|
||||
func_name = "Softmax"
|
||||
lib_path = self.build(impl_code, workdir, name, func_name)
|
||||
hat_path = f"{workdir}/{name}.hat"
|
||||
|
||||
# create the hat file
|
||||
shape = [2, 2]
|
||||
strides = [shape[1], 1] # first major
|
||||
param_input = hat.Parameter(
|
||||
name="input",
|
||||
logical_type=hat.ParameterType.AffineArray,
|
||||
declared_type="float*",
|
||||
element_type="float",
|
||||
usage=hat.UsageType.Input,
|
||||
shape=shape,
|
||||
affine_map=strides
|
||||
)
|
||||
param_output = hat.Parameter(
|
||||
name="output",
|
||||
logical_type=hat.ParameterType.AffineArray,
|
||||
declared_type="float*",
|
||||
element_type="float",
|
||||
usage=hat.UsageType.InputOutput,
|
||||
shape=shape,
|
||||
affine_map=strides
|
||||
)
|
||||
hat_function = hat.Function(
|
||||
arguments=[param_input, param_output],
|
||||
calling_convention=hat.CallingConventionType.StdCall,
|
||||
name=func_name,
|
||||
return_info=hat.Parameter.void()
|
||||
)
|
||||
hat_input = hat.HATFile(
|
||||
name=name,
|
||||
functions=[hat_function],
|
||||
dependencies=hat.Dependencies(link_target=os.path.basename(lib_path)),
|
||||
declaration=hat.Declaration(code=decl_code),
|
||||
path=hat_path
|
||||
)
|
||||
self.create_hat_file(hat_input)
|
||||
hat.verify_hat_package(hat_path)
|
||||
|
||||
def test_runtime_array(self):
|
||||
# Generate a HAT package using C and call verify_hat
|
||||
impl_code = '''#include <stdint.h>
|
||||
#include <stdlib.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#ifndef ALLOC
|
||||
#define ALLOC(size) ( malloc(size) )
|
||||
#endif
|
||||
#ifndef DEALLOC
|
||||
#define DEALLOC(X) ( free(X) )
|
||||
#endif
|
||||
|
||||
#ifdef _MSC_VER
|
||||
#define DLL_EXPORT __declspec( dllexport )
|
||||
#else
|
||||
#define DLL_EXPORT
|
||||
#endif
|
||||
|
||||
DLL_EXPORT void Range(const int32_t start[1], const int32_t limit[1], const int32_t delta[1], int32_t** output, uint32_t* output_dim)
|
||||
{
|
||||
/* Range */
|
||||
/* Ensure we don't crash with random inputs */
|
||||
int32_t start0 = start[0];
|
||||
int32_t delta0 = delta[0] == 0 ? 1 : delta[0];
|
||||
int32_t limit0 = (limit[0] <= start0) ? (start0 + delta0 * 25) : limit[0];
|
||||
|
||||
*output_dim = (limit0 - start0) / delta0;
|
||||
*output = (int32_t*)ALLOC(*output_dim * sizeof(int32_t));
|
||||
printf(\"Allocated %d output elements\\n\", *output_dim);
|
||||
printf(\"start=%d, limit=%d, delta=%d\\n\", start0, limit0, delta0);
|
||||
|
||||
for (uint32_t i = 0; i < *output_dim; ++i) {
|
||||
(*output)[i] = start0 + (i * delta0);
|
||||
}
|
||||
|
||||
for (uint32_t i = 0; i < *output_dim; ++i) {
|
||||
(*output)[i] = start0 + (i * delta0);
|
||||
}
|
||||
}
|
||||
'''
|
||||
decl_code = '''#endif // TOML
|
||||
#pragma once
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
#if defined(__cplusplus)
|
||||
extern "C"
|
||||
{
|
||||
#endif // defined(__cplusplus)
|
||||
|
||||
void Range(const int32_t start[1], const int32_t limit[1], const int32_t delta[1], int32_t** output, uint32_t* output_dim);
|
||||
|
||||
#ifndef __Range_DEFINED__
|
||||
#define __Range_DEFINED__
|
||||
void (*Range)(int32_t*, int32_t*, int32_t*, int32_t**, uint32_t*) = Range;
|
||||
#endif
|
||||
|
||||
#if defined(__cplusplus)
|
||||
} // extern "C"
|
||||
#endif // defined(__cplusplus)
|
||||
|
||||
#ifdef TOML
|
||||
'''
|
||||
workdir = "test_output/verify_hat_runtime_array"
|
||||
name = "range"
|
||||
func_name = "Range"
|
||||
lib_path = self.build(impl_code, workdir, name, func_name)
|
||||
hat_path = f"{workdir}/{name}.hat"
|
||||
|
||||
# create the hat file
|
||||
param_start = hat.Parameter(
|
||||
name="start",
|
||||
logical_type=hat.ParameterType.AffineArray,
|
||||
declared_type="int32_t*",
|
||||
element_type="int32_t",
|
||||
usage=hat.UsageType.Input,
|
||||
shape=[],
|
||||
)
|
||||
param_limit = hat.Parameter(
|
||||
name="limit",
|
||||
logical_type=hat.ParameterType.AffineArray,
|
||||
declared_type="int32_t*",
|
||||
element_type="int32_t",
|
||||
usage=hat.UsageType.Input,
|
||||
shape=[],
|
||||
)
|
||||
param_delta = hat.Parameter(
|
||||
name="delta",
|
||||
logical_type=hat.ParameterType.AffineArray,
|
||||
declared_type="int32_t*",
|
||||
element_type="int32_t",
|
||||
usage=hat.UsageType.Input,
|
||||
shape=[],
|
||||
)
|
||||
param_output = hat.Parameter(
|
||||
name="output",
|
||||
logical_type=hat.ParameterType.RuntimeArray,
|
||||
declared_type="int32_t**",
|
||||
element_type="int32_t",
|
||||
usage=hat.UsageType.Output,
|
||||
size="output_dim"
|
||||
)
|
||||
param_output_dim = hat.Parameter(
|
||||
name="output_dim",
|
||||
logical_type=hat.ParameterType.Element,
|
||||
declared_type="uint32_t*",
|
||||
element_type="uint32_t",
|
||||
usage=hat.UsageType.Output,
|
||||
shape=[]
|
||||
)
|
||||
hat_function = hat.Function(
|
||||
arguments=[param_start, param_limit, param_delta, param_output, param_output_dim],
|
||||
calling_convention=hat.CallingConventionType.StdCall,
|
||||
name=func_name,
|
||||
return_info=hat.Parameter.void()
|
||||
)
|
||||
hat_input = hat.HATFile(
|
||||
name=name,
|
||||
functions=[hat_function],
|
||||
dependencies=hat.Dependencies(link_target=os.path.basename(lib_path)),
|
||||
declaration=hat.Declaration(code=decl_code),
|
||||
path=hat_path
|
||||
)
|
||||
self.create_hat_file(hat_input)
|
||||
hat.verify_hat_package(hat_path)
|
||||
|
||||
def test_input_runtime_arrays(self):
|
||||
impl_code = '''#include <stdint.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
#ifndef ALLOC
|
||||
#define ALLOC(size) ( malloc(size) )
|
||||
#endif
|
||||
#ifndef DEALLOC
|
||||
#define DEALLOC(X) ( free(X) )
|
||||
#endif
|
||||
|
||||
#ifdef _MSC_VER
|
||||
#define DLL_EXPORT __declspec( dllexport )
|
||||
#else
|
||||
#define DLL_EXPORT
|
||||
#endif
|
||||
|
||||
DLL_EXPORT void /* Unsqueeze_18 */ Unsqueeze(const float* data, const int64_t data_dim0, float** expanded, int64_t* dim0, int64_t* dim1)
|
||||
{
|
||||
/* Unsqueeze */
|
||||
*dim0 = 1;
|
||||
*dim1 = data_dim0;
|
||||
*expanded = (float*)ALLOC((*dim0) * (*dim1) * sizeof(float));
|
||||
float* data_ = (float*)data;
|
||||
float* expanded_ = (float*)(*expanded);
|
||||
for (int64_t i = 0; i < data_dim0; ++i)
|
||||
expanded_[i] = data_[i];
|
||||
}
|
||||
'''
|
||||
decl_code = '''#endif // TOML
|
||||
#pragma once
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
#if defined(__cplusplus)
|
||||
extern "C"
|
||||
{
|
||||
#endif // defined(__cplusplus)
|
||||
|
||||
void Unsqueeze(const float* data, const int64_t data_dim0, float** expanded, int64_t* dim0, int64_t* dim1);
|
||||
|
||||
#ifndef __Unsqueeze_DEFINED__
|
||||
#define __Unsqueeze_DEFINED__
|
||||
void (*Unsqueeze_)(float*, int64_t, float**, int64_t*, int64_t*) = Unsqueeze;
|
||||
#endif
|
||||
|
||||
#if defined(__cplusplus)
|
||||
} // extern "C"
|
||||
#endif // defined(__cplusplus)
|
||||
|
||||
#ifdef TOML
|
||||
'''
|
||||
for id, usage in enumerate([hat.UsageType.Input, hat.UsageType.InputOutput]):
|
||||
workdir = "test_output/verify_hat_inout_runtime_arrays"
|
||||
name = f"unsqueeze_{id}" # uniqify for Windows to avoid load conflict
|
||||
func_name = "Unsqueeze"
|
||||
lib_path = self.build(impl_code, workdir, name, func_name)
|
||||
hat_path = f"{workdir}/{name}.hat"
|
||||
|
||||
# create the hat file
|
||||
param_data = hat.Parameter(
|
||||
name="data",
|
||||
logical_type=hat.ParameterType.RuntimeArray,
|
||||
declared_type="float*",
|
||||
element_type="float",
|
||||
usage=usage,
|
||||
size="data_dim"
|
||||
)
|
||||
param_data_dim = hat.Parameter(
|
||||
name="data_dim",
|
||||
logical_type=hat.ParameterType.Element,
|
||||
declared_type="int64_t",
|
||||
element_type="int64_t",
|
||||
usage=hat.UsageType.Input,
|
||||
shape=[]
|
||||
)
|
||||
param_expanded = hat.Parameter(
|
||||
name="expanded",
|
||||
logical_type=hat.ParameterType.RuntimeArray,
|
||||
declared_type="float**",
|
||||
element_type="float",
|
||||
usage=hat.UsageType.Output,
|
||||
size="dim0*dim1"
|
||||
)
|
||||
param_dim0 = hat.Parameter(
|
||||
name="dim0",
|
||||
logical_type=hat.ParameterType.Element,
|
||||
declared_type="int64_t*",
|
||||
element_type="int64_t",
|
||||
usage=hat.UsageType.Output,
|
||||
shape=[]
|
||||
)
|
||||
param_dim1 = hat.Parameter(
|
||||
name="dim1",
|
||||
logical_type=hat.ParameterType.Element,
|
||||
declared_type="int64_t*",
|
||||
element_type="int64_t",
|
||||
usage=hat.UsageType.Output,
|
||||
shape=[]
|
||||
)
|
||||
hat_function = hat.Function(
|
||||
arguments=[param_data, param_data_dim, param_expanded, param_dim0, param_dim1],
|
||||
calling_convention=hat.CallingConventionType.StdCall,
|
||||
name=func_name,
|
||||
return_info=hat.Parameter.void()
|
||||
)
|
||||
hat_input = hat.HATFile(
|
||||
name=name,
|
||||
functions=[hat_function],
|
||||
dependencies=hat.Dependencies(link_target=os.path.basename(lib_path)),
|
||||
declaration=hat.Declaration(code=decl_code),
|
||||
path=hat_path
|
||||
)
|
||||
self.create_hat_file(hat_input)
|
||||
hat.verify_hat_package(hat_path)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Загрузка…
Ссылка в новой задаче