Initial verify_hat support for runtime arrays and elements (#65)

* support arbirtary pointer levels for declared types

* handle logical_type=element

* arg values

* refactor (func name, arg_info) to FunctionInfo

* fix circular imports

* .

* support dylib

* update verify_hat_package

* verify -> verify_args

* fixes

* fix verify_args

* formatting

* add verify hat test

* refactor

* handle ndarrays as arguments

* cleanup

* runtime_array verify test

* print output

* basic test passing

* Update test_create_simple_hat_file.py

* Update test_create_simple_hat_file.py

* comments

* comments

* TODOs

* nfc

* [test] moved creation to workdir

* rename

* Print output dimension references and clarify HAT schema (#66)

* infer shapes from size, add shape order requirement

* merged

* pretty print using cross references

Co-authored-by: Lisa Ong <onglisa@microsoft.com>

* revert formatting changes

* revert more formatting only changes

* simplify

* Add input and input/output runtime_array support (#67)

* wip

* scaffold

* scaffold and initial support for input elements and input/output runtime_arrays

* .

* fixups

* don't swallow exceptions

* support cargs for non pointer args

* cleanup

* refactor

* support integer-like types when checking constant shapes

* nfc

Co-authored-by: Lisa Ong <onglisa@microsoft.com>

* test coverage for  usage type Input and InputOutput

* [test] Support windows in verify_hat tests (#69)

* wip

* build for windows (#68)

Co-authored-by: Lisa Ong <onglisa@microsoft.com>

* windows tomlkit expects lists

* manual CI trigger

Co-authored-by: Lisa Ong <onglisa@microsoft.com>

* fix windows test

* fix logic and add comment

* .

* verify_args -> verify

* args -> arguments

* PR feedback

Co-authored-by: Lisa Ong <onglisa@microsoft.com>
This commit is contained in:
Lisa Ong 2022-08-23 07:09:25 +08:00 коммит произвёл GitHub
Родитель 65743f6e83
Коммит d99d61d970
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
16 изменённых файлов: 788 добавлений и 157 удалений

2
.github/workflows/ci.yml поставляемый
Просмотреть файл

@ -5,6 +5,8 @@ on:
branches: [ main ]
pull_request:
branches: [ main ]
# allow manual triggers
workflow_dispatch:
jobs:
build:

1
.gitignore поставляемый
Просмотреть файл

@ -362,3 +362,4 @@ dist/
# test
test_acccgen
test_output

Просмотреть файл

@ -5,3 +5,4 @@ from .hat_to_dynamic import *
from .hat_to_lib import *
from .benchmark_hat_package import run_benchmark
from .platform_utilities import *
from .verify_hat_package import *

Просмотреть файл

@ -1,42 +1,44 @@
import ctypes
from functools import reduce
import numpy as np
import sys
import re
from dataclasses import dataclass
from typing import Any, List, Tuple
from typing import Any, Tuple, Union
from . import hat_file
# hat_declared_type : [ ctype, dtype_str ]
# element_type : [ ctype, dtype_str ]
ARG_TYPES = {
"int8_t*" : [ ctypes.c_int8, "int8" ],
"int16_t*" : [ ctypes.c_int16, "int16" ],
"int32_t*" : [ ctypes.c_int32, "int32" ],
"int64_t*" : [ ctypes.c_int64, "int64" ],
"uint8_t*" : [ ctypes.c_uint8, "uint8" ],
"uint16_t*" : [ ctypes.c_uint16, "uint16" ],
"uint32_t*" : [ ctypes.c_uint32, "uint32" ],
"uint64_t*" : [ ctypes.c_uint64, "uint64" ],
"float16_t*" : [ ctypes.c_uint16, "float16" ], # same bitwidth as uint16
"bfloat16_t*" : [ ctypes.c_uint16, "bfloat16" ],
"float*" : [ ctypes.c_float, "float32" ],
"double*" : [ ctypes.c_double, "float64" ],
"int8_t": [ctypes.c_int8, "int8"],
"int16_t": [ctypes.c_int16, "int16"],
"int32_t": [ctypes.c_int32, "int32"],
"int64_t": [ctypes.c_int64, "int64"],
"uint8_t": [ctypes.c_uint8, "uint8"],
"uint16_t": [ctypes.c_uint16, "uint16"],
"uint32_t": [ctypes.c_uint32, "uint32"],
"uint64_t": [ctypes.c_uint64, "uint64"],
"float16_t": [ctypes.c_uint16, "float16"], # same bitwidth as uint16
"bfloat16_t": [ctypes.c_uint16, "bfloat16"],
"float": [ctypes.c_float, "float32"],
"double": [ctypes.c_double, "float64"],
}
CTYPE_ENTRY = 0
DTYPE_ENTRY = 1
@dataclass
class ArgInfo:
"""Extracts necessary information from the description of a function argument in a hat file"""
name: str
hat_declared_type: str
numpy_shape: Tuple[int, ...]
shape: Tuple[Union[int, str], ...] # int for affine_arrays, str symbols for runtime_arrays
numpy_strides: Tuple[int, ...]
numpy_dtype: type
element_num_bytes: int
element_strides: Tuple[int, ...]
total_element_count: int
total_byte_size: int
ctypes_pointer_type: Any
total_element_count: Union[int, str]
total_byte_size: Union[int, str]
ctypes_type: Any
pointer_level: int
usage: hat_file.UsageType = None
def _get_type(self, type_str):
@ -46,66 +48,70 @@ class ArgInfo:
return np.dtype(type_str)
def _get_pointer_level(self, declared_type: str):
pos = declared_type.find("*")
if pos == -1:
return 0
return declared_type[pos:].count("*")
def __init__(self, param_description: hat_file.Parameter):
self.name = param_description.name
self.hat_declared_type = param_description.declared_type
self.numpy_shape = tuple(param_description.shape)
self.shape = tuple(param_description.shape)
self.usage = param_description.usage
self.pointer_level = self._get_pointer_level(self.hat_declared_type)
element_type = self.hat_declared_type[:(-1 * self.pointer_level)] \
if self.pointer_level else self.hat_declared_type
if not self.hat_declared_type in ARG_TYPES:
raise NotImplementedError(f"Unsupported declared_type {self.hat_declared_type} in hat file")
if not element_type in ARG_TYPES:
raise NotImplementedError(f"Unsupported element_type {element_type} in hat file")
self.ctypes_pointer_type = ctypes.POINTER(ARG_TYPES[self.hat_declared_type][CTYPE_ENTRY])
dtype_entry = ARG_TYPES[self.hat_declared_type][DTYPE_ENTRY]
ctypes_type = ARG_TYPES[element_type][CTYPE_ENTRY]
dtype_entry = ARG_TYPES[element_type][DTYPE_ENTRY]
self.numpy_dtype = self._get_type(dtype_entry)
self.element_num_bytes = 2 if dtype_entry == "bfloat16" else self.numpy_dtype.itemsize
if self.numpy_shape:
if param_description.logical_type == hat_file.ParameterType.AffineArray:
self.ctypes_type = ctypes.POINTER(ctypes_type)
if self.shape:
self.element_strides = param_description.affine_map
self.numpy_strides = tuple([self.element_num_bytes * x for x in self.element_strides])
major_dim = self.element_strides.index(max(self.element_strides))
self.total_element_count = self.numpy_shape[major_dim] * self.element_strides[major_dim]
self.total_element_count = self.shape[major_dim] * self.element_strides[major_dim]
else:
self.element_strides = self.numpy_strides = self.numpy_shape = [1]
self.element_strides = self.numpy_strides = self.shape = [1]
self.total_element_count = 1
self.total_byte_size = self.element_num_bytes * self.total_element_count
elif param_description.logical_type == hat_file.ParameterType.RuntimeArray:
self.ctypes_type = ctypes.POINTER(ctypes_type)
self.total_byte_size = f"{self.element_num_bytes} * {param_description.size}"
self.total_element_count = param_description.size
# assume the sizes are in shape order
self.shape = re.split(r"\s?\*\s?", param_description.size)
# TODO: Update this to take a HATFunction instead, instead of arg_infos and function_name
def verify_args(args: List, arg_infos: List[ArgInfo], function_name: str):
""" Verifies that a list of arguments matches a list of argument descriptions in a HAT file
"""
# check number of args
if len(args) != len(arg_infos):
sys.exit(f"Error calling {function_name}(...): expected {len(arg_infos)} arguments but received {len(args)}")
elif param_description.logical_type == hat_file.ParameterType.Element:
if param_description.usage == hat_file.UsageType.Input:
self.ctypes_type = ctypes_type
else:
self.ctypes_type = ctypes.POINTER(ctypes_type)
self.element_strides = self.numpy_strides = self.shape = [1]
self.total_element_count = 1
self.total_byte_size = self.element_num_bytes * self.total_element_count
# for each arg
for i in range(len(args)):
arg = args[i]
arg_info = arg_infos[i]
else:
raise ValueError(f"Unknown logical type {param_description.logical_type} in hat file")
# confirm that the arg is a numpy ndarray
if not isinstance(arg, np.ndarray):
sys.exit(
"Error calling {function_name}(...): expected argument {i} to be <class 'numpy.ndarray'> but received {type(arg)}"
)
@property
def is_constant_shaped(self):
# confirm that the arg dtype matches the dexcription in the hat package
if arg_info.numpy_dtype != arg.dtype:
sys.exit(
f"Error calling {function_name}(...): expected argument {i} to have dtype={arg_info.numpy_dtype} but received dtype={arg.dtype}"
)
def integer_like(s: Any):
# handle types such as tomlkit.items.Integer
try:
return int(s) == s
except:
return False
# confirm that the arg shape is correct
if arg_info.numpy_shape != arg.shape:
sys.exit(
f"Error calling {function_name}(...): expected argument {i} to have shape={arg_info.numpy_shape} but received shape={arg.shape}"
)
# confirm that the arg strides are correct
if arg_info.numpy_strides != arg.strides:
sys.exit(
f"Error calling {function_name}(...): expected argument {i} to have strides={arg_info.numpy_strides} but received strides={arg.strides}"
)
return all(integer_like(s) for s in self.shape)

170
hatlib/arg_value.py Normal file
Просмотреть файл

@ -0,0 +1,170 @@
from typing import Any, List
from ctypes import byref
import numpy as np
import random
from .arg_info import ArgInfo
from . import hat_file
class ArgValue:
"""An argument containing a scalar, ndarray, or pointer value.
Used for calling HAT functions from ctypes"""
def __init__(self, arg_info: ArgInfo, value: Any = None):
# TODO: set the free and alloc function symbols here?
self.arg_info = arg_info
self.pointer_level = arg_info.pointer_level
self.ctypes_type = arg_info.ctypes_type
if self.pointer_level > 2: # punt until we really need this
raise NotImplementedError("Pointer levels > 2 are not supported")
self.value = value
if self.value is None:
if not self.pointer_level:
raise ValueError("A value is required for non-pointers")
else:
# no value provided, allocate the pointer
self.allocate()
self.dim_values = None
def allocate(self):
if not self.pointer_level:
return # nothing to do
if self.value:
return # value already assigned, nothing to do
if self.pointer_level == 1:
# allocate an ndarray with random input values
self.value = np.lib.stride_tricks.as_strided(
np.random.rand(self.arg_info.total_element_count).astype(self.arg_info.numpy_dtype),
shape=self.arg_info.shape,
strides=self.arg_info.numpy_strides
)
elif self.pointer_level == 2:
# allocate a pointer. HAT function will perform the actual allocation.
self.value = self.ctypes_type()
def as_carg(self):
"Return the C interface for this argument"
if self.pointer_level:
if isinstance(self.value, np.ndarray):
return self.value.ctypes.data_as(self.ctypes_type)
else:
return byref(self.value)
else:
return self.ctypes_type(self.value)
def verify(self, desc: ArgInfo):
"Verifies that this argument matches an argument description"
if desc.pointer_level == 1:
if not isinstance(self.value, np.ndarray):
raise ValueError(f"expected argument to be <class 'numpy.ndarray'> but received {type(self.value)}")
if desc.numpy_dtype != self.value.dtype:
raise ValueError(
f"expected argument to have dtype={desc.numpy_dtype} but received dtype={self.value.dtype}"
)
if desc.is_constant_shaped:
# confirm that the arg shape is correct (numpy represents shapes as tuples)
if tuple(desc.shape) != self.value.shape:
raise ValueError(
f"expected argument to have shape={desc.shape} but received shape={self.value.shape}"
)
# confirm that the arg strides are correct (numpy represents strides as tuples)
if tuple(desc.numpy_strides) != self.value.strides:
raise ValueError(
f"expected argument to have strides={desc.numpy_strides} but received strides={self.value.strides}"
)
else:
pass # TODO - support other pointer levels
def __repr__(self):
if self.pointer_level:
if isinstance(self.value, np.ndarray):
return ",".join(map(str, self.value.ravel()[:32]))
else:
try:
if self.dim_values:
# cross-reference the dimension output values to pretty print the output
shape = [d.value[0] for d in self.dim_values] # stored as single-element ndarrays
s = repr(np.ctypeslib.as_array(self.value, shape))
else:
s = repr(self.value.contents)
except Exception as e:
if e.args[0].startswith("NULL pointer"):
s = f"{repr(self.value)} nullptr"
else:
raise (e)
return s
else:
return repr(self.value)
def __del__(self):
if self.pointer_level == 2:
pass # TODO - free the pointer, presumably calling a symbol passed into this ArgValue
def get_dimension_arg_indices(array_arg: ArgInfo, all_arguments: List[ArgInfo]) -> List[int]:
# Returns the dimension argument indices in shape order for an array argument
indices = []
for sym_name in array_arg.shape:
for i, info in enumerate(all_arguments):
if info.name == sym_name: # limitation: only string shapes are supported
indices.append(i)
break
else:
# not found
raise RuntimeError(f"{sym_name} is not an argument to the function") # likely an invalid HAT file
return indices
def generate_arg_values(arguments: List[ArgInfo]) -> List[ArgValue]:
"""Generate argument values from argument descriptions
Input and input/output affine_arrays: initialized with random inputs
Input and input/output runtime_arrays: initialized with arbitrary dimensions and random inputs
Output elements and runtime_arrays: pointers are allocated
"""
def generate_dim_value():
return random.choice([128, 256, 1234]) # example dimension values
dim_names_to_values = {}
values = []
for arg in arguments:
if arg.usage != hat_file.UsageType.Output and not arg.is_constant_shaped:
# input runtime arrays
dim_args = [arguments[i] for i in get_dimension_arg_indices(arg, arguments)]
# assign generated shape values to the corresponding dimension arguments
shape = []
for d in dim_args:
if d.name not in dim_names_to_values:
shape.append(generate_dim_value())
dim_names_to_values[d.name] = ArgValue(d, shape[-1])
else:
shape.append(dim_names_to_values[d.name].value)
# materialize an array input using the generated shape
runtime_array_inputs = np.random.random(tuple(shape)).astype(arg.numpy_dtype)
values.append(ArgValue(arg, runtime_array_inputs))
elif arg.name in dim_names_to_values:
# input element that is a dimension value (populated when its input runtime array is created)
values.append(dim_names_to_values[arg.name])
else:
# everything else is known size or a pointer
values.append(ArgValue(arg))
# collect the dimension ArgValues for each output runtime_array ArgValue
for value in values:
if value.arg_info.usage == hat_file.UsageType.Output and not value.arg_info.is_constant_shaped:
dim_values = [values[i] for i in get_dimension_arg_indices(value.arg_info, arguments)]
assert dim_values, f"Runtime array {value.arg_info.name} has no dimensions"
value.dim_values = dim_values
return values

Просмотреть файл

@ -9,7 +9,7 @@ import traceback
from .callable_func import CallableFunc
from .hat_file import HATFile
from .hat import load, generate_input_sets_for_func
from .hat import load, generate_arg_sets_for_func
class Benchmark:
@ -84,13 +84,13 @@ class Benchmark:
if not isinstance(benchmark_func, CallableFunc):
# generate sufficient input sets to overflow the L3 cache, since we don't know the size of the model
# we'll make a guess based on the minimum input set size
input_sets = generate_input_sets_for_func(func,
input_sets = generate_arg_sets_for_func(func,
input_sets_minimum_size_MB,
num_additional=10)
set_size = 0
for i in input_sets[0]:
set_size += i.size * i.dtype.itemsize
set_size += i.value.size * i.value.dtype.itemsize
if verbose:
print(f"[Benchmarking] Using {len(input_sets)} input sets, each {set_size} bytes")
@ -129,11 +129,11 @@ class Benchmark:
else:
if verbose:
print(f"[Benchmarking] Benchmarking device function on gpu {gpu_id}. {batch_size} batches of warming up for {warmup_iterations} and then measuring with {min_timing_iterations} iterations.")
input_sets = generate_input_sets_for_func(func)
input_sets = generate_arg_sets_for_func(func)
set_size = 0
for i in input_sets:
set_size += i.size * i.dtype.itemsize
set_size += i.value.size * i.value.dtype.itemsize
if verbose:
print(f"[Benchmarking] Using input of {set_size} bytes")

Просмотреть файл

@ -4,8 +4,9 @@ import sys
import numpy as np
from typing import List
from cuda import cuda, nvrtc
from .arg_info import ArgInfo, verify_args
from .arg_info import ArgInfo
from .callable_func import CallableFunc
from .function_info import FunctionInfo
from .hat_file import Function
@ -38,6 +39,7 @@ def _find_cuda_incl_path() -> pathlib.Path:
return cuda_path
def _get_compute_capability(gpu_id) -> int:
err, major = cuda.cuDeviceGetAttribute(cuda.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, gpu_id)
ASSERT_DRV(err)
@ -47,6 +49,7 @@ def _get_compute_capability(gpu_id) -> int:
return (major * 10) + minor
def compile_cuda_program(cuda_src_path: pathlib.Path, func_name, gpu_id):
src = cuda_src_path.read_text()
@ -180,10 +183,8 @@ class CudaCallableFunc(CallableFunc):
def __init__(self, func: Function, cuda_src_path: str) -> None:
super().__init__()
self.hat_func = func
self.func_name = func.name
self.func_info = FunctionInfo(func)
self.kernel = None
hat_arg_descriptions = func.arguments
self.arg_infos = [ArgInfo(d) for d in hat_arg_descriptions]
self.launch_params = func.launch_parameters
self.device_mem = None
self.ptrs = None
@ -198,19 +199,19 @@ class CudaCallableFunc(CallableFunc):
ptx = _PTX_CACHE.get(self.cuda_src_path)
if not ptx:
_PTX_CACHE[self.cuda_src_path] = ptx = compile_cuda_program(self.cuda_src_path, self.func_name, gpu_id)
_PTX_CACHE[self.cuda_src_path] = ptx = compile_cuda_program(self.cuda_src_path, self.func_info.name, gpu_id)
self.kernel = get_func_from_ptx(ptx, self.func_name)
self.kernel = get_func_from_ptx(ptx, self.func_info.name)
def cleanup_runtime(self, benchmark: bool):
cuda.cuCtxDestroy(self.context)
def init_main(self, benchmark: bool, warmup_iters=0, args=[], gpu_id: int=0):
verify_args(args, self.arg_infos, self.func_name)
self.device_mem = allocate_cuda_mem(self.arg_infos)
def init_main(self, benchmark: bool, warmup_iters=0, args=[], gpu_id: int = 0):
self.func_info.verify(args)
self.device_mem = allocate_cuda_mem(self.func_info.arguments)
if not benchmark:
transfer_mem_host_to_cuda(device_args=self.device_mem, host_args=args, arg_infos=self.arg_infos)
transfer_mem_host_to_cuda(device_args=self.device_mem, host_args=args, arg_infos=self.func_info.arguments)
self.ptrs = device_args_to_ptr_list(self.device_mem)
@ -270,7 +271,7 @@ class CudaCallableFunc(CallableFunc):
def cleanup_main(self, benchmark: bool, args=[]):
# If there's no device mem, that means allocation during initialization failed, which means nothing else needs to be cleaned up either
if not benchmark and self.device_mem:
transfer_mem_cuda_to_host(device_args=self.device_mem, host_args=args, arg_infos=self.arg_infos)
transfer_mem_cuda_to_host(device_args=self.device_mem, host_args=args, arg_infos=self.func_info.arguments)
if self.device_mem:
free_cuda_mem(self.device_mem)
err, = cuda.cuCtxSynchronize()

46
hatlib/function_info.py Normal file
Просмотреть файл

@ -0,0 +1,46 @@
import numpy as np
import sys
from dataclasses import dataclass, field
from typing import Any, List
from .arg_info import ArgInfo
from .arg_value import ArgValue
from . import hat_file
@dataclass
class FunctionInfo:
"Information about a HAT function"
desc: hat_file.Function
arguments: List[ArgInfo] = field(default_factory=list)
name: str = ""
def __post_init__(self):
self.name = self.desc.name
self.arguments = list(map(ArgInfo, self.desc.arguments))
def verify(self, args: List[Any]):
"Verifies that a list of argument values matches the function description"
if len(args) != len(self.arguments):
sys.exit(
f"Error calling {self.name}(...): expected {len(self.arguments)} arguments but received {len(args)}"
)
for i, (info, value) in enumerate(zip(self.arguments, args)):
try:
if isinstance(value, np.ndarray):
value = ArgValue(info, value)
value.verify(info)
except ValueError as v:
sys.exit(f"Error calling {self.name}(...): argument {i} failed verification: {v}")
def as_cargs(self, args: List[Any]):
"Converts arguments to their C interfaces"
arg_values = [
ArgValue(info, value) if isinstance(value, np.ndarray) else value
for info, value in zip(self.arguments, args)
]
return [value.as_carg() for value in arg_values]

Просмотреть файл

@ -24,36 +24,36 @@ For example:
# call a package function named 'my_func_698b5e5c'
package.my_func_698b5e5c(A, B, D, E)
"""
import numpy as np
from typing import Tuple, Union
from functools import reduce
from . import hat_file
from . import hat_package
from .arg_info import ArgInfo
from .arg_value import generate_arg_values
from .function_info import FunctionInfo
def generate_input_sets_for_func(func: hat_file.Function, input_sets_minimum_size_MB: int = 0, num_additional: int = 0):
parameters = list(map(ArgInfo, func.arguments))
shapes_to_sizes = [reduce(lambda x, y: x * y, p.numpy_shape) for p in parameters]
set_size = reduce(lambda x, y: x + y, map(lambda size, p: size * p.element_num_bytes, shapes_to_sizes, parameters))
def generate_arg_sets_for_func(func: hat_file.Function, input_sets_minimum_size_MB: int = 0, num_additional: int = 0):
func_info = FunctionInfo(func)
parameters = func_info.arguments
# use constant-sized params to estimate the minimum set size
const_sized_parameters = list(filter(lambda p: p.is_constant_shaped, parameters))
shapes_to_sizes = [reduce(lambda x, y: x * y, p.shape) for p in const_sized_parameters]
set_size = reduce(
lambda x, y: x + y, map(lambda size, p: size * p.element_num_bytes, shapes_to_sizes, const_sized_parameters)
)
num_input_sets = (input_sets_minimum_size_MB * 1024 * 1024 // set_size) + 1 + num_additional
input_sets = [[
np.lib.stride_tricks.as_strided(
np.random.rand(p.total_element_count).astype(p.numpy_dtype),
shape=p.numpy_shape,
strides=p.numpy_strides
) for p in parameters
] for _ in range(num_input_sets)]
arg_sets = [generate_arg_values(parameters) for _ in range(num_input_sets)]
return input_sets[0] if len(input_sets) == 1 else input_sets
return arg_sets[0] if len(arg_sets) == 1 else arg_sets
def generate_input_sets_for_hat_file(hat_path):
def generate_arg_sets_for_hat_file(hat_path):
t = hat_file.HATFile.Deserialize(hat_path)
return {func_name: generate_input_sets_for_func(func_desc)
return {func_name: generate_arg_sets_for_func(func_desc)
for func_name, func_desc in t.function_map.items()}

Просмотреть файл

@ -1,17 +1,12 @@
#!/usr/bin/env python3
# Utility to parse and validate a HAT package
import ctypes
from typing import Any, List, Union
from collections import OrderedDict
from functools import partial
from .hat_file import HATFile, Function, Parameter
from .arg_info import ArgInfo, verify_args
import os
from .hat_file import HATFile, Function
from .function_info import FunctionInfo
class HATPackage:
@ -66,17 +61,16 @@ class AttributeDict(OrderedDict):
return OrderedDict.__getitem__(key)
def _make_cpu_func(shared_lib: ctypes.CDLL, function_name: str, arg_infos: List[Parameter]):
arg_infos = [ArgInfo(d) for d in arg_infos]
fn = shared_lib[function_name]
def _make_cpu_func(shared_lib: ctypes.CDLL, func: Function):
func_info = FunctionInfo(func)
fn = shared_lib[func_info.name]
def f(*args):
# verify that the (numpy) input args match the description in
# the hat file
verify_args(args, arg_infos, function_name)
# verify that the args match the description in the hat file
func_info.verify(args)
# prepare the args to the hat package
hat_args = [arg.ctypes.data_as(arg_info.ctypes_pointer_type) for arg, arg_info in zip(args, arg_infos)]
hat_args = func_info.as_cargs(args)
# call the function in the hat package
fn(*hat_args)
@ -97,7 +91,7 @@ def _load_pkg_binary_module(hat_pkg: HATPackage):
shared_lib = None
if os.path.isfile(hat_pkg.link_target_path):
supported_extensions = [".dll", ".so"]
supported_extensions = [".dll", ".so", ".dylib"]
_, extension = os.path.splitext(hat_pkg.link_target_path)
if extension and extension not in supported_extensions:
@ -149,7 +143,7 @@ def hat_package_to_func_dict(hat_pkg: HATPackage) -> AttributeDict:
launches = func_desc.launches
if not launches and shared_lib:
func_dict[func_name] = _make_cpu_func(shared_lib, func_desc.name, func_desc.arguments)
func_dict[func_name] = _make_cpu_func(shared_lib, func_desc)
else:
device_func = hat_pkg.hat_file.device_function_map.get(launches)

Просмотреть файл

@ -3,8 +3,9 @@ import pathlib
import numpy as np
from typing import List
from .arg_info import ArgInfo, verify_args
from .arg_info import ArgInfo
from .callable_func import CallableFunc
from .function_info import FunctionInfo
from .hat_file import Function
from .gpu_headers import ROCM_HEADER_MAP
from .pyhip.hip import *
@ -37,7 +38,9 @@ def get_func_from_rocm_program(rocm_program, func_name):
kernel = hipModuleGetFunction(rocm_module, func_name)
return kernel
cached_mem=[]
cached_mem = []
def allocate_rocm_mem(benchmark: bool, arg_infos: List[ArgInfo], gpu_id: int):
device_mem = []
@ -90,10 +93,8 @@ class RocmCallableFunc(CallableFunc):
def __init__(self, func: Function, rocm_src_path: str) -> None:
super().__init__()
self.hat_func = func
self.func_name = func.name
self.func_info = FunctionInfo(func)
self.kernel = None
hat_arg_descriptions = func.arguments
self.arg_infos = [ArgInfo(d) for d in hat_arg_descriptions]
self.launch_params = func.launch_parameters
self.device_mem = None
self.ptrs = None
@ -115,22 +116,23 @@ class RocmCallableFunc(CallableFunc):
rocm_program = _HSACO_CACHE.get(self.rocm_src_path)
if not rocm_program:
_HSACO_CACHE[self.rocm_src_path] = rocm_program = compile_rocm_program(self.rocm_src_path, self.func_name)
_HSACO_CACHE[self.rocm_src_path
] = rocm_program = compile_rocm_program(self.rocm_src_path, self.func_info.name)
self.kernel = get_func_from_rocm_program(rocm_program, self.func_name)
self.kernel = get_func_from_rocm_program(rocm_program, self.func_info.name)
def cleanup_runtime(self, benchmark: bool):
pass
def init_main(self, benchmark: bool, warmup_iters=0, args=[], gpu_id: int=0):
verify_args(args, self.arg_infos, self.func_name)
self.device_mem = allocate_rocm_mem(benchmark, self.arg_infos, gpu_id)
def init_main(self, benchmark: bool, warmup_iters=0, args=[], gpu_id: int = 0):
self.func_info.verify(args)
self.device_mem = allocate_rocm_mem(benchmark, self.func_info.arguments, gpu_id)
if not benchmark:
transfer_mem_host_to_rocm(device_args=self.device_mem, host_args=args, arg_infos=self.arg_infos)
transfer_mem_host_to_rocm(device_args=self.device_mem, host_args=args, arg_infos=self.func_info.arguments)
class DataStruct(ctypes.Structure):
_fields_ = [(f"arg{i}", ctypes.c_void_p) for i in range(len(self.arg_infos))]
_fields_ = [(f"arg{i}", ctypes.c_void_p) for i in range(len(self.func_info.arguments))]
self.data = DataStruct(*self.device_mem)
@ -175,7 +177,7 @@ class RocmCallableFunc(CallableFunc):
def cleanup_main(self, benchmark: bool, args=[]):
# If there's no device mem, that means allocation during initialization failed, which means nothing else needs to be cleaned up either
if not benchmark and self.device_mem:
transfer_mem_rocm_to_host(device_args=self.device_mem, host_args=args, arg_infos=self.arg_infos)
transfer_mem_rocm_to_host(device_args=self.device_mem, host_args=args, arg_infos=self.func_info.arguments)
free_rocm_mem(self.device_mem)
hipDeviceSynchronize()

Просмотреть файл

@ -7,33 +7,32 @@ from . import hat
def verify_hat_package(hat_path):
_, funcs = hat.load(hat_path)
inputs = hat.generate_input_sets_for_hat_file(hat_path)
args = hat.generate_arg_sets_for_hat_file(hat_path)
for name, fn in funcs.items():
print(f"\n{'*' * 10}\n")
print(f"[*] Verifying function {name} --")
func_inputs = inputs[name]
func_args = args[name]
print("[*] Inputs before function call:")
for i, func_input in enumerate(func_inputs):
print(f"[*]\tInput {i}: {','.join(map(str, func_input.ravel()[:32]))}")
print("[*] Args before function call:")
for i, func_arg in enumerate(func_args):
print(f"[*]\tArg {i}: {func_arg}")
try:
time = fn(*inputs[name])
time = fn(*args[name])
except RuntimeError as e:
print(f"[!] Error while running {name}: {e}")
continue
print("Inputs after function call:")
for i, func_input in enumerate(func_inputs):
print(f"[*]\tInput {i}: {','.join(map(str, func_input.ravel()[:32]))}")
print("Args after function call:")
for i, func_arg in enumerate(func_args):
print(f"[*]\tArg {i}: {func_arg}")
if time:
print(f"[*] Function execution time: {time:4f}ms")
del inputs[name]
del args[name]
else:
print(f"\n{'*' * 10}\n")

Просмотреть файл

@ -56,8 +56,8 @@ version = "0.0.0.3"
optional = true
# A string describing the number of elements in the buffer for a runtime_array logical type.
# Typically expected to reference other parameters in the function.
# e.g. "N", "lda * K"
# Typically expected to reference other parameters in the function in their shape order.
# e.g. "N", "lda * K" for shape (lda, K)
[types.paramType.size]
type = "string"
optional = true

Просмотреть файл

@ -28,6 +28,7 @@ void MatMul(const float* A, const float* B, float* C);
#ifdef TOML
'''
class CreateSimpleHatFile_test(unittest.TestCase):
def test_create_simple_hat_file(self):
@ -94,10 +95,13 @@ class CreateSimpleHatFile_test(unittest.TestCase):
return_info=return_arg
)
auxiliary_key_name = "test_auxiliary_key"
hat_function.auxiliary[auxiliary_key_name] = { "name" : "matmul" }
hat_function.auxiliary[auxiliary_key_name] = { "name": "matmul" }
workdir = "./test_output"
os.makedirs(workdir, exist_ok=True)
link_target_path = "./fake_link_target.lib"
hat_file_path = "./test_simple_hat_path.hat"
hat_file_path = f"{workdir}/test_simple_hat_path.hat"
new_hat_file = hat.HATFile(
name="simple_hat_file",
functions=[hat_function],
@ -128,5 +132,6 @@ class CreateSimpleHatFile_test(unittest.TestCase):
# Check that the code strings are equal. Serialization/deserialization doesn't always preserve leading/trailing whitespace so use strip() to normalize
self.assertEqual(parsed_hat_file.declaration.code.strip(), SAMPLE_MATMUL_DECL_CODE.strip())
if __name__ == '__main__':
unittest.main()

Просмотреть файл

@ -1,7 +1,6 @@
#!/usr/bin/env python3
import os
import sys
import unittest
from pathlib import Path
from hatlib import (

405
test/test_verify_hat.py Normal file
Просмотреть файл

@ -0,0 +1,405 @@
#!/usr/bin/env python3
import re
import hatlib as hat
import os
import shutil
import unittest
class VerifyHat_test(unittest.TestCase):
def build(self, impl_code: str, workdir: str, name: str, func_name: str) -> str:
hat.ensure_compiler_in_path()
if hat.get_platform() == hat.OperatingSystem.Windows:
return self.windows_build(impl_code, workdir, name, func_name)
else:
return self.linux_build(impl_code, workdir, name)
def windows_build(self, impl_code: str, workdir: str, name: str, func_name: str) -> str:
source_path = f"{workdir}/{name}.c"
lib_path = f"{workdir}/{name}.dll"
shutil.rmtree(workdir, ignore_errors=True)
os.makedirs(workdir, exist_ok=True)
with open(source_path, "w") as f:
print(impl_code, file=f)
dllmain_path = f"{workdir}/dllmain.cpp"
with open(dllmain_path, "w") as f:
print("#include <windows.h>\n", file=f)
print("BOOL APIENTRY DllMain(HMODULE, DWORD, LPVOID) { return TRUE; }\n", file=f)
if os.path.exists(lib_path):
os.remove(lib_path)
hat.run_command(
f'cl.exe "{source_path}" "{dllmain_path}" /nologo /link /DLL /EXPORT:{func_name} /OUT:"{lib_path}"',
quiet=True
)
self.assertTrue(os.path.isfile(lib_path))
return lib_path
def linux_build(self, impl_code: str, workdir: str, name: str) -> str:
source_path = f"{workdir}/{name}.c"
lib_path = f"{workdir}/{name}.so"
shutil.rmtree(workdir, ignore_errors=True)
os.makedirs(workdir, exist_ok=True)
with open(source_path, "w") as f:
print(impl_code, file=f)
if os.path.exists(lib_path):
os.remove(lib_path)
hat.run_command(f'gcc -shared -fPIC -o "{lib_path}" "{source_path}"', quiet=True)
self.assertTrue(os.path.isfile(lib_path))
return lib_path
def create_hat_file(self, hat_input: hat.HATFile):
hat_path = hat_input.path
if os.path.exists(hat_path):
os.remove(hat_path)
hat_input.Serialize(hat_path)
self.assertTrue(os.path.exists(hat_path))
def test_basic(self):
# Generate a HAT package with a C implementation and call verify_hat
impl_code = '''#include <math.h>
#include <stdint.h>
#ifdef _MSC_VER
#define DLL_EXPORT __declspec( dllexport )
#else
#define DLL_EXPORT
#endif
DLL_EXPORT void Softmax(const float input[2][2], float output[2][2])
{
/* Softmax 13 (TF, pytorch style)
axis = 0
*/
for (uint32_t i1 = 0; i1 < 2; ++i1) {
float max = -INFINITY;
for (uint32_t i0 = 0; i0 < 2; ++i0) {
max = max > input[i0][i1] ? max : input[i0][i1];
}
float sum = 0.0;
for (uint32_t i0 = 0; i0 < 2; ++i0) {
sum += expf(input[i0][i1] - max);
}
for (uint32_t i0 = 0; i0 < 2; ++i0) {
output[i0][i1] = expf(input[i0][i1] - max) / sum;
}
}
}
'''
decl_code = '''#endif // TOML
#pragma once
#if defined(__cplusplus)
extern "C"
{
#endif // defined(__cplusplus)
void Softmax(const float input[2][2], float output[2][2]);
#ifndef __Softmax_DEFINED__
#define __Softmax_DEFINED__
void (*Softmax)(float*, float*) = Softmax;
#endif
#if defined(__cplusplus)
} // extern "C"
#endif // defined(__cplusplus)
#ifdef TOML
'''
workdir = "./test_output/verify_hat_basic"
name = "softmax"
func_name = "Softmax"
lib_path = self.build(impl_code, workdir, name, func_name)
hat_path = f"{workdir}/{name}.hat"
# create the hat file
shape = [2, 2]
strides = [shape[1], 1] # first major
param_input = hat.Parameter(
name="input",
logical_type=hat.ParameterType.AffineArray,
declared_type="float*",
element_type="float",
usage=hat.UsageType.Input,
shape=shape,
affine_map=strides
)
param_output = hat.Parameter(
name="output",
logical_type=hat.ParameterType.AffineArray,
declared_type="float*",
element_type="float",
usage=hat.UsageType.InputOutput,
shape=shape,
affine_map=strides
)
hat_function = hat.Function(
arguments=[param_input, param_output],
calling_convention=hat.CallingConventionType.StdCall,
name=func_name,
return_info=hat.Parameter.void()
)
hat_input = hat.HATFile(
name=name,
functions=[hat_function],
dependencies=hat.Dependencies(link_target=os.path.basename(lib_path)),
declaration=hat.Declaration(code=decl_code),
path=hat_path
)
self.create_hat_file(hat_input)
hat.verify_hat_package(hat_path)
def test_runtime_array(self):
# Generate a HAT package using C and call verify_hat
impl_code = '''#include <stdint.h>
#include <stdlib.h>
#include <stdio.h>
#ifndef ALLOC
#define ALLOC(size) ( malloc(size) )
#endif
#ifndef DEALLOC
#define DEALLOC(X) ( free(X) )
#endif
#ifdef _MSC_VER
#define DLL_EXPORT __declspec( dllexport )
#else
#define DLL_EXPORT
#endif
DLL_EXPORT void Range(const int32_t start[1], const int32_t limit[1], const int32_t delta[1], int32_t** output, uint32_t* output_dim)
{
/* Range */
/* Ensure we don't crash with random inputs */
int32_t start0 = start[0];
int32_t delta0 = delta[0] == 0 ? 1 : delta[0];
int32_t limit0 = (limit[0] <= start0) ? (start0 + delta0 * 25) : limit[0];
*output_dim = (limit0 - start0) / delta0;
*output = (int32_t*)ALLOC(*output_dim * sizeof(int32_t));
printf(\"Allocated %d output elements\\n\", *output_dim);
printf(\"start=%d, limit=%d, delta=%d\\n\", start0, limit0, delta0);
for (uint32_t i = 0; i < *output_dim; ++i) {
(*output)[i] = start0 + (i * delta0);
}
for (uint32_t i = 0; i < *output_dim; ++i) {
(*output)[i] = start0 + (i * delta0);
}
}
'''
decl_code = '''#endif // TOML
#pragma once
#include <stdint.h>
#if defined(__cplusplus)
extern "C"
{
#endif // defined(__cplusplus)
void Range(const int32_t start[1], const int32_t limit[1], const int32_t delta[1], int32_t** output, uint32_t* output_dim);
#ifndef __Range_DEFINED__
#define __Range_DEFINED__
void (*Range)(int32_t*, int32_t*, int32_t*, int32_t**, uint32_t*) = Range;
#endif
#if defined(__cplusplus)
} // extern "C"
#endif // defined(__cplusplus)
#ifdef TOML
'''
workdir = "test_output/verify_hat_runtime_array"
name = "range"
func_name = "Range"
lib_path = self.build(impl_code, workdir, name, func_name)
hat_path = f"{workdir}/{name}.hat"
# create the hat file
param_start = hat.Parameter(
name="start",
logical_type=hat.ParameterType.AffineArray,
declared_type="int32_t*",
element_type="int32_t",
usage=hat.UsageType.Input,
shape=[],
)
param_limit = hat.Parameter(
name="limit",
logical_type=hat.ParameterType.AffineArray,
declared_type="int32_t*",
element_type="int32_t",
usage=hat.UsageType.Input,
shape=[],
)
param_delta = hat.Parameter(
name="delta",
logical_type=hat.ParameterType.AffineArray,
declared_type="int32_t*",
element_type="int32_t",
usage=hat.UsageType.Input,
shape=[],
)
param_output = hat.Parameter(
name="output",
logical_type=hat.ParameterType.RuntimeArray,
declared_type="int32_t**",
element_type="int32_t",
usage=hat.UsageType.Output,
size="output_dim"
)
param_output_dim = hat.Parameter(
name="output_dim",
logical_type=hat.ParameterType.Element,
declared_type="uint32_t*",
element_type="uint32_t",
usage=hat.UsageType.Output,
shape=[]
)
hat_function = hat.Function(
arguments=[param_start, param_limit, param_delta, param_output, param_output_dim],
calling_convention=hat.CallingConventionType.StdCall,
name=func_name,
return_info=hat.Parameter.void()
)
hat_input = hat.HATFile(
name=name,
functions=[hat_function],
dependencies=hat.Dependencies(link_target=os.path.basename(lib_path)),
declaration=hat.Declaration(code=decl_code),
path=hat_path
)
self.create_hat_file(hat_input)
hat.verify_hat_package(hat_path)
def test_input_runtime_arrays(self):
impl_code = '''#include <stdint.h>
#include <stdlib.h>
#ifndef ALLOC
#define ALLOC(size) ( malloc(size) )
#endif
#ifndef DEALLOC
#define DEALLOC(X) ( free(X) )
#endif
#ifdef _MSC_VER
#define DLL_EXPORT __declspec( dllexport )
#else
#define DLL_EXPORT
#endif
DLL_EXPORT void /* Unsqueeze_18 */ Unsqueeze(const float* data, const int64_t data_dim0, float** expanded, int64_t* dim0, int64_t* dim1)
{
/* Unsqueeze */
*dim0 = 1;
*dim1 = data_dim0;
*expanded = (float*)ALLOC((*dim0) * (*dim1) * sizeof(float));
float* data_ = (float*)data;
float* expanded_ = (float*)(*expanded);
for (int64_t i = 0; i < data_dim0; ++i)
expanded_[i] = data_[i];
}
'''
decl_code = '''#endif // TOML
#pragma once
#include <stdint.h>
#if defined(__cplusplus)
extern "C"
{
#endif // defined(__cplusplus)
void Unsqueeze(const float* data, const int64_t data_dim0, float** expanded, int64_t* dim0, int64_t* dim1);
#ifndef __Unsqueeze_DEFINED__
#define __Unsqueeze_DEFINED__
void (*Unsqueeze_)(float*, int64_t, float**, int64_t*, int64_t*) = Unsqueeze;
#endif
#if defined(__cplusplus)
} // extern "C"
#endif // defined(__cplusplus)
#ifdef TOML
'''
for id, usage in enumerate([hat.UsageType.Input, hat.UsageType.InputOutput]):
workdir = "test_output/verify_hat_inout_runtime_arrays"
name = f"unsqueeze_{id}" # uniqify for Windows to avoid load conflict
func_name = "Unsqueeze"
lib_path = self.build(impl_code, workdir, name, func_name)
hat_path = f"{workdir}/{name}.hat"
# create the hat file
param_data = hat.Parameter(
name="data",
logical_type=hat.ParameterType.RuntimeArray,
declared_type="float*",
element_type="float",
usage=usage,
size="data_dim"
)
param_data_dim = hat.Parameter(
name="data_dim",
logical_type=hat.ParameterType.Element,
declared_type="int64_t",
element_type="int64_t",
usage=hat.UsageType.Input,
shape=[]
)
param_expanded = hat.Parameter(
name="expanded",
logical_type=hat.ParameterType.RuntimeArray,
declared_type="float**",
element_type="float",
usage=hat.UsageType.Output,
size="dim0*dim1"
)
param_dim0 = hat.Parameter(
name="dim0",
logical_type=hat.ParameterType.Element,
declared_type="int64_t*",
element_type="int64_t",
usage=hat.UsageType.Output,
shape=[]
)
param_dim1 = hat.Parameter(
name="dim1",
logical_type=hat.ParameterType.Element,
declared_type="int64_t*",
element_type="int64_t",
usage=hat.UsageType.Output,
shape=[]
)
hat_function = hat.Function(
arguments=[param_data, param_data_dim, param_expanded, param_dim0, param_dim1],
calling_convention=hat.CallingConventionType.StdCall,
name=func_name,
return_info=hat.Parameter.void()
)
hat_input = hat.HATFile(
name=name,
functions=[hat_function],
dependencies=hat.Dependencies(link_target=os.path.basename(lib_path)),
declaration=hat.Declaration(code=decl_code),
path=hat_path
)
self.create_hat_file(hat_input)
hat.verify_hat_package(hat_path)
if __name__ == '__main__':
unittest.main()