зеркало из https://github.com/microsoft/hat.git
Add beginnings of support for CUDA device functions (#32)
Adds support for GPU, device functions, and launch functions
This commit is contained in:
Родитель
b54fca0ff7
Коммит
615006cecb
|
@ -28,11 +28,11 @@ jobs:
|
|||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install -r tools/requirements.txt
|
||||
python -m pip install -r hatlib/requirements.txt
|
||||
- name: Unittest
|
||||
run: |
|
||||
python -m pip install -r tools/test/requirements.txt
|
||||
python -m unittest discover tools/test
|
||||
python -m pip install -r hatlib/test/requirements.txt
|
||||
python -m unittest discover hatlib/test
|
||||
- name: Build whl
|
||||
run: |
|
||||
python -m pip install build
|
||||
|
|
26
README.md
26
README.md
|
@ -8,13 +8,13 @@ HAT is a format for packaging compiled libraries in the C programming language.
|
|||
|
||||
A HAT package includes one library file and one or more `.hat` files. The library file can be a static library (a `.a` file on Posix systems or a `.lib` file on Windows systems) or a dynamic library (a `.so` file on Posix systems or a `.dll` file and its accompanying `.lib` import library file on Windows systems). For simplicity of documentation, when discussing the "library file" as a unit in a HAT package we are referring to either the single static library file (`.a` or `.lib`), the single dynamic library file (`.so`) on Posix, or the pair of library files that comprise a dynamic library on Windows (`.dll` and `.lib` import library file). Note that since a HAT package includes only one "library file" it cannot contain both a static and dynamic library, so in the case of a HAT package containing a Windows dynamic library, the `.lib` file in the package is always the import library and there is no ambiguity in which file (or pair of files) is being referenced by "library file".
|
||||
|
||||
The library file contains all the compiled object code that implements the functions in the HAT package. Each `.hat` file contains a combination of standard C function declarations (like a typical `.h` file) and metadata in the TOML markup language. The metadata that accompanies each function declaration describes how the function should be called and how it was implemented. The metadata is intended to be both human-readable and machine-readable, providing structured and systematic documentation and allowing downstream tools to examine the package contents.
|
||||
The library file contains all the compiled object code that implements the functions in the HAT package. Each `.hat` file contains a combination of standard C function declarations (like a typical `.h` file) and metadata in the TOML markup language. The metadata that accompanies each function declaration describes how the function should be called and how it was implemented. The metadata is intended to be both human-readable and machine-readable, providing structured and systematic documentation and allowing downstream tools to examine the package contents.
|
||||
|
||||
Each `.hat` file has the convenient property that it is simultaneously a valid h-file and a valid TOML file. In other words, the file is structured such that a C compiler will ignore the TOML metadata, while a TOML parser will understand the entire file as a valid TOML file. We accomplish this using a technique we call *the hat trick*, which is explained below.
|
||||
Each `.hat` file has the convenient property that it is simultaneously a valid h-file and a valid TOML file. In other words, the file is structured such that a C compiler will ignore the TOML metadata, while a TOML parser will understand the entire file as a valid TOML file. We accomplish this using a technique we call *the hat trick*, which is explained below.
|
||||
|
||||
# What problem does the HAT format solve?
|
||||
# What problem does the HAT format solve?
|
||||
|
||||
C is among the most popular programming languages, but it also has serious shortcomings. In particular, C libraries are typically opaque and lack mechanisms for systematic documentation and introspection. This is best explained with an example:
|
||||
C is among the most popular programming languages, but it also has serious shortcomings. In particular, C libraries are typically opaque and lack mechanisms for systematic documentation and introspection. This is best explained with an example:
|
||||
|
||||
Say that we use C to implement an in-place column-wise normalization of a 10x10 matrix. In other words, this function takes a 10x10 matrix `A` and divides each column by the Euclidean norm of that column. A highly-optimized implementation of this function would be tailored to the target computer's specific hardware properties, such as its cache size, the number of CPU cores, and perhaps even the presence of a GPU. The declaration of this function in an h-file would look something like this:
|
||||
```
|
||||
|
@ -24,16 +24,16 @@ The accompanying library file would contain the compiled machine code for this f
|
|||
|
||||
* What does the pointer `float* A` point to? By convention, it is reasonable to assume that `A` points to the first element of an array that contains the 100 matrix elements, but this is not stated explicitly.
|
||||
* Does the function expect the matrix elements to appear in row-major order, column-major order, Z-order, or something else?
|
||||
* What is the size of the array `A`? We may have auxiliary knowledge that the array is 100 elements long, but this information is not stated explicitly.
|
||||
* We can see that `A` is not `const`, so we know that its elements can be changed by the function, but is it an "output-only" array (its initial values are overwritten) or is it an "input/output" array? We have auxiliary knowledge that `A` is both an input and an output, but this is not stated explicitly.
|
||||
* Is this function compiled for Windows or Linux?
|
||||
* What is the size of the array `A`? We may have auxiliary knowledge that the array is 100 elements long, but this information is not stated explicitly.
|
||||
* We can see that `A` is not `const`, so we know that its elements can be changed by the function, but is it an "output-only" array (its initial values are overwritten) or is it an "input/output" array? We have auxiliary knowledge that `A` is both an input and an output, but this is not stated explicitly.
|
||||
* Is this function compiled for Windows or Linux?
|
||||
* Does it need to be linked to a C runtime library or any other library?
|
||||
* For which instruction set is the function compiled? Does it rely on SSE extensions? AVX? AVX512?
|
||||
* Is this a multi-threaded implementation? Does the implementation assume a fixed number of CPU cores?
|
||||
* Does the implementation rely on GPU hardware?
|
||||
* Is this a multi-threaded implementation? Does the implementation assume a fixed number of CPU cores?
|
||||
* Does the implementation rely on GPU hardware?
|
||||
* Who created this library? Does it have a version number? Is it distributed under an open-source license?
|
||||
|
||||
Some of the questions above can be answered by reading the human-readable documentation provided in h-file comments, in `README.txt` or `LICENSE.txt` files, or in a web page that describes the library. Some of the information may be implied by the library name or the function name (e.g., imagine that the function was named "normalize_10x10_singlecore") or by common sense (e.g., if a GPU is not mentioned anywhere, the function probably doesn't require one). Nevertheless, C does not have a schematized systematic way to express all of this important information. Moreover, human-readable documentation does not expose this information to downstream programming tools. For example, imagine a downstream tool that examines a library and automatically creates tests that measure the performance of each function.
|
||||
Some of the questions above can be answered by reading the human-readable documentation provided in h-file comments, in `README.txt` or `LICENSE.txt` files, or in a web page that describes the library. Some of the information may be implied by the library name or the function name (e.g., imagine that the function was named "normalize_10x10_singlecore") or by common sense (e.g., if a GPU is not mentioned anywhere, the function probably doesn't require one). Nevertheless, C does not have a schematized systematic way to express all of this important information. Moreover, human-readable documentation does not expose this information to downstream programming tools. For example, imagine a downstream tool that examines a library and automatically creates tests that measure the performance of each function.
|
||||
|
||||
The HAT package format attempts to replace this opacity with transparency, by annotating each declared function with descriptive metadata in TOML.
|
||||
|
||||
|
@ -56,11 +56,11 @@ As mentioned above, the `.hat` file is simultaneously a valid h-file and a valid
|
|||
#endif // TOML
|
||||
```
|
||||
|
||||
What does a C compiler see? Assuming that the `TOML` macro is not defined, the parser ignores everything that appears between `#ifdef TOML` and `#endif`. This leaves whatever appears instead of `// Add C declarations here`.
|
||||
What does a C compiler see? Assuming that the `TOML` macro is not defined, the parser ignores everything that appears between `#ifdef TOML` and `#endif`. This leaves whatever appears instead of `// Add C declarations here`.
|
||||
|
||||
What does a TOML parser see? First note that `#` is a comment escape character in TOML, so the `#ifdef` and `#endif` lines are ignored as comments. Any TOML code that appears instead of `// Add TOML here` is parsed normally. Finally, a special TOML table named `[declaration]` is defined, and inside it a key named `code` with all of the C declarations as a multiline string.
|
||||
|
||||
Why is it important for the TOML and the C declarations to live in the same file? Why not put the TOML metadata in a separate file? The fact that C already splits the package code between library files and h-files is already a concern, because the user has to worry about distributing a `.h` file with an incorrect version of the library file. We don't want to make things worse by adding yet another separate file. Keeping the metadata in the same file as the function declaration ensures that each declaration is never separated from its metadata.
|
||||
Why is it important for the TOML and the C declarations to live in the same file? Why not put the TOML metadata in a separate file? The fact that C already splits the package code between library files and h-files is already a concern, because the user has to worry about distributing a `.h` file with an incorrect version of the library file. We don't want to make things worse by adding yet another separate file. Keeping the metadata in the same file as the function declaration ensures that each declaration is never separated from its metadata.
|
||||
|
||||
# Multiple `.hat` files
|
||||
|
||||
|
@ -82,7 +82,7 @@ Requirements: Python 3.7 and above.
|
|||
pip install hatlib
|
||||
```
|
||||
|
||||
[Documentation](https://github.com/microsoft/hat/tree/main/tools#readme)
|
||||
[Documentation](https://github.com/microsoft/hat/tree/main/hatlib#readme)
|
||||
|
||||
You can also clone this repository and build a package locally:
|
||||
|
||||
|
|
|
@ -137,14 +137,14 @@ mean_duration_in_sec = 1.5953456437541567e-06
|
|||
This repository contains unit tests, authored with the Python `unittest` library. To setup and run all tests:
|
||||
|
||||
```shell
|
||||
pip install -r <path_to_repo>/tools/test/requirements.txt
|
||||
python -m unittest discover <path_to_repo>/tools/test
|
||||
pip install -r <path_to_repo>/hatlib/test/requirements.txt
|
||||
python -m unittest discover <path_to_repo>/hatlib/test
|
||||
```
|
||||
|
||||
To run a test case:
|
||||
|
||||
```shell
|
||||
python -m unittest discover -k "test_file_basic_serialize" <path_to_repo>/tools/test
|
||||
python -m unittest discover -k "test_file_basic_serialize" <path_to_repo>/hatlib/test
|
||||
```
|
||||
|
||||
Note that some tests will require a C++ compiler (e.g. MSVC for windows, gcc for linux) in `PATH`.
|
||||
Note that some tests will require a C++ compiler (e.g. MSVC for windows, gcc for linux) in `PATH`.
|
|
@ -0,0 +1,108 @@
|
|||
import ctypes
|
||||
import numpy as np
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Tuple
|
||||
from functools import reduce
|
||||
from typing import List
|
||||
|
||||
|
||||
@dataclass
|
||||
class ArgInfo:
|
||||
"""Extracts necessary information from the description of a function argument in a hat file"""
|
||||
hat_declared_type: str
|
||||
numpy_shape: Tuple[int]
|
||||
numpy_strides: Tuple[int]
|
||||
numpy_dtype: type
|
||||
element_num_bytes: int
|
||||
ctypes_pointer_type: Any
|
||||
usage: str = ""
|
||||
|
||||
def __init__(self, param_description):
|
||||
self.hat_declared_type = param_description["declared_type"]
|
||||
self.numpy_shape = tuple(param_description["shape"])
|
||||
self.usage = param_description["usage"]
|
||||
if self.hat_declared_type == "float16_t*":
|
||||
self.numpy_dtype = np.float16
|
||||
self.element_num_bytes = 2
|
||||
self.ctypes_pointer_type = ctypes.POINTER(
|
||||
ctypes.c_uint16) # same bitwidth as float16
|
||||
elif self.hat_declared_type == "float*":
|
||||
self.numpy_dtype = np.float32
|
||||
self.element_num_bytes = 4
|
||||
self.ctypes_pointer_type = ctypes.POINTER(ctypes.c_float)
|
||||
elif self.hat_declared_type == "double*":
|
||||
self.numpy_dtype = np.float64
|
||||
self.element_num_bytes = 8
|
||||
self.ctypes_pointer_type = ctypes.POINTER(ctypes.c_double)
|
||||
elif self.hat_declared_type == "int64_t*":
|
||||
self.numpy_dtype = np.int64
|
||||
self.element_num_bytes = 8
|
||||
self.ctypes_pointer_type = ctypes.POINTER(ctypes.c_int64)
|
||||
elif self.hat_declared_type == "int32_t*":
|
||||
self.numpy_dtype = np.int32
|
||||
self.element_num_bytes = 4
|
||||
self.ctypes_pointer_type = ctypes.POINTER(ctypes.c_int32)
|
||||
elif self.hat_declared_type == "int16_t*":
|
||||
self.numpy_dtype = np.int16
|
||||
self.element_num_bytes = 2
|
||||
self.ctypes_pointer_type = ctypes.POINTER(ctypes.c_int16)
|
||||
elif self.hat_declared_type == "int8_t*":
|
||||
self.numpy_dtype = np.int8
|
||||
self.element_num_bytes = 1
|
||||
self.ctypes_pointer_type = ctypes.POINTER(ctypes.c_int8)
|
||||
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Unsupported declared_type {self.hat_declared_type} in hat file")
|
||||
|
||||
self.numpy_strides = tuple(
|
||||
[self.element_num_bytes * x for x in param_description["affine_map"]])
|
||||
|
||||
|
||||
def verify_args(args, arg_infos, function_name):
|
||||
""" Verifies that a list of arguments matches a list of argument descriptions in a HAT file
|
||||
"""
|
||||
# check number of args
|
||||
if len(args) != len(arg_infos):
|
||||
sys.exit(
|
||||
f"Error calling {function_name}(...): expected {len(arg_infos)} arguments but received {len(args)}")
|
||||
|
||||
# for each arg
|
||||
for i in range(len(args)):
|
||||
arg = args[i]
|
||||
arg_info = arg_infos[i]
|
||||
|
||||
# confirm that the arg is a numpy ndarray
|
||||
if not isinstance(arg, np.ndarray):
|
||||
sys.exit(
|
||||
"Error calling {function_name}(...): expected argument {i} to be <class 'numpy.ndarray'> but received {type(arg)}")
|
||||
|
||||
# confirm that the arg dtype matches the dexcription in the hat package
|
||||
if arg_info.numpy_dtype != arg.dtype:
|
||||
sys.exit(
|
||||
f"Error calling {function_name}(...): expected argument {i} to have dtype={arg_info.numpy_dtype} but received dtype={arg.dtype}")
|
||||
|
||||
# confirm that the arg shape is correct
|
||||
if arg_info.numpy_shape != arg.shape:
|
||||
sys.exit(
|
||||
f"Error calling {function_name}(...): expected argument {i} to have shape={arg_info.numpy_shape} but received shape={arg.shape}")
|
||||
|
||||
# confirm that the arg strides are correct
|
||||
if arg_info.numpy_strides != arg.strides:
|
||||
sys.exit(
|
||||
f"Error calling {function_name}(...): expected argument {i} to have strides={arg_info.numpy_strides} but received strides={arg.strides}")
|
||||
|
||||
|
||||
def generate_input_sets(parameters: List[ArgInfo], input_sets_minimum_size_MB: int = 0, num_additional: int = 0):
|
||||
shapes_to_sizes = [reduce(lambda x, y: x * y, p.numpy_shape)
|
||||
for p in parameters]
|
||||
set_size = reduce(
|
||||
lambda x, y: x + y, [size * p.element_num_bytes for size, p in zip(shapes_to_sizes, parameters)])
|
||||
|
||||
num_input_sets = (input_sets_minimum_size_MB * 1024 *
|
||||
1024 // set_size) + 1 + num_additional
|
||||
input_sets = [[np.random.random(p.numpy_shape).astype(
|
||||
p.numpy_dtype) for p in parameters] for _ in range(num_input_sets)]
|
||||
|
||||
return input_sets[0] if len(input_sets) == 1 else input_sets
|
|
@ -7,18 +7,17 @@ import sys
|
|||
import time
|
||||
import toml
|
||||
import traceback
|
||||
from functools import reduce
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
if __package__:
|
||||
from .hat_file import HATFile
|
||||
from .hat_to_dynamic import create_dynamic_package
|
||||
from .hat import load, ArgInfo
|
||||
from .hat import load, ArgInfo, generate_input_sets
|
||||
else:
|
||||
from hat_file import HATFile
|
||||
from hat_to_dynamic import create_dynamic_package
|
||||
from hat import load, ArgInfo
|
||||
from hat import load, ArgInfo, generate_input_sets
|
||||
|
||||
|
||||
class Benchmark:
|
||||
"""A basic python-based benchmark.
|
||||
|
@ -27,6 +26,7 @@ class Benchmark:
|
|||
Requirements:
|
||||
A compilation toolchain in your PATH: cl.exe & link.exe (Windows), gcc (Linux), or clang (macOS)
|
||||
"""
|
||||
|
||||
def __init__(self, hat_path):
|
||||
self.hat_path = Path(hat_path)
|
||||
|
||||
|
@ -36,14 +36,15 @@ class Benchmark:
|
|||
# create dictionary of function descriptions defined in the hat file
|
||||
t = toml.load(self.hat_path)
|
||||
function_descriptions = t["functions"]
|
||||
self.hat_arg_descriptions = {key : [ArgInfo(d) for d in val["arguments"]] for key, val in function_descriptions.items()}
|
||||
self.hat_arg_descriptions = {key: [ArgInfo(
|
||||
d) for d in val["arguments"]] for key, val in function_descriptions.items()}
|
||||
|
||||
def run(self,
|
||||
function_name: str,
|
||||
warmup_iterations: int = 10,
|
||||
min_timing_iterations: int = 100,
|
||||
min_time_in_sec: int = 10,
|
||||
input_sets_minimum_size_MB = 50) -> float:
|
||||
input_sets_minimum_size_MB=50) -> float:
|
||||
"""Runs benchmarking for a function.
|
||||
Multiple inputs are run through the function until both minimum time and minimum iterations have been reached.
|
||||
The mean duration is then calculated as mean_duration = total_time_elapsed / total_iterations_performed.
|
||||
|
@ -62,8 +63,10 @@ class Benchmark:
|
|||
|
||||
# TODO: support packing and unpacking functions
|
||||
|
||||
mean_elapsed_time, batch_timings = self._profile(function_name, warmup_iterations, min_timing_iterations, min_time_in_sec, input_sets_minimum_size_MB)
|
||||
print(f"[Benchmarking] Mean duration per iteration: {mean_elapsed_time:.8f}s")
|
||||
mean_elapsed_time, batch_timings = self._profile(
|
||||
function_name, warmup_iterations, min_timing_iterations, min_time_in_sec, input_sets_minimum_size_MB)
|
||||
print(
|
||||
f"[Benchmarking] Mean duration per iteration: {mean_elapsed_time:.8f}s")
|
||||
|
||||
return mean_elapsed_time, batch_timings
|
||||
|
||||
|
@ -78,29 +81,30 @@ class Benchmark:
|
|||
perf_counter_scale = 1
|
||||
return perf_counter, perf_counter_scale
|
||||
|
||||
def generate_input_sets(parameters: List[ArgInfo], input_sets_minimum_size_MB: int, num_additional: int = 10):
|
||||
shapes_to_sizes = [reduce(lambda x, y: x * y, p.numpy_shape) for p in parameters]
|
||||
set_size = reduce(lambda x, y: x + y, [size * p.element_num_bytes for size, p in zip(shapes_to_sizes, parameters)])
|
||||
|
||||
num_input_sets = (input_sets_minimum_size_MB * 1024 * 1024 // set_size) + 1 + num_additional
|
||||
print(f"[Benchmarking] Using {num_input_sets} input sets, each {set_size} bytes")
|
||||
|
||||
return [[np.random.random(p.numpy_shape).astype(p.numpy_dtype) for p in parameters] for _ in range(num_input_sets)]
|
||||
|
||||
parameters = self.hat_arg_descriptions[function_name]
|
||||
|
||||
# generate sufficient input sets to overflow the L3 cache, since we don't know the size of the model
|
||||
# we'll make a guess based on the minimum input set size
|
||||
input_sets = generate_input_sets(parameters, input_sets_minimum_size_MB)
|
||||
input_sets = generate_input_sets(
|
||||
parameters, input_sets_minimum_size_MB, num_additional=10)
|
||||
|
||||
set_size = 0
|
||||
for i in input_sets[0]:
|
||||
set_size += i.size * i.dtype.itemsize
|
||||
|
||||
print(
|
||||
f"[Benchmarking] Using {len(input_sets)} input sets, each {set_size} bytes")
|
||||
|
||||
perf_counter, perf_counter_scale = get_perf_counter()
|
||||
print(f"[Benchmarking] Warming up for {warmup_iterations} iterations...")
|
||||
print(
|
||||
f"[Benchmarking] Warming up for {warmup_iterations} iterations...")
|
||||
|
||||
for _ in range(warmup_iterations):
|
||||
for calling_args in input_sets:
|
||||
self.hat_package[function_name](*calling_args)
|
||||
|
||||
print(f"[Benchmarking] Timing for at least {min_time_in_sec}s and at least {min_timing_iterations} iterations...")
|
||||
print(
|
||||
f"[Benchmarking] Timing for at least {min_time_in_sec}s and at least {min_timing_iterations} iterations...")
|
||||
start_time = perf_counter()
|
||||
end_time = perf_counter()
|
||||
|
||||
|
@ -115,7 +119,8 @@ class Benchmark:
|
|||
i = iterations % i_max
|
||||
iterations += 1
|
||||
end_time = perf_counter()
|
||||
batch_timings.append((end_time - batch_start_time) / perf_counter_scale)
|
||||
batch_timings.append(
|
||||
(end_time - batch_start_time) / perf_counter_scale)
|
||||
|
||||
elapsed_time = ((end_time - start_time) / perf_counter_scale)
|
||||
mean_elapsed_time = elapsed_time / iterations
|
||||
|
@ -135,8 +140,8 @@ def write_runtime_to_hat_file(hat_path, function_name, mean_time_secs):
|
|||
# Workaround to remove extra empty lines
|
||||
with open(hat_path, "r") as f:
|
||||
lines = f.readlines()
|
||||
lines = [lines[i] for i in range(len(lines)) if not(lines[i] == "\n" \
|
||||
and i < len(lines)-1 and lines[i+1] == "\n")]
|
||||
lines = [lines[i] for i in range(len(lines)) if not(lines[i] == "\n"
|
||||
and i < len(lines)-1 and lines[i+1] == "\n")]
|
||||
with open(hat_path, "w") as f:
|
||||
f.writelines(lines)
|
||||
|
||||
|
@ -148,27 +153,29 @@ def run_benchmark(hat_path, store_in_hat=False, batch_size=10, min_time_in_sec=1
|
|||
functions = benchmark.hat_functions
|
||||
for function_name in functions:
|
||||
print(f"\nBenchmarking function: {function_name}")
|
||||
if "Initialize" in function_name or "_debug_check_allclose" in function_name : # Skip init and debug functions
|
||||
if "Initialize" in function_name or "_debug_check_allclose" in function_name: # Skip init and debug functions
|
||||
continue
|
||||
|
||||
try:
|
||||
_, batch_timings = benchmark.run(function_name,
|
||||
warmup_iterations=batch_size,
|
||||
min_timing_iterations=batch_size,
|
||||
min_time_in_sec=min_time_in_sec,
|
||||
input_sets_minimum_size_MB=input_sets_minimum_size_MB)
|
||||
warmup_iterations=batch_size,
|
||||
min_timing_iterations=batch_size,
|
||||
min_time_in_sec=min_time_in_sec,
|
||||
input_sets_minimum_size_MB=input_sets_minimum_size_MB)
|
||||
|
||||
sorted_batch_means = np.array(sorted(batch_timings)) / batch_size
|
||||
num_batches = len(batch_timings)
|
||||
|
||||
mean_of_means = sorted_batch_means.mean()
|
||||
median_of_means = sorted_batch_means[num_batches//2]
|
||||
mean_of_small_means = sorted_batch_means[0 : num_batches//2].mean()
|
||||
robust_mean_of_means = sorted_batch_means[num_batches//5 : -num_batches//5].mean()
|
||||
mean_of_small_means = sorted_batch_means[0: num_batches//2].mean()
|
||||
robust_mean_of_means = sorted_batch_means[num_batches //
|
||||
5: -num_batches//5].mean()
|
||||
min_of_means = sorted_batch_means[0]
|
||||
|
||||
if store_in_hat:
|
||||
write_runtime_to_hat_file(hat_path, function_name, mean_of_means)
|
||||
write_runtime_to_hat_file(
|
||||
hat_path, function_name, mean_of_means)
|
||||
results.append({"function_name": function_name,
|
||||
"mean": mean_of_means,
|
||||
"median_of_means": median_of_means,
|
||||
|
@ -178,9 +185,11 @@ def run_benchmark(hat_path, store_in_hat=False, batch_size=10, min_time_in_sec=1
|
|||
})
|
||||
except Exception as e:
|
||||
exc_type, exc_val, exc_tb = sys.exc_info()
|
||||
traceback.print_exception(exc_type, exc_val, exc_tb, file=sys.stderr)
|
||||
traceback.print_exception(
|
||||
exc_type, exc_val, exc_tb, file=sys.stderr)
|
||||
print("\nException message: ", e)
|
||||
print(f"WARNING: Failed to run function {function_name}, skipping this benchmark.")
|
||||
print(
|
||||
f"WARNING: Failed to run function {function_name}, skipping this benchmark.")
|
||||
return results
|
||||
|
||||
|
||||
|
@ -191,27 +200,28 @@ def main(argv):
|
|||
" hatlib.benchmark_hat_package <hat_path>\n")
|
||||
|
||||
arg_parser.add_argument("hat_path",
|
||||
help="Path to the HAT file",
|
||||
default=None)
|
||||
help="Path to the HAT file",
|
||||
default=None)
|
||||
arg_parser.add_argument("--store_in_hat",
|
||||
help="If set, will write the duration as meta-data back into the hat file",
|
||||
action='store_true')
|
||||
help="If set, will write the duration as meta-data back into the hat file",
|
||||
action='store_true')
|
||||
arg_parser.add_argument("--results_file",
|
||||
help="Full path where the results will be written",
|
||||
default="results.csv")
|
||||
help="Full path where the results will be written",
|
||||
default="results.csv")
|
||||
arg_parser.add_argument("--batch_size",
|
||||
help="The number of function calls in each batch (at least one full batch is executed)",
|
||||
default=10)
|
||||
help="The number of function calls in each batch (at least one full batch is executed)",
|
||||
default=10)
|
||||
arg_parser.add_argument("--min_time_in_sec",
|
||||
help="Minimum number of seconds to run the benchmark for",
|
||||
default=30)
|
||||
help="Minimum number of seconds to run the benchmark for",
|
||||
default=30)
|
||||
arg_parser.add_argument("--input_sets_minimum_size_MB",
|
||||
help="Minimum size in MB of the input sets. Typically this is large enough to ensure eviction of the biggest cache on the target (e.g. L3 on an desktop CPU)",
|
||||
default=50)
|
||||
help="Minimum size in MB of the input sets. Typically this is large enough to ensure eviction of the biggest cache on the target (e.g. L3 on an desktop CPU)",
|
||||
default=50)
|
||||
|
||||
args = vars(arg_parser.parse_args(argv))
|
||||
|
||||
results = run_benchmark(args["hat_path"], args["store_in_hat"], batch_size=int(args["batch_size"]), min_time_in_sec=int(args["min_time_in_sec"]), input_sets_minimum_size_MB=int(args["input_sets_minimum_size_MB"]))
|
||||
results = run_benchmark(args["hat_path"], args["store_in_hat"], batch_size=int(args["batch_size"]), min_time_in_sec=int(
|
||||
args["min_time_in_sec"]), input_sets_minimum_size_MB=int(args["input_sets_minimum_size_MB"]))
|
||||
df = pd.DataFrame(results)
|
||||
df.to_csv(args["results_file"], index=False)
|
||||
pd.options.display.float_format = '{:8.8f}'.format
|
||||
|
@ -219,8 +229,10 @@ def main(argv):
|
|||
|
||||
print(f"Results saved to {args['results_file']}")
|
||||
|
||||
|
||||
def main_command():
|
||||
main(sys.argv[1:]) # drop the first argument (program name)
|
||||
main(sys.argv[1:]) # drop the first argument (program name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main_command()
|
|
@ -0,0 +1,389 @@
|
|||
import os
|
||||
import pathlib
|
||||
import sys
|
||||
import numpy as np
|
||||
from functools import reduce
|
||||
from typing import Dict, List
|
||||
|
||||
# CUDA stuff
|
||||
# TODO: move from pvnrtc module to cuda entirely to reduce dependencies
|
||||
from pynvrtc.compiler import Program
|
||||
from cuda import cuda, nvrtc
|
||||
|
||||
try:
|
||||
from .arg_info import ArgInfo, verify_args, generate_input_sets
|
||||
except:
|
||||
from arg_info import ArgInfo, verify_args, generate_input_sets
|
||||
|
||||
# lifted from https://github.com/NVIDIA/jitify/blob/master/jitify.hpp
|
||||
HEADER_MAP: Dict[str, str] = {
|
||||
'float.h':
|
||||
"""
|
||||
#pragma once
|
||||
#define FLT_RADIX 2
|
||||
#define FLT_MANT_DIG 24
|
||||
#define DBL_MANT_DIG 53
|
||||
#define FLT_DIG 6
|
||||
#define DBL_DIG 15
|
||||
#define FLT_MIN_EXP -125
|
||||
#define DBL_MIN_EXP -1021
|
||||
#define FLT_MIN_10_EXP -37
|
||||
#define DBL_MIN_10_EXP -307
|
||||
#define FLT_MAX_EXP 128
|
||||
#define DBL_MAX_EXP 1024
|
||||
#define FLT_MAX_10_EXP 38
|
||||
#define DBL_MAX_10_EXP 308
|
||||
#define FLT_MAX 3.4028234e38f
|
||||
#define DBL_MAX 1.7976931348623157e308
|
||||
#define FLT_EPSILON 1.19209289e-7f
|
||||
#define DBL_EPSILON 2.220440492503130e-16
|
||||
#define FLT_MIN 1.1754943e-38f
|
||||
#define DBL_MIN 2.2250738585072013e-308
|
||||
#define FLT_ROUNDS 1
|
||||
#if defined __cplusplus && __cplusplus >= 201103L
|
||||
#define FLT_EVAL_METHOD 0
|
||||
#define DECIMAL_DIG 21
|
||||
#endif
|
||||
""",
|
||||
'limits.h':
|
||||
"""
|
||||
#pragma once
|
||||
#if defined _WIN32 || defined _WIN64
|
||||
#define __WORDSIZE 32
|
||||
#else
|
||||
#if defined __x86_64__ && !defined __ILP32__
|
||||
#define __WORDSIZE 64
|
||||
#else
|
||||
#define __WORDSIZE 32
|
||||
#endif
|
||||
#endif
|
||||
#define MB_LEN_MAX 16
|
||||
#define CHAR_BIT 8
|
||||
#define SCHAR_MIN (-128)
|
||||
#define SCHAR_MAX 127
|
||||
#define UCHAR_MAX 255
|
||||
enum {
|
||||
_JITIFY_CHAR_IS_UNSIGNED = (char)-1 >= 0,
|
||||
CHAR_MIN = _JITIFY_CHAR_IS_UNSIGNED ? 0 : SCHAR_MIN,
|
||||
CHAR_MAX = _JITIFY_CHAR_IS_UNSIGNED ? UCHAR_MAX : SCHAR_MAX,
|
||||
};
|
||||
#define SHRT_MIN (-32768)
|
||||
#define SHRT_MAX 32767
|
||||
#define USHRT_MAX 65535
|
||||
#define INT_MIN (-INT_MAX - 1)
|
||||
#define INT_MAX 2147483647
|
||||
#define UINT_MAX 4294967295U
|
||||
#if __WORDSIZE == 64
|
||||
# define LONG_MAX 9223372036854775807L
|
||||
#else
|
||||
# define LONG_MAX 2147483647L
|
||||
#endif
|
||||
#define LONG_MIN (-LONG_MAX - 1L)
|
||||
#if __WORDSIZE == 64
|
||||
#define ULONG_MAX 18446744073709551615UL
|
||||
#else
|
||||
#define ULONG_MAX 4294967295UL
|
||||
#endif
|
||||
#define LLONG_MAX 9223372036854775807LL
|
||||
#define LLONG_MIN (-LLONG_MAX - 1LL)
|
||||
#define ULLONG_MAX 18446744073709551615ULL
|
||||
""",
|
||||
'stdint.h':
|
||||
"""
|
||||
#pragma once
|
||||
#include <climits>
|
||||
namespace __jitify_stdint_ns {
|
||||
typedef signed char int8_t;
|
||||
typedef signed short int16_t;
|
||||
typedef signed int int32_t;
|
||||
typedef signed long long int64_t;
|
||||
typedef signed char int_fast8_t;
|
||||
typedef signed short int_fast16_t;
|
||||
typedef signed int int_fast32_t;
|
||||
typedef signed long long int_fast64_t;
|
||||
typedef signed char int_least8_t;
|
||||
typedef signed short int_least16_t;
|
||||
typedef signed int int_least32_t;
|
||||
typedef signed long long int_least64_t;
|
||||
typedef signed long long intmax_t;
|
||||
typedef signed long intptr_t; //optional
|
||||
typedef unsigned char uint8_t;
|
||||
typedef unsigned short uint16_t;
|
||||
typedef unsigned int uint32_t;
|
||||
typedef unsigned long long uint64_t;
|
||||
typedef unsigned char uint_fast8_t;
|
||||
typedef unsigned short uint_fast16_t;
|
||||
typedef unsigned int uint_fast32_t;
|
||||
typedef unsigned long long uint_fast64_t;
|
||||
typedef unsigned char uint_least8_t;
|
||||
typedef unsigned short uint_least16_t;
|
||||
typedef unsigned int uint_least32_t;
|
||||
typedef unsigned long long uint_least64_t;
|
||||
typedef unsigned long long uintmax_t;
|
||||
#define INT8_MIN SCHAR_MIN
|
||||
#define INT16_MIN SHRT_MIN
|
||||
#if defined _WIN32 || defined _WIN64
|
||||
#define WCHAR_MIN 0
|
||||
#define WCHAR_MAX USHRT_MAX
|
||||
typedef unsigned long long uintptr_t; //optional
|
||||
#else
|
||||
#define WCHAR_MIN INT_MIN
|
||||
#define WCHAR_MAX INT_MAX
|
||||
typedef unsigned long uintptr_t; //optional
|
||||
#endif
|
||||
#define INT32_MIN INT_MIN
|
||||
#define INT64_MIN LLONG_MIN
|
||||
#define INT8_MAX SCHAR_MAX
|
||||
#define INT16_MAX SHRT_MAX
|
||||
#define INT32_MAX INT_MAX
|
||||
#define INT64_MAX LLONG_MAX
|
||||
#define UINT8_MAX UCHAR_MAX
|
||||
#define UINT16_MAX USHRT_MAX
|
||||
#define UINT32_MAX UINT_MAX
|
||||
#define UINT64_MAX ULLONG_MAX
|
||||
#define INTPTR_MIN LONG_MIN
|
||||
#define INTMAX_MIN LLONG_MIN
|
||||
#define INTPTR_MAX LONG_MAX
|
||||
#define INTMAX_MAX LLONG_MAX
|
||||
#define UINTPTR_MAX ULONG_MAX
|
||||
#define UINTMAX_MAX ULLONG_MAX
|
||||
#define PTRDIFF_MIN INTPTR_MIN
|
||||
#define PTRDIFF_MAX INTPTR_MAX
|
||||
#define SIZE_MAX UINT64_MAX
|
||||
} // namespace __jitify_stdint_ns
|
||||
namespace std { using namespace __jitify_stdint_ns; }
|
||||
using namespace __jitify_stdint_ns;
|
||||
""",
|
||||
'math.h':
|
||||
"""
|
||||
#pragma once
|
||||
namespace __jitify_math_ns {
|
||||
#if __cplusplus >= 201103L
|
||||
#define DEFINE_MATH_UNARY_FUNC_WRAPPER(f) \\
|
||||
inline double f(double x) { return ::f(x); } \\
|
||||
inline float f##f(float x) { return ::f(x); } \\
|
||||
/*inline long double f##l(long double x) { return ::f(x); }*/ \\
|
||||
inline float f(float x) { return ::f(x); } \\
|
||||
/*inline long double f(long double x) { return ::f(x); }*/
|
||||
#else
|
||||
#define DEFINE_MATH_UNARY_FUNC_WRAPPER(f) \\
|
||||
inline double f(double x) { return ::f(x); } \\
|
||||
inline float f##f(float x) { return ::f(x); } \\
|
||||
/*inline long double f##l(long double x) { return ::f(x); }*/
|
||||
#endif
|
||||
DEFINE_MATH_UNARY_FUNC_WRAPPER(cos)
|
||||
DEFINE_MATH_UNARY_FUNC_WRAPPER(sin)
|
||||
DEFINE_MATH_UNARY_FUNC_WRAPPER(tan)
|
||||
DEFINE_MATH_UNARY_FUNC_WRAPPER(acos)
|
||||
DEFINE_MATH_UNARY_FUNC_WRAPPER(asin)
|
||||
DEFINE_MATH_UNARY_FUNC_WRAPPER(atan)
|
||||
template<typename T> inline T atan2(T y, T x) { return ::atan2(y, x); }
|
||||
DEFINE_MATH_UNARY_FUNC_WRAPPER(cosh)
|
||||
DEFINE_MATH_UNARY_FUNC_WRAPPER(sinh)
|
||||
DEFINE_MATH_UNARY_FUNC_WRAPPER(tanh)
|
||||
DEFINE_MATH_UNARY_FUNC_WRAPPER(exp)
|
||||
template<typename T> inline T frexp(T x, int* exp) { return ::frexp(x, exp); }
|
||||
template<typename T> inline T ldexp(T x, int exp) { return ::ldexp(x, exp); }
|
||||
DEFINE_MATH_UNARY_FUNC_WRAPPER(log)
|
||||
DEFINE_MATH_UNARY_FUNC_WRAPPER(log10)
|
||||
template<typename T> inline T modf(T x, T* intpart) { return ::modf(x, intpart); }
|
||||
template<typename T> inline T pow(T x, T y) { return ::pow(x, y); }
|
||||
DEFINE_MATH_UNARY_FUNC_WRAPPER(sqrt)
|
||||
DEFINE_MATH_UNARY_FUNC_WRAPPER(ceil)
|
||||
DEFINE_MATH_UNARY_FUNC_WRAPPER(floor)
|
||||
template<typename T> inline T fmod(T n, T d) { return ::fmod(n, d); }
|
||||
DEFINE_MATH_UNARY_FUNC_WRAPPER(fabs)
|
||||
template<typename T> inline T abs(T x) { return ::abs(x); }
|
||||
#if __cplusplus >= 201103L
|
||||
DEFINE_MATH_UNARY_FUNC_WRAPPER(acosh)
|
||||
DEFINE_MATH_UNARY_FUNC_WRAPPER(asinh)
|
||||
DEFINE_MATH_UNARY_FUNC_WRAPPER(atanh)
|
||||
DEFINE_MATH_UNARY_FUNC_WRAPPER(exp2)
|
||||
DEFINE_MATH_UNARY_FUNC_WRAPPER(expm1)
|
||||
template<typename T> inline int ilogb(T x) { return ::ilogb(x); }
|
||||
DEFINE_MATH_UNARY_FUNC_WRAPPER(log1p)
|
||||
DEFINE_MATH_UNARY_FUNC_WRAPPER(log2)
|
||||
DEFINE_MATH_UNARY_FUNC_WRAPPER(logb)
|
||||
template<typename T> inline T scalbn (T x, int n) { return ::scalbn(x, n); }
|
||||
template<typename T> inline T scalbln(T x, long n) { return ::scalbn(x, n); }
|
||||
DEFINE_MATH_UNARY_FUNC_WRAPPER(cbrt)
|
||||
template<typename T> inline T hypot(T x, T y) { return ::hypot(x, y); }
|
||||
DEFINE_MATH_UNARY_FUNC_WRAPPER(erf)
|
||||
DEFINE_MATH_UNARY_FUNC_WRAPPER(erfc)
|
||||
DEFINE_MATH_UNARY_FUNC_WRAPPER(tgamma)
|
||||
DEFINE_MATH_UNARY_FUNC_WRAPPER(lgamma)
|
||||
DEFINE_MATH_UNARY_FUNC_WRAPPER(trunc)
|
||||
DEFINE_MATH_UNARY_FUNC_WRAPPER(round)
|
||||
template<typename T> inline long lround(T x) { return ::lround(x); }
|
||||
template<typename T> inline long long llround(T x) { return ::llround(x); }
|
||||
DEFINE_MATH_UNARY_FUNC_WRAPPER(rint)
|
||||
template<typename T> inline long lrint(T x) { return ::lrint(x); }
|
||||
template<typename T> inline long long llrint(T x) { return ::llrint(x); }
|
||||
DEFINE_MATH_UNARY_FUNC_WRAPPER(nearbyint)
|
||||
// TODO: remainder, remquo, copysign, nan, nextafter, nexttoward, fdim,
|
||||
// fmax, fmin, fma
|
||||
#endif
|
||||
#undef DEFINE_MATH_UNARY_FUNC_WRAPPER
|
||||
} // namespace __jitify_math_ns
|
||||
namespace std { using namespace __jitify_math_ns; }
|
||||
#define M_PI 3.14159265358979323846
|
||||
// Note: Global namespace already includes CUDA math funcs
|
||||
//using namespace __jitify_math_ns;
|
||||
""",
|
||||
'cuda_fp16.h': "",
|
||||
}
|
||||
|
||||
HEADER_MAP['climits'] = HEADER_MAP['limits.h']
|
||||
|
||||
|
||||
def ASSERT_DRV(err):
|
||||
if isinstance(err, cuda.CUresult):
|
||||
if err != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise RuntimeError("Cuda Error: {}".format(
|
||||
cuda.cuGetErrorString(err)[1].decode('utf-8')))
|
||||
elif isinstance(err, nvrtc.nvrtcResult):
|
||||
if err != nvrtc.nvrtcResult.NVRTC_SUCCESS:
|
||||
raise RuntimeError("Nvrtc Error: {}".format(err))
|
||||
else:
|
||||
raise RuntimeError("Unknown error type: {}".format(err))
|
||||
|
||||
|
||||
def _find_cuda_incl_path() -> pathlib.Path:
|
||||
"Tries to find the CUDA include path."
|
||||
cuda_path = os.getenv("CUDA_PATH")
|
||||
if not cuda_path:
|
||||
if sys.platform == 'linux':
|
||||
cuda_path = pathlib.Path("/usr/local/cuda/include")
|
||||
if not (cuda_path.exists() and cuda_path.is_dir()):
|
||||
cuda_path = None
|
||||
elif sys.platform == 'win32':
|
||||
...
|
||||
elif sys.platform == 'darwin':
|
||||
...
|
||||
else:
|
||||
cuda_path /= "include"
|
||||
|
||||
return cuda_path
|
||||
|
||||
|
||||
def compile_cuda_program(cuda_src_path: pathlib.Path, func_name):
|
||||
src = cuda_src_path.read_text()
|
||||
|
||||
prog = Program(src=src, name=func_name,
|
||||
headers=HEADER_MAP.values(), include_names=HEADER_MAP.keys())
|
||||
ptx = prog.compile([
|
||||
'-use_fast_math',
|
||||
'-default-device',
|
||||
'-std=c++11',
|
||||
'-arch=sm_52', # TODO: is this needed?
|
||||
])
|
||||
|
||||
return ptx
|
||||
|
||||
|
||||
def initialize_cuda():
|
||||
# Initialize CUDA Driver API
|
||||
err, = cuda.cuInit(0)
|
||||
ASSERT_DRV(err)
|
||||
|
||||
# Retrieve handle for device 0
|
||||
# TODO: add support for multiple CUDA devices?
|
||||
err, cuDevice = cuda.cuDeviceGet(0)
|
||||
ASSERT_DRV(err)
|
||||
|
||||
# Create context
|
||||
err, context = cuda.cuCtxCreate(0, cuDevice)
|
||||
ASSERT_DRV(err)
|
||||
|
||||
|
||||
def get_func_from_ptx(ptx, func_name):
|
||||
# Note: Incompatible --gpu-architecture would be detected here
|
||||
err, ptx_mod = cuda.cuModuleLoadData(ptx.encode('utf-8'))
|
||||
ASSERT_DRV(err)
|
||||
err, kernel = cuda.cuModuleGetFunction(ptx_mod, func_name.encode('utf-8'))
|
||||
ASSERT_DRV(err)
|
||||
|
||||
return kernel
|
||||
|
||||
|
||||
def _arg_size(arg_info: ArgInfo):
|
||||
return arg_info.element_num_bytes * reduce(lambda x, y: x*y, arg_info.numpy_shape)
|
||||
|
||||
|
||||
def transfer_mem_host_to_cuda(device_args: List, host_args: List[np.array], arg_infos: List[ArgInfo]):
|
||||
for device_arg, host_arg, arg_info in zip(device_args, host_args, arg_infos):
|
||||
if 'input' in arg_info.usage:
|
||||
err, = cuda.cuMemcpyHtoD(
|
||||
device_arg, host_arg.ctypes.data, _arg_size(arg_info))
|
||||
ASSERT_DRV(err)
|
||||
|
||||
|
||||
def transfer_mem_cuda_to_host(device_args: List, host_args: List[np.array], arg_infos: List[ArgInfo]):
|
||||
for device_arg, host_arg, arg_info in zip(device_args, host_args, arg_infos):
|
||||
if 'output' in arg_info.usage:
|
||||
err, = cuda.cuMemcpyDtoH(
|
||||
host_arg.ctypes.data, device_arg, _arg_size(arg_info))
|
||||
ASSERT_DRV(err)
|
||||
|
||||
|
||||
def allocate_cuda_mem(arg_infos: List[ArgInfo]):
|
||||
device_mem = []
|
||||
for arg in arg_infos:
|
||||
err, mem = cuda.cuMemAlloc(_arg_size(arg))
|
||||
ASSERT_DRV(err)
|
||||
device_mem.append(mem)
|
||||
|
||||
return device_mem
|
||||
|
||||
|
||||
def device_args_to_ptr_list(device_args: List):
|
||||
# CUDA python example says this is subject to change
|
||||
ptrs = [
|
||||
np.array([int(d_arg)], dtype=np.uint64) for d_arg in device_args
|
||||
]
|
||||
ptrs = np.array([ptr.ctypes.data for ptr in ptrs], dtype=np.uint64)
|
||||
|
||||
return ptrs
|
||||
|
||||
|
||||
def create_loader_for_device_function(device_func, hat_details):
|
||||
hat_path: pathlib.Path = hat_details.path
|
||||
cuda_src_path: pathlib.Path = hat_path.parent / device_func["provider"]
|
||||
func_name = device_func["name"]
|
||||
|
||||
ptx = compile_cuda_program(cuda_src_path, func_name)
|
||||
|
||||
initialize_cuda()
|
||||
|
||||
kernel = get_func_from_ptx(ptx, func_name)
|
||||
|
||||
hat_arg_descriptions = device_func["arguments"]
|
||||
arg_infos = [ArgInfo(d) for d in hat_arg_descriptions]
|
||||
launch_parameters = device_func["launch_parameters"]
|
||||
|
||||
def f(*args):
|
||||
verify_args(args, arg_infos, func_name)
|
||||
device_mem = allocate_cuda_mem(arg_infos)
|
||||
transfer_mem_host_to_cuda(
|
||||
device_args=device_mem, host_args=args, arg_infos=arg_infos)
|
||||
ptrs = device_args_to_ptr_list(device_mem)
|
||||
|
||||
err, stream = cuda.cuStreamCreate(0)
|
||||
ASSERT_DRV(err)
|
||||
|
||||
err, = cuda.cuLaunchKernel(
|
||||
kernel,
|
||||
*launch_parameters, # [ grid[x-z], block[x-z] ]
|
||||
0, # dynamic shared memory
|
||||
stream, # stream
|
||||
ptrs.ctypes.data, # kernel arguments
|
||||
0, # extra (ignore)
|
||||
)
|
||||
ASSERT_DRV(err)
|
||||
err, = cuda.cuStreamSynchronize(stream)
|
||||
ASSERT_DRV(err)
|
||||
|
||||
transfer_mem_cuda_to_host(
|
||||
device_args=device_mem, host_args=args, arg_infos=arg_infos)
|
||||
|
||||
return f
|
|
@ -0,0 +1,165 @@
|
|||
"""Loads a dynamically-linked HAT package in Python
|
||||
|
||||
Call 'load' to load a HAT package in Python. After loading, call the HAT
|
||||
functions using numpy arrays as arguments. The shape, element type, and order
|
||||
of each numpy array should exactly match the requirements of the HAT function.
|
||||
|
||||
For example:
|
||||
import numpy as np
|
||||
import hatlib as hat
|
||||
|
||||
# load the package
|
||||
package = hat.load("my_package.hat")
|
||||
|
||||
# print the function names
|
||||
for name in package.names:
|
||||
print(name)
|
||||
|
||||
# create numpy arguments with the correct shape, dtype, and order
|
||||
A = np.ones([256,32], dtype=np.float32, order="C")
|
||||
B = np.ones([32,256], dtype=np.float32, order="C")
|
||||
D = np.ones([256,32], dtype=np.float32, order="C")
|
||||
E = np.ones([256,32], dtype=np.float32, order="C")
|
||||
|
||||
# call a package function named 'my_func_698b5e5c'
|
||||
package.my_func_698b5e5c(A, B, D, E)
|
||||
"""
|
||||
|
||||
import ctypes
|
||||
import numpy as np
|
||||
import pathlib
|
||||
import sys
|
||||
import toml
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Tuple
|
||||
|
||||
try:
|
||||
from . import hat_file
|
||||
from .arg_info import ArgInfo, verify_args, generate_input_sets
|
||||
except:
|
||||
import hat_file
|
||||
from arg_info import ArgInfo, verify_args, generate_input_sets
|
||||
|
||||
try:
|
||||
try:
|
||||
from . import cuda_loader
|
||||
except ModuleNotFoundError:
|
||||
import cuda_loader
|
||||
except:
|
||||
CUDA_AVAILABLE = False
|
||||
else:
|
||||
CUDA_AVAILABLE = True
|
||||
|
||||
# Remove when ROCM is finally available
|
||||
ROCM_AVAILABLE = False
|
||||
|
||||
|
||||
def generate_input_sets_for_hat_file(hat_path):
|
||||
hat_path = pathlib.Path(hat_path).absolute()
|
||||
t: hat_file.HATFile = toml.load(hat_path)
|
||||
return {
|
||||
func_name:
|
||||
generate_input_sets(list(map(ArgInfo, func_desc["arguments"])))
|
||||
for func_name, func_desc in t["functions"].items()
|
||||
}
|
||||
|
||||
|
||||
class AttributeDict(OrderedDict):
|
||||
""" Dictionary that allows entries to be accessed like attributes
|
||||
"""
|
||||
__getattr__ = OrderedDict.__getitem__
|
||||
|
||||
@property
|
||||
def names(self):
|
||||
return list(self.keys())
|
||||
|
||||
def __getitem__(self, key):
|
||||
for k, v in self.items():
|
||||
if k.startswith(key):
|
||||
return v
|
||||
return OrderedDict.__getitem__(key)
|
||||
|
||||
|
||||
def hat_description_to_python_function(hat_description: hat_file.HATFile,
|
||||
hat_details: AttributeDict):
|
||||
""" Creates a callable function based on a function description in a HAT
|
||||
package
|
||||
"""
|
||||
|
||||
for func_name, func_desc in hat_description["functions"].items():
|
||||
|
||||
func_desc: hat_file.Function
|
||||
func_name: str
|
||||
|
||||
launches = func_desc.get("launches")
|
||||
if not launches:
|
||||
|
||||
hat_arg_descriptions = func_desc["arguments"]
|
||||
function_name = func_desc["name"]
|
||||
hat_library: ctypes.CDLL = hat_details.shared_lib
|
||||
|
||||
def f(*args):
|
||||
# verify that the (numpy) input args match the description in
|
||||
# the hat file
|
||||
arg_infos = [ArgInfo(d) for d in hat_arg_descriptions]
|
||||
verify_args(args, arg_infos, function_name)
|
||||
|
||||
# prepare the args to the hat package
|
||||
hat_args = [
|
||||
arg.ctypes.data_as(arg_info.ctypes_pointer_type)
|
||||
for arg, arg_info in zip(args, arg_infos)
|
||||
]
|
||||
|
||||
# call the function in the hat package
|
||||
hat_library[function_name](*hat_args)
|
||||
|
||||
yield func_name, f
|
||||
|
||||
else:
|
||||
device_func = hat_description.get("device_functions",
|
||||
{}).get(launches)
|
||||
|
||||
func_runtime = func_desc.get("runtime")
|
||||
if not device_func:
|
||||
raise RuntimeError(
|
||||
f"Couldn't find device function for loader: " + launches)
|
||||
if not func_runtime:
|
||||
raise RuntimeError(f"Couldn't find runtime for loader: " +
|
||||
launches)
|
||||
if func_runtime == "CUDA" and CUDA_AVAILABLE:
|
||||
yield func_name, cuda_loader.create_loader_for_device_function(
|
||||
device_func, hat_details)
|
||||
elif func_runtime == "ROCM" and ROCM_AVAILABLE:
|
||||
yield func_name, rocm_loader.create_loader_for_device_function(
|
||||
device_func, hat_details)
|
||||
|
||||
|
||||
def load(hat_path):
|
||||
""" Creates a class with static functions based on the function
|
||||
descriptions in a HAT package
|
||||
"""
|
||||
# load the function decscriptions from the hat file
|
||||
hat_path = pathlib.Path(hat_path).absolute()
|
||||
t: hat_file.HATFile = toml.load(hat_path)
|
||||
hat_details = AttributeDict({"path": hat_path})
|
||||
|
||||
# function_descriptions = t["functions"]
|
||||
hat_binary_filename = t["dependencies"]["link_target"]
|
||||
hat_binary_path = hat_path.parent / hat_binary_filename
|
||||
|
||||
# check that the HAT library has a supported file extension
|
||||
supported_extensions = [".dll", ".so"]
|
||||
extension = hat_binary_path.suffix
|
||||
if extension and extension not in supported_extensions:
|
||||
sys.exit(f"Unsupported HAT library extension: {extension}")
|
||||
|
||||
# load the hat_library:
|
||||
hat_library = ctypes.cdll.LoadLibrary(
|
||||
str(hat_binary_path)) if extension else None
|
||||
hat_details["shared_lib"] = hat_library
|
||||
|
||||
# create dictionary of functions defined in the hat file
|
||||
function_dict = AttributeDict(
|
||||
dict(hat_description_to_python_function(t, hat_details)))
|
||||
return function_dict
|
|
@ -10,6 +10,7 @@ import tomlkit
|
|||
|
||||
# TODO : type-checking on leaf node values
|
||||
|
||||
|
||||
def _read_toml_file(filepath):
|
||||
path = os.path.abspath(filepath)
|
||||
toml_doc = None
|
||||
|
@ -18,11 +19,13 @@ def _read_toml_file(filepath):
|
|||
toml_doc = tomlkit.parse(file_contents)
|
||||
return toml_doc
|
||||
|
||||
|
||||
def _check_required_table_entry(table, key):
|
||||
if key not in table:
|
||||
# TODO : add more context to this error message
|
||||
raise ValueError(f"Invalid HAT file: missing required key {key}")
|
||||
|
||||
|
||||
def _check_required_table_entries(table, keys):
|
||||
for key in keys:
|
||||
_check_required_table_entry(table, key)
|
||||
|
@ -34,26 +37,32 @@ class ParameterType(Enum):
|
|||
Element = "element"
|
||||
Void = "void"
|
||||
|
||||
|
||||
class UsageType(Enum):
|
||||
Input = "input"
|
||||
Output = "output"
|
||||
InputOutput = "input_output"
|
||||
|
||||
|
||||
class CallingConventionType(Enum):
|
||||
StdCall = "stdcall"
|
||||
CDecl = "cdecl"
|
||||
FastCall = "fastcall"
|
||||
VectorCall = "vectorcall"
|
||||
Device = "devicecall"
|
||||
|
||||
|
||||
class TargetType(Enum):
|
||||
CPU = "CPU"
|
||||
GPU = "GPU"
|
||||
|
||||
|
||||
class OperatingSystem(Enum):
|
||||
Windows = "windows"
|
||||
MacOS = "macos"
|
||||
Linux = "linux"
|
||||
|
||||
|
||||
@dataclass
|
||||
class AuxiliarySupportedTable:
|
||||
AuxiliaryKey = "auxiliary"
|
||||
|
@ -70,6 +79,7 @@ class AuxiliarySupportedTable:
|
|||
else:
|
||||
return {}
|
||||
|
||||
|
||||
@dataclass
|
||||
class Description(AuxiliarySupportedTable):
|
||||
TableName: str = "description"
|
||||
|
@ -91,10 +101,12 @@ class Description(AuxiliarySupportedTable):
|
|||
|
||||
@staticmethod
|
||||
def parse_from_table(table):
|
||||
return Description(author=table["author"],
|
||||
version=table["version"],
|
||||
license_url=table["license_url"],
|
||||
auxiliary=AuxiliarySupportedTable.parse_auxiliary(table))
|
||||
return Description(
|
||||
author=table["author"],
|
||||
version=table["version"],
|
||||
license_url=table["license_url"],
|
||||
auxiliary=AuxiliarySupportedTable.parse_auxiliary(table))
|
||||
|
||||
|
||||
@dataclass
|
||||
class Parameter:
|
||||
|
@ -135,9 +147,14 @@ class Parameter:
|
|||
# TODO : change "usage" to "role" in schema
|
||||
@staticmethod
|
||||
def parse_from_table(param_table):
|
||||
required_table_entries = ["name", "description", "logical_type", "declared_type", "element_type", "usage"]
|
||||
required_table_entries = [
|
||||
"name", "description", "logical_type", "declared_type",
|
||||
"element_type", "usage"
|
||||
]
|
||||
_check_required_table_entries(param_table, required_table_entries)
|
||||
affine_array_required_table_entries = ["shape", "affine_map", "affine_offset"]
|
||||
affine_array_required_table_entries = [
|
||||
"shape", "affine_map", "affine_offset"
|
||||
]
|
||||
runtime_array_required_table_entries = ["size"]
|
||||
|
||||
name = param_table["name"]
|
||||
|
@ -147,14 +164,23 @@ class Parameter:
|
|||
element_type = param_table["element_type"]
|
||||
usage = UsageType(param_table["usage"])
|
||||
|
||||
param = Parameter(name=name, description=description, logical_type=logical_type, declared_type=declared_type, element_type=element_type, usage=usage)
|
||||
param = Parameter(name=name,
|
||||
description=description,
|
||||
logical_type=logical_type,
|
||||
declared_type=declared_type,
|
||||
element_type=element_type,
|
||||
usage=usage)
|
||||
|
||||
if logical_type == ParameterType.AffineArray:
|
||||
_check_required_table_entries(param_table, affine_array_required_table_entries)
|
||||
_check_required_table_entries(param_table,
|
||||
affine_array_required_table_entries)
|
||||
param.shape = param_table["shape"]
|
||||
param.affine_map = param_table["affine_map"]
|
||||
param.affine_offset = param_table["affine_offset"]
|
||||
|
||||
elif logical_type == ParameterType.RuntimeArray:
|
||||
_check_required_table_entries(param_table, runtime_array_required_table_entries)
|
||||
_check_required_table_entries(
|
||||
param_table, runtime_array_required_table_entries)
|
||||
param.size = param_table["size"]
|
||||
|
||||
return param
|
||||
|
@ -162,13 +188,20 @@ class Parameter:
|
|||
|
||||
@dataclass
|
||||
class Function(AuxiliarySupportedTable):
|
||||
name: str = ""
|
||||
description: str = ""
|
||||
calling_convention: CallingConventionType = None
|
||||
# required
|
||||
arguments: list = field(default_factory=list)
|
||||
return_info: Parameter = None
|
||||
calling_convention: CallingConventionType = None
|
||||
description: str = ""
|
||||
hat_file: any = None
|
||||
link_target: Path = None
|
||||
name: str = ""
|
||||
return_info: Parameter = None
|
||||
|
||||
# optional
|
||||
launch_parameters: list = field(default_factory=list)
|
||||
launches: str = ""
|
||||
provider: str = ""
|
||||
runtime: str = ""
|
||||
|
||||
def to_table(self):
|
||||
table = tomlkit.table()
|
||||
|
@ -179,7 +212,22 @@ class Function(AuxiliarySupportedTable):
|
|||
arg_array = tomlkit.array()
|
||||
for arg_table in arg_tables:
|
||||
arg_array.append(arg_table)
|
||||
table.add("arguments", arg_array) # TODO : figure out why this isn't indenting after serialization in some cases
|
||||
table.add(
|
||||
"arguments", arg_array
|
||||
) # TODO : figure out why this isn't indenting after serialization in some cases
|
||||
|
||||
if self.launch_parameters:
|
||||
table.add("launch_parameters", self.launch_parameters)
|
||||
|
||||
if self.launches:
|
||||
table.add("launches", self.launches)
|
||||
|
||||
if self.provider:
|
||||
table.add("provider", self.provider)
|
||||
|
||||
if self.runtime:
|
||||
table.add("runtime", self.runtime)
|
||||
|
||||
table.add("return", self.return_info.to_table())
|
||||
|
||||
self.add_auxiliary_table(table)
|
||||
|
@ -188,72 +236,153 @@ class Function(AuxiliarySupportedTable):
|
|||
|
||||
@staticmethod
|
||||
def parse_from_table(function_table):
|
||||
required_table_entries = ["name", "description", "calling_convention", "arguments", "return"]
|
||||
required_table_entries = [
|
||||
"name", "description", "calling_convention", "arguments", "return"
|
||||
]
|
||||
_check_required_table_entries(function_table, required_table_entries)
|
||||
arguments = [Parameter.parse_from_table(param_table) for param_table in function_table["arguments"]]
|
||||
arguments = [
|
||||
Parameter.parse_from_table(param_table)
|
||||
for param_table in function_table["arguments"]
|
||||
]
|
||||
|
||||
launch_parameters = function_table[
|
||||
"launch_parameters"] if "launch_parameters" in function_table else []
|
||||
|
||||
launches = function_table[
|
||||
"launches"] if "launches" in function_table else ""
|
||||
|
||||
provider = function_table[
|
||||
"provider"] if "provider" in function_table else ""
|
||||
|
||||
runtime = function_table[
|
||||
"runtime"] if "runtime" in function_table else ""
|
||||
|
||||
return_info = Parameter.parse_from_table(function_table["return"])
|
||||
return Function(name=function_table["name"],
|
||||
description=function_table["description"],
|
||||
calling_convention=CallingConventionType(function_table["calling_convention"]),
|
||||
arguments=arguments,
|
||||
return_info=return_info,
|
||||
auxiliary=AuxiliarySupportedTable.parse_auxiliary(function_table))
|
||||
|
||||
return Function(
|
||||
name=function_table["name"],
|
||||
description=function_table["description"],
|
||||
calling_convention=CallingConventionType(
|
||||
function_table["calling_convention"]),
|
||||
arguments=arguments,
|
||||
return_info=return_info,
|
||||
launch_parameters=launch_parameters,
|
||||
launches=launches,
|
||||
provider=provider,
|
||||
runtime=runtime,
|
||||
auxiliary=AuxiliarySupportedTable.parse_auxiliary(function_table))
|
||||
|
||||
|
||||
class FunctionTable:
|
||||
TableName = "functions"
|
||||
class FunctionTableCommon:
|
||||
def __init__(self, function_map):
|
||||
self.function_map = function_map
|
||||
self.functions = self.function_map.values()
|
||||
|
||||
def to_table(self):
|
||||
serialized_map = { function_key : self.function_map[function_key].to_table() for function_key in self.function_map }
|
||||
func_table = tomlkit.table()
|
||||
for function_key in self.function_map:
|
||||
func_table.add(function_key, self.function_map[function_key].to_table())
|
||||
func_table.add(function_key,
|
||||
self.function_map[function_key].to_table())
|
||||
return func_table
|
||||
|
||||
@staticmethod
|
||||
def parse_from_table(all_functions_table):
|
||||
function_map = {function_key: Function.parse_from_table(all_functions_table[function_key]) for function_key in all_functions_table}
|
||||
return FunctionTable(function_map)
|
||||
@classmethod
|
||||
def parse_from_table(cls, all_functions_table):
|
||||
function_map = {
|
||||
function_key:
|
||||
Function.parse_from_table(all_functions_table[function_key])
|
||||
for function_key in all_functions_table
|
||||
}
|
||||
return cls(function_map)
|
||||
|
||||
|
||||
class FunctionTable(FunctionTableCommon):
|
||||
TableName = "functions"
|
||||
|
||||
|
||||
class DeviceFunctionTable(FunctionTableCommon):
|
||||
TableName = "device_functions"
|
||||
|
||||
|
||||
@dataclass
|
||||
class Target:
|
||||
|
||||
@dataclass
|
||||
class Required:
|
||||
|
||||
@dataclass
|
||||
class CPU:
|
||||
TableName = TargetType.CPU.value
|
||||
|
||||
# required
|
||||
architecture: str = ""
|
||||
extensions: list = field(default_factory=list)
|
||||
|
||||
# optional
|
||||
runtime: str = ""
|
||||
|
||||
def to_table(self):
|
||||
table = tomlkit.table()
|
||||
table.add("architecture", self.architecture)
|
||||
table.add("extensions", self.extensions)
|
||||
|
||||
if self.runtime:
|
||||
table.add("runtime", self.runtime)
|
||||
|
||||
return table
|
||||
|
||||
@staticmethod
|
||||
def parse_from_table(table):
|
||||
required_table_entries = ["architecture", "extensions"]
|
||||
_check_required_table_entries(table, required_table_entries)
|
||||
return Target.Required.CPU(architecture=table["architecture"], extensions=table["extensions"])
|
||||
|
||||
# TODO : support GPU
|
||||
runtime = table.get("runtime", "")
|
||||
|
||||
return Target.Required.CPU(
|
||||
architecture=table["architecture"],
|
||||
extensions=table["extensions"],
|
||||
runtime=runtime)
|
||||
|
||||
@dataclass
|
||||
class GPU:
|
||||
TableName = TargetType.CPU.value
|
||||
TableName = TargetType.GPU.value
|
||||
blocks: int = 0
|
||||
instruction_set_version: str = ""
|
||||
min_threads: int = 0
|
||||
min_global_memory_KB: int = 0
|
||||
min_shared_memory_KB: int = 0
|
||||
min_texture_memory_KB: int = 0
|
||||
model: str = ""
|
||||
runtime: str = ""
|
||||
|
||||
def to_table(self):
|
||||
return tomlkit.table()
|
||||
table = tomlkit.table()
|
||||
table.add("model", self.model)
|
||||
table.add("runtime", self.runtime)
|
||||
table.add("blocks", self.blocks)
|
||||
table.add("instruction_set_version",
|
||||
self.instruction_set_version)
|
||||
table.add("min_threads", self.min_threads)
|
||||
table.add("min_global_memory_KB", self.min_global_memory_KB)
|
||||
table.add("min_shared_memory_KB", self.min_shared_memory_KB)
|
||||
table.add("min_texture_memory_KB", self.min_texture_memory_KB)
|
||||
|
||||
return table
|
||||
|
||||
@staticmethod
|
||||
def parse_from_table(table):
|
||||
pass
|
||||
required_table_entries = [
|
||||
"runtime",
|
||||
"model",
|
||||
]
|
||||
_check_required_table_entries(table, required_table_entries)
|
||||
|
||||
return Target.Required.GPU(
|
||||
runtime=table["runtime"],
|
||||
model=table["model"],
|
||||
blocks=table["blocks"],
|
||||
instruction_set_version=table["instruction_set_version"],
|
||||
min_threads=table["min_threads"],
|
||||
min_global_memory_KB=table["min_global_memory_KB"],
|
||||
min_shared_memory_KB=table["min_shared_memory_KB"],
|
||||
min_texture_memory_KB=table["min_texture_memory_KB"])
|
||||
|
||||
TableName = "required"
|
||||
os: OperatingSystem = None
|
||||
|
@ -264,7 +393,7 @@ class Target:
|
|||
table = tomlkit.table()
|
||||
table.add("os", self.os.value)
|
||||
table.add(Target.Required.CPU.TableName, self.cpu.to_table())
|
||||
if self.gpu is not None:
|
||||
if self.gpu and self.gpu.runtime:
|
||||
table.add(Target.Required.GPU.TableName, self.gpu.to_table())
|
||||
return table
|
||||
|
||||
|
@ -272,9 +401,11 @@ class Target:
|
|||
def parse_from_table(table):
|
||||
required_table_entries = ["os", Target.Required.CPU.TableName]
|
||||
_check_required_table_entries(table, required_table_entries)
|
||||
cpu_info = Target.Required.CPU.parse_from_table(table[Target.Required.CPU.TableName])
|
||||
cpu_info = Target.Required.CPU.parse_from_table(
|
||||
table[Target.Required.CPU.TableName])
|
||||
if Target.Required.GPU.TableName in table:
|
||||
gpu_info = Target.Required.GPU.parse_from_table(table[Target.Required.GPU.TableName])
|
||||
gpu_info = Target.Required.GPU.parse_from_table(
|
||||
table[Target.Required.GPU.TableName])
|
||||
else:
|
||||
gpu_info = Target.Required.GPU()
|
||||
return Target.Required(os=table["os"], cpu=cpu_info, gpu=gpu_info)
|
||||
|
@ -298,20 +429,24 @@ class Target:
|
|||
table = tomlkit.table()
|
||||
table.add(Target.Required.TableName, self.required.to_table())
|
||||
if self.optimized_for is not None:
|
||||
table.add(Target.OptimizedFor.TableName, self.optimized_for.to_table())
|
||||
table.add(Target.OptimizedFor.TableName,
|
||||
self.optimized_for.to_table())
|
||||
return table
|
||||
|
||||
@staticmethod
|
||||
def parse_from_table(target_table):
|
||||
required_table_entries = [Target.Required.TableName]
|
||||
_check_required_table_entries(target_table, required_table_entries)
|
||||
required_data = Target.Required.parse_from_table(target_table[Target.Required.TableName])
|
||||
required_data = Target.Required.parse_from_table(
|
||||
target_table[Target.Required.TableName])
|
||||
if Target.OptimizedFor.TableName in target_table:
|
||||
optimized_for_data = Target.OptimizedFor.parse_from_table(target_table[Target.OptimizedFor.TableName])
|
||||
optimized_for_data = Target.OptimizedFor.parse_from_table(
|
||||
target_table[Target.OptimizedFor.TableName])
|
||||
else:
|
||||
optimized_for_data = Target.OptimizedFor()
|
||||
return Target(required=required_data, optimized_for=optimized_for_data)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LibraryReference:
|
||||
name: str = ""
|
||||
|
@ -328,8 +463,8 @@ class LibraryReference:
|
|||
@staticmethod
|
||||
def parse_from_table(table):
|
||||
return LibraryReference(name=table["name"],
|
||||
version=table["version"],
|
||||
target_file=table["target_file"])
|
||||
version=table["version"],
|
||||
target_file=table["target_file"])
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -355,12 +490,18 @@ class Dependencies(AuxiliarySupportedTable):
|
|||
@staticmethod
|
||||
def parse_from_table(dependencies_table):
|
||||
required_table_entries = ["link_target", "deploy_files", "dynamic"]
|
||||
_check_required_table_entries(dependencies_table, required_table_entries)
|
||||
dynamic = [LibraryReference.parse_from_table(lib_ref_table) for lib_ref_table in dependencies_table["dynamic"]]
|
||||
_check_required_table_entries(dependencies_table,
|
||||
required_table_entries)
|
||||
dynamic = [
|
||||
LibraryReference.parse_from_table(lib_ref_table)
|
||||
for lib_ref_table in dependencies_table["dynamic"]
|
||||
]
|
||||
return Dependencies(link_target=dependencies_table["link_target"],
|
||||
deploy_files=dependencies_table["deploy_files"],
|
||||
dynamic=dynamic,
|
||||
auxiliary=AuxiliarySupportedTable.parse_auxiliary(dependencies_table))
|
||||
auxiliary=AuxiliarySupportedTable.parse_auxiliary(
|
||||
dependencies_table))
|
||||
|
||||
|
||||
@dataclass
|
||||
class CompiledWith:
|
||||
|
@ -386,13 +527,18 @@ class CompiledWith:
|
|||
@staticmethod
|
||||
def parse_from_table(compiled_with_table):
|
||||
required_table_entries = ["compiler", "flags", "crt", "libraries"]
|
||||
_check_required_table_entries(compiled_with_table, required_table_entries)
|
||||
libraries = [LibraryReference.parse_from_table(lib_ref_table) for lib_ref_table in compiled_with_table["libraries"]]
|
||||
_check_required_table_entries(compiled_with_table,
|
||||
required_table_entries)
|
||||
libraries = [
|
||||
LibraryReference.parse_from_table(lib_ref_table)
|
||||
for lib_ref_table in compiled_with_table["libraries"]
|
||||
]
|
||||
return CompiledWith(compiler=compiled_with_table["compiler"],
|
||||
flags=compiled_with_table["flags"],
|
||||
crt=compiled_with_table["crt"],
|
||||
libraries=libraries)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Declaration:
|
||||
TableName = "declaration"
|
||||
|
@ -406,14 +552,16 @@ class Declaration:
|
|||
@staticmethod
|
||||
def parse_from_table(declaration_table):
|
||||
required_table_entries = ["code"]
|
||||
_check_required_table_entries(declaration_table, required_table_entries)
|
||||
_check_required_table_entries(declaration_table,
|
||||
required_table_entries)
|
||||
return Declaration(code=declaration_table["code"])
|
||||
|
||||
|
||||
@dataclass
|
||||
class HATFile:
|
||||
"""Encapsulates a HAT file. An instance of this class can be created by calling the
|
||||
"""Encapsulates a HAT file. An instance of this class can be created by calling the
|
||||
Deserialize class method e.g.:
|
||||
some_hat_file = Deserialize('someFile.hat')
|
||||
some_hat_file = Deserialize('someFile.hat')
|
||||
Similarly, HAT files can be serialized but creating/modifying a HATFile instance
|
||||
and then calling Serilize e.g.:
|
||||
some_hat_file.name = 'some new name'
|
||||
|
@ -422,8 +570,11 @@ class HATFile:
|
|||
name: str = ""
|
||||
description: Description = None
|
||||
_function_table: FunctionTable = None
|
||||
_device_function_table: DeviceFunctionTable = None
|
||||
functions: list = field(default_factory=list)
|
||||
device_functions: list = field(default_factory=list)
|
||||
function_map: dict = field(default_factory=dict)
|
||||
device_function_map: list = field(default_factory=list)
|
||||
target: Target = None
|
||||
dependencies: Dependencies = None
|
||||
compiled_with: CompiledWith = None
|
||||
|
@ -438,7 +589,13 @@ class HATFile:
|
|||
self.function_map = self._function_table.function_map
|
||||
for func in self.functions:
|
||||
func.hat_file = self
|
||||
func.link_target = Path(self.path).resolve().parent / self.dependencies.link_target
|
||||
func.link_target = Path(self.path).resolve(
|
||||
).parent / self.dependencies.link_target
|
||||
|
||||
if not self._device_function_table:
|
||||
self._device_function_table = DeviceFunctionTable({})
|
||||
self.device_function_map = self._device_function_table.function_map
|
||||
self.device_functions = self._device_function_table.functions
|
||||
|
||||
def Serialize(self, filepath=None):
|
||||
"""Serilizes the HATFile to disk using the file location specified by `filepath`.
|
||||
|
@ -447,7 +604,11 @@ class HATFile:
|
|||
filepath = self.path
|
||||
root_table = tomlkit.table()
|
||||
root_table.add(Description.TableName, self.description.to_table())
|
||||
root_table.add(FunctionTable.TableName, self._function_table.to_table())
|
||||
root_table.add(FunctionTable.TableName,
|
||||
self._function_table.to_table())
|
||||
if self.device_function_map:
|
||||
root_table.add(DeviceFunctionTable.TableName,
|
||||
self._device_function_table.to_table())
|
||||
root_table.add(Target.TableName, self.target.to_table())
|
||||
root_table.add(Dependencies.TableName, self.dependencies.to_table())
|
||||
root_table.add(CompiledWith.TableName, self.compiled_with.to_table())
|
||||
|
@ -460,23 +621,34 @@ class HATFile:
|
|||
out_file.write(self.HATEpilogue.format(name))
|
||||
|
||||
@staticmethod
|
||||
def Deserialize(filepath):
|
||||
def Deserialize(filepath) -> "HATFile":
|
||||
"""Creates an instance of A HATFile class by deserializing the contents of the file at `filepath`"""
|
||||
hat_toml = _read_toml_file(filepath)
|
||||
name = os.path.splitext(os.path.basename(filepath))[0]
|
||||
required_entries = [Description.TableName,
|
||||
FunctionTable.TableName,
|
||||
Target.TableName,
|
||||
Dependencies.TableName,
|
||||
CompiledWith.TableName,
|
||||
Declaration.TableName]
|
||||
required_entries = [
|
||||
Description.TableName, FunctionTable.TableName, Target.TableName,
|
||||
Dependencies.TableName, CompiledWith.TableName,
|
||||
Declaration.TableName
|
||||
]
|
||||
_check_required_table_entries(hat_toml, required_entries)
|
||||
hat_file = HATFile(name=name,
|
||||
description=Description.parse_from_table(hat_toml[Description.TableName]),
|
||||
_function_table=FunctionTable.parse_from_table(hat_toml[FunctionTable.TableName]),
|
||||
target=Target.parse_from_table(hat_toml[Target.TableName]),
|
||||
dependencies=Dependencies.parse_from_table(hat_toml[Dependencies.TableName]),
|
||||
compiled_with=CompiledWith.parse_from_table(hat_toml[CompiledWith.TableName]),
|
||||
declaration=Declaration.parse_from_table(hat_toml[Declaration.TableName]),
|
||||
path=Path(filepath).resolve())
|
||||
device_function_table = None
|
||||
if DeviceFunctionTable.TableName in hat_toml:
|
||||
device_function_table = DeviceFunctionTable.parse_from_table(
|
||||
hat_toml[DeviceFunctionTable.TableName])
|
||||
hat_file = HATFile(
|
||||
name=name,
|
||||
description=Description.parse_from_table(
|
||||
hat_toml[Description.TableName]),
|
||||
_function_table=FunctionTable.parse_from_table(
|
||||
hat_toml[FunctionTable.TableName]),
|
||||
_device_function_table=device_function_table,
|
||||
target=Target.parse_from_table(
|
||||
hat_toml[Target.TableName]),
|
||||
dependencies=Dependencies.parse_from_table(
|
||||
hat_toml[Dependencies.TableName]),
|
||||
compiled_with=CompiledWith.parse_from_table(
|
||||
hat_toml[CompiledWith.TableName]),
|
||||
declaration=Declaration.parse_from_table(
|
||||
hat_toml[Declaration.TableName]),
|
||||
path=Path(filepath).resolve())
|
||||
return hat_file
|
|
@ -1,4 +1,3 @@
|
|||
|
||||
#!/usr/bin/env python3
|
||||
import unittest
|
||||
import sys, os
|
||||
|
@ -8,6 +7,7 @@ sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
|
|||
|
||||
from benchmark_hat_package import run_benchmark
|
||||
|
||||
|
||||
class BenchmarkHATPackage_test(unittest.TestCase):
|
||||
def test_benchmark(self):
|
||||
A = acc.Array(role=acc.Array.Role.INPUT, shape=(256, 256))
|
||||
|
@ -23,9 +23,16 @@ class BenchmarkHATPackage_test(unittest.TestCase):
|
|||
|
||||
package = acc.Package()
|
||||
package.add(nest, args=(A, B, C), base_name="test_function")
|
||||
package.build(name="BenchmarkHATPackage_test_benchmark", output_dir="test_acccgen", format=acc.Package.Format.HAT_DYNAMIC)
|
||||
package.build(name="BenchmarkHATPackage_test_benchmark",
|
||||
output_dir="test_acccgen",
|
||||
format=acc.Package.Format.HAT_DYNAMIC)
|
||||
|
||||
run_benchmark("test_acccgen/BenchmarkHATPackage_test_benchmark.hat",
|
||||
store_in_hat=False,
|
||||
batch_size=2,
|
||||
min_time_in_sec=1,
|
||||
input_sets_minimum_size_MB=1)
|
||||
|
||||
run_benchmark("test_acccgen/BenchmarkHATPackage_test_benchmark.hat", store_in_hat=False, batch_size=2, min_time_in_sec=1, input_sets_minimum_size_MB=1)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
unittest.main()
|
|
@ -10,8 +10,8 @@ from hat import load
|
|||
from hat_to_dynamic import create_dynamic_package
|
||||
from hat_to_lib import create_static_package
|
||||
|
||||
class HAT_test(unittest.TestCase):
|
||||
|
||||
class HAT_test(unittest.TestCase):
|
||||
def test_load(self):
|
||||
|
||||
# Generate a HAT package
|
||||
|
@ -30,9 +30,12 @@ class HAT_test(unittest.TestCase):
|
|||
|
||||
for mode in [acc.Package.Mode.RELEASE, acc.Package.Mode.DEBUG]:
|
||||
package_name = f"HAT_test_load_{mode.value}"
|
||||
package.build(name=package_name, output_dir="test_acccgen", mode=mode)
|
||||
package.build(name=package_name,
|
||||
output_dir="test_acccgen",
|
||||
mode=mode)
|
||||
|
||||
create_dynamic_package(f"test_acccgen/{package_name}.hat", f"test_acccgen/{package_name}.dyn.hat")
|
||||
create_dynamic_package(f"test_acccgen/{package_name}.hat",
|
||||
f"test_acccgen/{package_name}.dyn.hat")
|
||||
|
||||
hat_package = load(f"test_acccgen/{package_name}.dyn.hat")
|
||||
|
||||
|
@ -40,7 +43,7 @@ class HAT_test(unittest.TestCase):
|
|||
print(name)
|
||||
|
||||
# create numpy arguments with the correct shape and dtype
|
||||
A = np.random.rand(16, 16).astype(np.float32)
|
||||
A = np.random.rand(16, 16).astype(np.float32)
|
||||
B = np.random.rand(16, 16).astype(np.float32)
|
||||
B_ref = B + A
|
||||
|
|
@ -0,0 +1,92 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from pathlib import Path
|
||||
import unittest
|
||||
|
||||
import sys, os
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
|
||||
|
||||
from hat_file import (CallingConventionType, CompiledWith, Declaration,
|
||||
Dependencies, Description, Function, FunctionTable,
|
||||
HATFile, OperatingSystem, Parameter, ParameterType,
|
||||
Target, UsageType)
|
||||
|
||||
|
||||
class HATFile_test(unittest.TestCase):
|
||||
def test_file_basic_serialize(self):
|
||||
# Construct a HAT file from scratch
|
||||
# Start with a function definition
|
||||
my_function = Function(
|
||||
name="my_function",
|
||||
description="Some description",
|
||||
calling_convention=CallingConventionType.StdCall,
|
||||
return_info=Parameter(logical_type=ParameterType.RuntimeArray,
|
||||
declared_type="float*",
|
||||
element_type="float",
|
||||
usage=UsageType.Input,
|
||||
shape="[16, 16]",
|
||||
affine_map=[16, 1],
|
||||
size="16 * 16 * sizeof(float)"))
|
||||
# Create the function table
|
||||
functions = FunctionTable({"my_function": my_function})
|
||||
# Create the HATFile object
|
||||
hat_file1 = HATFile(
|
||||
name="test_file",
|
||||
description=Description(
|
||||
version="0.0.1",
|
||||
author="me",
|
||||
license_url="https://www.apache.org/licenses/LICENSE-2.0.html"
|
||||
),
|
||||
_function_table=functions,
|
||||
target=Target(required=Target.Required(os=OperatingSystem.Windows,
|
||||
cpu=Target.Required.CPU(
|
||||
architecture="Haswell",
|
||||
extensions=["AVX2"]),
|
||||
gpu=None),
|
||||
optimized_for=Target.OptimizedFor()),
|
||||
dependencies=Dependencies(link_target="my_lib.lib"),
|
||||
compiled_with=CompiledWith(compiler="VC++"),
|
||||
declaration=Declaration(),
|
||||
path=Path(".").resolve())
|
||||
# Serialize it to disk
|
||||
test_file_name = "test_file_serialize.hat"
|
||||
|
||||
try:
|
||||
hat_file1.Serialize(test_file_name)
|
||||
# Deserialize it and verify it has what we expect
|
||||
hat_file2 = HATFile.Deserialize(test_file_name)
|
||||
finally:
|
||||
# Remove the file
|
||||
os.remove(test_file_name)
|
||||
|
||||
# Do basic verification that the deserialized HatFile contains what we specified
|
||||
# when we created the HATFile directly
|
||||
self.assertEqual(hat_file1.description, hat_file2.description)
|
||||
self.assertEqual(hat_file1.dependencies, hat_file2.dependencies)
|
||||
self.assertEqual(hat_file1.compiled_with.to_table(),
|
||||
hat_file2.compiled_with.to_table())
|
||||
self.assertTrue("my_function" in hat_file2.function_map)
|
||||
|
||||
def test_file_basic_deserialize(self):
|
||||
# Load a HAT file from the samples directory
|
||||
hat_file1 = HATFile.Deserialize(
|
||||
os.path.join(os.path.dirname(__file__), "..", "..", "samples",
|
||||
"sample_gemm_library.hat"))
|
||||
description = {
|
||||
"author": "John Doe",
|
||||
"version": "1.2.3.5",
|
||||
"license_url": "https://www.apache.org/licenses/LICENSE-2.0.html",
|
||||
}
|
||||
|
||||
# Do basic verification of known values in the file
|
||||
# Verify the description has entries we expect
|
||||
self.assertLessEqual(description.items(),
|
||||
hat_file1.description.to_table().items())
|
||||
# Verify the list of functions
|
||||
self.assertTrue(len(hat_file1.functions) == 2)
|
||||
self.assertTrue("GEMM_B94D27B9934D3E08" in hat_file1.function_map)
|
||||
self.assertTrue("blas_sgemm_row_major" in hat_file1.function_map)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
|
@ -0,0 +1,92 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from pathlib import Path
|
||||
import unittest
|
||||
|
||||
import sys, os
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
|
||||
|
||||
from hat_file import (CallingConventionType, CompiledWith, Declaration,
|
||||
Dependencies, Description, Function, FunctionTable,
|
||||
HATFile, OperatingSystem, Parameter, ParameterType,
|
||||
Target, UsageType)
|
||||
|
||||
|
||||
class HATFile_test(unittest.TestCase):
|
||||
def test_file_basic_serialize(self):
|
||||
# Construct a HAT file from scratch
|
||||
# Start with a function definition
|
||||
my_function = Function(
|
||||
name="my_function",
|
||||
description="Some description",
|
||||
calling_convention=CallingConventionType.StdCall,
|
||||
return_info=Parameter(logical_type=ParameterType.RuntimeArray,
|
||||
declared_type="float*",
|
||||
element_type="float",
|
||||
usage=UsageType.Input,
|
||||
shape="[16, 16]",
|
||||
affine_map=[16, 1],
|
||||
size="16 * 16 * sizeof(float)"))
|
||||
# Create the function table
|
||||
functions = FunctionTable({"my_function": my_function})
|
||||
# Create the HATFile object
|
||||
hat_file1 = HATFile(
|
||||
name="test_file",
|
||||
description=Description(
|
||||
version="0.0.1",
|
||||
author="me",
|
||||
license_url="https://www.apache.org/licenses/LICENSE-2.0.html"
|
||||
),
|
||||
_function_table=functions,
|
||||
target=Target(required=Target.Required(os=OperatingSystem.Windows,
|
||||
cpu=Target.Required.CPU(
|
||||
architecture="Haswell",
|
||||
extensions=["AVX2"]),
|
||||
gpu=None),
|
||||
optimized_for=Target.OptimizedFor()),
|
||||
dependencies=Dependencies(link_target="my_lib.lib"),
|
||||
compiled_with=CompiledWith(compiler="VC++"),
|
||||
declaration=Declaration(),
|
||||
path=Path(".").resolve())
|
||||
# Serialize it to disk
|
||||
test_file_name = "test_file_serialize.hat"
|
||||
|
||||
try:
|
||||
hat_file1.Serialize(test_file_name)
|
||||
# Deserialize it and verify it has what we expect
|
||||
hat_file2 = HATFile.Deserialize(test_file_name)
|
||||
finally:
|
||||
# Remove the file
|
||||
os.remove(test_file_name)
|
||||
|
||||
# Do basic verification that the deserialized HatFile contains what we specified
|
||||
# when we created the HATFile directly
|
||||
self.assertEqual(hat_file1.description, hat_file2.description)
|
||||
self.assertEqual(hat_file1.dependencies, hat_file2.dependencies)
|
||||
self.assertEqual(hat_file1.compiled_with.to_table(),
|
||||
hat_file2.compiled_with.to_table())
|
||||
self.assertTrue("my_function" in hat_file2.function_map)
|
||||
|
||||
def test_file_basic_deserialize(self):
|
||||
# Load a HAT file from the samples directory
|
||||
hat_file1 = HATFile.Deserialize(
|
||||
os.path.join(os.path.dirname(__file__), "..", "..", "samples",
|
||||
"sample_gemm_library.hat"))
|
||||
description = {
|
||||
"author": "John Doe",
|
||||
"version": "1.2.3.5",
|
||||
"license_url": "https://www.apache.org/licenses/LICENSE-2.0.html",
|
||||
}
|
||||
|
||||
# Do basic verification of known values in the file
|
||||
# Verify the description has entries we expect
|
||||
self.assertLessEqual(description.items(),
|
||||
hat_file1.description.to_table().items())
|
||||
# Verify the list of functions
|
||||
self.assertTrue(len(hat_file1.function_map) == 2)
|
||||
self.assertTrue("GEMM_B94D27B9934D3E08" in hat_file1.function_map)
|
||||
self.assertTrue("blas_sgemm_row_major" in hat_file1.function_map)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
|
@ -1,6 +1,6 @@
|
|||
# HAT TOML Schema
|
||||
[toml-schema]
|
||||
version = "0.0.0.2"
|
||||
version = "0.0.0.3"
|
||||
|
||||
# Types to be used elsewhere in this schema
|
||||
[types]
|
||||
|
@ -66,6 +66,10 @@ version = "0.0.0.2"
|
|||
[types.functionType]
|
||||
type = "table"
|
||||
|
||||
##########
|
||||
# Required
|
||||
##########
|
||||
|
||||
# The name of the function
|
||||
[types.functionType.name]
|
||||
type = "string"
|
||||
|
@ -77,8 +81,8 @@ version = "0.0.0.2"
|
|||
# The calling convention for this function
|
||||
[types.functionType.calling_convention]
|
||||
type = "string"
|
||||
allowedvalues = [ "stdcall", "cdecl", "fastcall", "vectorcall" ]
|
||||
|
||||
allowedvalues = [ "stdcall", "cdecl", "fastcall", "vectorcall", "device" ]
|
||||
|
||||
# An array of arguments to the function
|
||||
[types.functionType.arguments]
|
||||
type = "array"
|
||||
|
@ -88,6 +92,30 @@ version = "0.0.0.2"
|
|||
[types.functionType.return]
|
||||
typeof = "paramType"
|
||||
|
||||
##########
|
||||
# Optional
|
||||
##########
|
||||
|
||||
# The parameters needed to launch this function, if applicable
|
||||
[types.functionType.launch_parameters]
|
||||
type = "array"
|
||||
optional = true
|
||||
|
||||
# The function that is launched by this function
|
||||
[types.functionType.launches]
|
||||
type = "string"
|
||||
optional = true
|
||||
|
||||
# The provider of this function, if any
|
||||
[types.functionType.provider]
|
||||
type = "string"
|
||||
optional = true
|
||||
|
||||
# The runtime used by the function
|
||||
[types.functionType.runtime]
|
||||
type = "string"
|
||||
optional = true
|
||||
|
||||
# Optional additional usage-specific information about the function that isn't part of this schema
|
||||
[types.functionType.auxiliary]
|
||||
type = "table"
|
||||
|
@ -136,12 +164,18 @@ version = "0.0.0.2"
|
|||
type = "table"
|
||||
optional = true
|
||||
|
||||
# Collection of functions declared within the HAT file and their metadata
|
||||
# Collection of host functions declared within the HAT file and their metadata
|
||||
# The keys in a collection are not prescribed by the schema, and in this case are the names of the functions as the HAT format does not support function overloading.
|
||||
[elements.functions]
|
||||
type = "collection"
|
||||
typeof = "functionType"
|
||||
|
||||
# Collection of device functions declared within the HAT file and their metadata
|
||||
# The keys in a collection are not prescribed by the schema, and in this case are the names of the functions as the HAT format does not support function overloading.
|
||||
[elements.device_functions]
|
||||
type = "collection"
|
||||
typeof = "functionType"
|
||||
|
||||
# Table of information about the target device the functions described in this HAT file are intended to be used with
|
||||
[elements.target]
|
||||
type = "table"
|
||||
|
@ -168,11 +202,17 @@ version = "0.0.0.2"
|
|||
type = "array"
|
||||
arraytype = "string"
|
||||
|
||||
# Optional CPU runtime library
|
||||
[elements.target.required.CPU.runtime]
|
||||
type = "string"
|
||||
allowedvalues = [ "openmp" ]
|
||||
optional = true
|
||||
|
||||
# Optional additional information not defined in this schema
|
||||
[elements.target.required.CPU.auxiliary]
|
||||
type = "table"
|
||||
optional = true
|
||||
|
||||
|
||||
# Required GPU characteristics if there are GPU functions in this HAT package
|
||||
[elements.target.required.GPU]
|
||||
type = "table"
|
||||
|
@ -194,7 +234,7 @@ version = "0.0.0.2"
|
|||
# Minimum global memory in KB that will be allocated
|
||||
[elements.target.required.GPU.min_global_memory_KB]
|
||||
type = "integer"
|
||||
|
||||
|
||||
# Minimum shared memory in KB that will be allocated
|
||||
[elements.target.required.GPU.min_shared_memory_KB]
|
||||
type = "integer"
|
||||
|
|
|
@ -29,7 +29,7 @@ install_requires =
|
|||
tomlkit
|
||||
vswhere; sys_platform == "win32"
|
||||
package_dir =
|
||||
hatlib = tools
|
||||
hatlib = hatlib
|
||||
|
||||
[options.entry_points]
|
||||
console_scripts =
|
||||
|
|
173
tools/hat.py
173
tools/hat.py
|
@ -1,173 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
"""Loads a dynamically-linked HAT package in Python
|
||||
|
||||
Call 'load' to load a HAT package in Python. After loading, call the HAT functions using numpy
|
||||
arrays as arguments. The shape, element type, and order of each numpy array should exactly match
|
||||
the requirements of the HAT function.
|
||||
|
||||
For example:
|
||||
import numpy as np
|
||||
import hatlib as hat
|
||||
|
||||
# load the package
|
||||
package = hat.load("my_package.hat")
|
||||
|
||||
# print the function names
|
||||
for name in package.names:
|
||||
print(name)
|
||||
|
||||
# create numpy arguments with the correct shape, dtype, and order
|
||||
A = np.ones([256,32], dtype=np.float32, order="C")
|
||||
B = np.ones([32,256], dtype=np.float32, order="C")
|
||||
D = np.ones([256,32], dtype=np.float32, order="C")
|
||||
E = np.ones([256,32], dtype=np.float32, order="C")
|
||||
|
||||
# call a package function named 'my_func_698b5e5c'
|
||||
package.my_func_698b5e5c(A, B, D, E)
|
||||
"""
|
||||
|
||||
import sys
|
||||
import toml
|
||||
import ctypes
|
||||
import os
|
||||
import numpy as np
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Tuple
|
||||
|
||||
@dataclass
|
||||
class ArgInfo:
|
||||
"""Extracts necessary information from the description of a function argument in a hat file"""
|
||||
hat_declared_type: str
|
||||
numpy_shape: Tuple[int]
|
||||
numpy_strides: Tuple[int]
|
||||
numpy_dtype: type
|
||||
element_num_bytes: int
|
||||
ctypes_pointer_type: Any
|
||||
|
||||
def __init__(self, param_description):
|
||||
self.hat_declared_type = param_description["declared_type"]
|
||||
self.numpy_shape = tuple(param_description["shape"])
|
||||
if self.hat_declared_type == "float16_t*":
|
||||
self.numpy_dtype = np.float16
|
||||
self.element_num_bytes = 2
|
||||
self.ctypes_pointer_type = ctypes.POINTER(ctypes.c_uint16) # same bitwidth as float16
|
||||
elif self.hat_declared_type == "float*":
|
||||
self.numpy_dtype = np.float32
|
||||
self.element_num_bytes = 4
|
||||
self.ctypes_pointer_type = ctypes.POINTER(ctypes.c_float)
|
||||
elif self.hat_declared_type == "double*":
|
||||
self.numpy_dtype = np.float64
|
||||
self.element_num_bytes = 8
|
||||
self.ctypes_pointer_type = ctypes.POINTER(ctypes.c_double)
|
||||
elif self.hat_declared_type == "int64_t*":
|
||||
self.numpy_dtype = np.int64
|
||||
self.element_num_bytes = 8
|
||||
self.ctypes_pointer_type = ctypes.POINTER(ctypes.c_int64)
|
||||
elif self.hat_declared_type == "int32_t*":
|
||||
self.numpy_dtype = np.int32
|
||||
self.element_num_bytes = 4
|
||||
self.ctypes_pointer_type = ctypes.POINTER(ctypes.c_int32)
|
||||
elif self.hat_declared_type == "int16_t*":
|
||||
self.numpy_dtype = np.int16
|
||||
self.element_num_bytes = 2
|
||||
self.ctypes_pointer_type = ctypes.POINTER(ctypes.c_int16)
|
||||
elif self.hat_declared_type == "int8_t*":
|
||||
self.numpy_dtype = np.int8
|
||||
self.element_num_bytes = 1
|
||||
self.ctypes_pointer_type = ctypes.POINTER(ctypes.c_int8)
|
||||
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported declared_type {self.hat_declared_type} in hat file")
|
||||
|
||||
self.numpy_strides = tuple([self.element_num_bytes * x for x in param_description["affine_map"]])
|
||||
|
||||
|
||||
def verify_args(args, arg_infos, function_name):
|
||||
""" Verifies that a list of arguments matches a list of argument descriptions in a HAT file
|
||||
"""
|
||||
# check number of args
|
||||
if len(args) != len(arg_infos):
|
||||
sys.exit(f"Error calling {function_name}(...): expected {len(arg_infos)} arguments but received {len(args)}")
|
||||
|
||||
# for each arg
|
||||
for i in range(len(args)):
|
||||
arg = args[i]
|
||||
arg_info = arg_infos[i]
|
||||
|
||||
# confirm that the arg is a numpy ndarray
|
||||
if not isinstance(arg, np.ndarray):
|
||||
sys.exit("Error calling {function_name}(...): expected argument {i} to be <class 'numpy.ndarray'> but received {type(arg)}")
|
||||
|
||||
# confirm that the arg dtype matches the dexcription in the hat package
|
||||
if arg_info.numpy_dtype != arg.dtype:
|
||||
sys.exit(f"Error calling {function_name}(...): expected argument {i} to have dtype={arg_info.numpy_dtype} but received dtype={arg.dtype}")
|
||||
|
||||
# confirm that the arg shape is correct
|
||||
if arg_info.numpy_shape != arg.shape:
|
||||
sys.exit(f"Error calling {function_name}(...): expected argument {i} to have shape={arg_info.numpy_shape} but received shape={arg.shape}")
|
||||
|
||||
# confirm that the arg strides are correct
|
||||
if arg_info.numpy_strides != arg.strides:
|
||||
sys.exit(f"Error calling {function_name}(...): expected argument {i} to have strides={arg_info.numpy_strides} but received strides={arg.strides}")
|
||||
|
||||
|
||||
def hat_description_to_python_function(hat_description, hat_library):
|
||||
""" Creates a callable function based on a function description in a HAT package
|
||||
"""
|
||||
hat_arg_descriptions = hat_description["arguments"]
|
||||
function_name = hat_description["name"]
|
||||
|
||||
def f(*args):
|
||||
# verify that the (numpy) input args match the description in the hat file
|
||||
arg_infos = [ArgInfo(d) for d in hat_arg_descriptions]
|
||||
verify_args(args, arg_infos, function_name)
|
||||
|
||||
# prepare the args to the hat package
|
||||
hat_args = [arg.ctypes.data_as(arg_info.ctypes_pointer_type) for arg, arg_info in zip (args, arg_infos)]
|
||||
|
||||
# call the function in the hat package
|
||||
hat_library[function_name](*hat_args)
|
||||
|
||||
return f
|
||||
|
||||
|
||||
class AttributeDict(OrderedDict):
|
||||
""" Dictionary that allows entries to be accessed like attributes
|
||||
"""
|
||||
__getattr__ = OrderedDict.__getitem__
|
||||
|
||||
@property
|
||||
def names(self):
|
||||
return list(self.keys())
|
||||
|
||||
def __getitem__(self, key):
|
||||
for k, v in self.items():
|
||||
if k.startswith(key):
|
||||
return v
|
||||
return OrderedDict.__getitem__(key)
|
||||
|
||||
def load(hat_path):
|
||||
""" Creates a class with static functions based on the function descriptions in a HAT package
|
||||
"""
|
||||
# load the function decscriptions from the hat file
|
||||
hat_path = os.path.abspath(hat_path)
|
||||
t = toml.load(hat_path)
|
||||
|
||||
function_descriptions = t["functions"]
|
||||
hat_binary_filename = t["dependencies"]["link_target"]
|
||||
hat_binary_path = os.path.join(os.path.dirname(hat_path), hat_binary_filename)
|
||||
|
||||
# check that the HAT library has a supported file extension
|
||||
supported_extensions = [".dll", ".so"]
|
||||
_, extension = os.path.splitext(hat_binary_path)
|
||||
if extension not in supported_extensions:
|
||||
sys.exit(f"Unsupported HAT library extension: {extension}")
|
||||
|
||||
# load the hat_library:
|
||||
hat_library = ctypes.cdll.LoadLibrary(hat_binary_path)
|
||||
|
||||
# create dictionary of functions defined in the hat file
|
||||
function_dict = AttributeDict({key : hat_description_to_python_function(val, hat_library) for key,val in function_descriptions.items()})
|
||||
return function_dict
|
|
@ -1,79 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from pathlib import Path
|
||||
import unittest
|
||||
|
||||
import sys, os
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
|
||||
|
||||
from hat_file import CallingConventionType, CompiledWith, Declaration, Dependencies, Description, Function, FunctionTable, HATFile, OperatingSystem, Parameter, ParameterType, Target, UsageType
|
||||
|
||||
class HATFile_test(unittest.TestCase):
|
||||
def test_file_basic_serialize(self):
|
||||
# Construct a HAT file from scratch
|
||||
# Start with a function definition
|
||||
my_function = Function(name="my_function",
|
||||
description="Some description",
|
||||
calling_convention=CallingConventionType.StdCall,
|
||||
return_info=Parameter(logical_type=ParameterType.RuntimeArray,
|
||||
declared_type="float*",
|
||||
element_type="float",
|
||||
usage=UsageType.Input,
|
||||
shape="[16, 16]",
|
||||
affine_map=[16, 1],
|
||||
size="16 * 16 * sizeof(float)"))
|
||||
# Create the function table
|
||||
functions = FunctionTable({"my_function" : my_function})
|
||||
# Create the HATFile object
|
||||
hat_file1 = HATFile(name="test_file",
|
||||
description=Description(version="0.0.1",
|
||||
author="me",
|
||||
license_url="https://www.apache.org/licenses/LICENSE-2.0.html"),
|
||||
_function_table=functions,
|
||||
target=Target(
|
||||
required=Target.Required(os=OperatingSystem.Windows,
|
||||
cpu=Target.Required.CPU(architecture="Haswell", extensions=["AVX2"]),
|
||||
gpu=None),
|
||||
|
||||
optimized_for=Target.OptimizedFor()),
|
||||
dependencies=Dependencies(link_target="my_lib.lib"),
|
||||
compiled_with=CompiledWith(compiler="VC++"),
|
||||
declaration=Declaration(),
|
||||
path=Path(".").resolve())
|
||||
# Serialize it to disk
|
||||
test_file_name = "test_file_serialize.hat"
|
||||
|
||||
try:
|
||||
hat_file1.Serialize(test_file_name)
|
||||
# Deserialize it and verify it has what we expect
|
||||
hat_file2 = HATFile.Deserialize(test_file_name)
|
||||
finally:
|
||||
# Remove the file
|
||||
os.remove(test_file_name)
|
||||
|
||||
# Do basic verification that the deserialized HatFile contains what we specified
|
||||
# when we created the HATFile directly
|
||||
self.assertEqual(hat_file1.description, hat_file2.description)
|
||||
self.assertEqual(hat_file1.dependencies, hat_file2.dependencies)
|
||||
self.assertEqual(hat_file1.compiled_with.to_table(), hat_file2.compiled_with.to_table())
|
||||
self.assertTrue("my_function" in hat_file2.function_map)
|
||||
|
||||
def test_file_basic_deserialize(self):
|
||||
# Load a HAT file from the samples directory
|
||||
hat_file1 = HATFile.Deserialize(os.path.join(os.path.dirname(__file__), "..", "..", "samples", "sample_gemm_library.hat"))
|
||||
description = {
|
||||
"author": "John Doe",
|
||||
"version": "1.2.3.5",
|
||||
"license_url": "https://www.apache.org/licenses/LICENSE-2.0.html",
|
||||
}
|
||||
|
||||
# Do basic verification of known values in the file
|
||||
# Verify the description has entries we expect
|
||||
self.assertLessEqual(description.items(), hat_file1.description.to_table().items())
|
||||
# Verify the list of functions
|
||||
self.assertTrue(len(hat_file1.functions) == 2)
|
||||
self.assertTrue("GEMM_B94D27B9934D3E08" in hat_file1.function_map)
|
||||
self.assertTrue("blas_sgemm_row_major" in hat_file1.function_map)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
|
@ -1,79 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from pathlib import Path
|
||||
import unittest
|
||||
|
||||
import sys, os
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
|
||||
|
||||
from hat_file import CallingConventionType, CompiledWith, Declaration, Dependencies, Description, Function, FunctionTable, HATFile, OperatingSystem, Parameter, ParameterType, Target, UsageType
|
||||
|
||||
class HATFile_test(unittest.TestCase):
|
||||
def test_file_basic_serialize(self):
|
||||
# Construct a HAT file from scratch
|
||||
# Start with a function definition
|
||||
my_function = Function(name="my_function",
|
||||
description="Some description",
|
||||
calling_convention=CallingConventionType.StdCall,
|
||||
return_info=Parameter(logical_type=ParameterType.RuntimeArray,
|
||||
declared_type="float*",
|
||||
element_type="float",
|
||||
usage=UsageType.Input,
|
||||
shape="[16, 16]",
|
||||
affine_map=[16, 1],
|
||||
size="16 * 16 * sizeof(float)"))
|
||||
# Create the function table
|
||||
functions = FunctionTable({"my_function" : my_function})
|
||||
# Create the HATFile object
|
||||
hat_file1 = HATFile(name="test_file",
|
||||
description=Description(version="0.0.1",
|
||||
author="me",
|
||||
license_url="https://www.apache.org/licenses/LICENSE-2.0.html"),
|
||||
_function_table=functions,
|
||||
target=Target(
|
||||
required=Target.Required(os=OperatingSystem.Windows,
|
||||
cpu=Target.Required.CPU(architecture="Haswell", extensions=["AVX2"]),
|
||||
gpu=None),
|
||||
|
||||
optimized_for=Target.OptimizedFor()),
|
||||
dependencies=Dependencies(link_target="my_lib.lib"),
|
||||
compiled_with=CompiledWith(compiler="VC++"),
|
||||
declaration=Declaration(),
|
||||
path=Path(".").resolve())
|
||||
# Serialize it to disk
|
||||
test_file_name = "test_file_serialize.hat"
|
||||
|
||||
try:
|
||||
hat_file1.Serialize(test_file_name)
|
||||
# Deserialize it and verify it has what we expect
|
||||
hat_file2 = HATFile.Deserialize(test_file_name)
|
||||
finally:
|
||||
# Remove the file
|
||||
os.remove(test_file_name)
|
||||
|
||||
# Do basic verification that the deserialized HatFile contains what we specified
|
||||
# when we created the HATFile directly
|
||||
self.assertEqual(hat_file1.description, hat_file2.description)
|
||||
self.assertEqual(hat_file1.dependencies, hat_file2.dependencies)
|
||||
self.assertEqual(hat_file1.compiled_with.to_table(), hat_file2.compiled_with.to_table())
|
||||
self.assertTrue("my_function" in hat_file2.function_map)
|
||||
|
||||
def test_file_basic_deserialize(self):
|
||||
# Load a HAT file from the samples directory
|
||||
hat_file1 = HATFile.Deserialize(os.path.join(os.path.dirname(__file__), "..", "..", "samples", "sample_gemm_library.hat"))
|
||||
description = {
|
||||
"author": "John Doe",
|
||||
"version": "1.2.3.5",
|
||||
"license_url": "https://www.apache.org/licenses/LICENSE-2.0.html",
|
||||
}
|
||||
|
||||
# Do basic verification of known values in the file
|
||||
# Verify the description has entries we expect
|
||||
self.assertLessEqual(description.items(), hat_file1.description.to_table().items())
|
||||
# Verify the list of functions
|
||||
self.assertTrue(len(hat_file1.function_map) == 2)
|
||||
self.assertTrue("GEMM_B94D27B9934D3E08" in hat_file1.function_map)
|
||||
self.assertTrue("blas_sgemm_row_major" in hat_file1.function_map)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Загрузка…
Ссылка в новой задаче