Add beginnings of support for CUDA device functions (#32)

Adds support for GPU, device functions, and launch functions
This commit is contained in:
Kern Handa 2022-03-14 23:20:41 -07:00 коммит произвёл GitHub
Родитель b54fca0ff7
Коммит 615006cecb
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
25 изменённых файлов: 1232 добавлений и 483 удалений

6
.github/workflows/ci.yml поставляемый
Просмотреть файл

@ -28,11 +28,11 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install -r tools/requirements.txt
python -m pip install -r hatlib/requirements.txt
- name: Unittest
run: |
python -m pip install -r tools/test/requirements.txt
python -m unittest discover tools/test
python -m pip install -r hatlib/test/requirements.txt
python -m unittest discover hatlib/test
- name: Build whl
run: |
python -m pip install build

Просмотреть файл

@ -8,13 +8,13 @@ HAT is a format for packaging compiled libraries in the C programming language.
A HAT package includes one library file and one or more `.hat` files. The library file can be a static library (a `.a` file on Posix systems or a `.lib` file on Windows systems) or a dynamic library (a `.so` file on Posix systems or a `.dll` file and its accompanying `.lib` import library file on Windows systems). For simplicity of documentation, when discussing the "library file" as a unit in a HAT package we are referring to either the single static library file (`.a` or `.lib`), the single dynamic library file (`.so`) on Posix, or the pair of library files that comprise a dynamic library on Windows (`.dll` and `.lib` import library file). Note that since a HAT package includes only one "library file" it cannot contain both a static and dynamic library, so in the case of a HAT package containing a Windows dynamic library, the `.lib` file in the package is always the import library and there is no ambiguity in which file (or pair of files) is being referenced by "library file".
The library file contains all the compiled object code that implements the functions in the HAT package. Each `.hat` file contains a combination of standard C function declarations (like a typical `.h` file) and metadata in the TOML markup language. The metadata that accompanies each function declaration describes how the function should be called and how it was implemented. The metadata is intended to be both human-readable and machine-readable, providing structured and systematic documentation and allowing downstream tools to examine the package contents.
The library file contains all the compiled object code that implements the functions in the HAT package. Each `.hat` file contains a combination of standard C function declarations (like a typical `.h` file) and metadata in the TOML markup language. The metadata that accompanies each function declaration describes how the function should be called and how it was implemented. The metadata is intended to be both human-readable and machine-readable, providing structured and systematic documentation and allowing downstream tools to examine the package contents.
Each `.hat` file has the convenient property that it is simultaneously a valid h-file and a valid TOML file. In other words, the file is structured such that a C compiler will ignore the TOML metadata, while a TOML parser will understand the entire file as a valid TOML file. We accomplish this using a technique we call *the hat trick*, which is explained below.
Each `.hat` file has the convenient property that it is simultaneously a valid h-file and a valid TOML file. In other words, the file is structured such that a C compiler will ignore the TOML metadata, while a TOML parser will understand the entire file as a valid TOML file. We accomplish this using a technique we call *the hat trick*, which is explained below.
# What problem does the HAT format solve?
# What problem does the HAT format solve?
C is among the most popular programming languages, but it also has serious shortcomings. In particular, C libraries are typically opaque and lack mechanisms for systematic documentation and introspection. This is best explained with an example:
C is among the most popular programming languages, but it also has serious shortcomings. In particular, C libraries are typically opaque and lack mechanisms for systematic documentation and introspection. This is best explained with an example:
Say that we use C to implement an in-place column-wise normalization of a 10x10 matrix. In other words, this function takes a 10x10 matrix `A` and divides each column by the Euclidean norm of that column. A highly-optimized implementation of this function would be tailored to the target computer's specific hardware properties, such as its cache size, the number of CPU cores, and perhaps even the presence of a GPU. The declaration of this function in an h-file would look something like this:
```
@ -24,16 +24,16 @@ The accompanying library file would contain the compiled machine code for this f
* What does the pointer `float* A` point to? By convention, it is reasonable to assume that `A` points to the first element of an array that contains the 100 matrix elements, but this is not stated explicitly.
* Does the function expect the matrix elements to appear in row-major order, column-major order, Z-order, or something else?
* What is the size of the array `A`? We may have auxiliary knowledge that the array is 100 elements long, but this information is not stated explicitly.
* We can see that `A` is not `const`, so we know that its elements can be changed by the function, but is it an "output-only" array (its initial values are overwritten) or is it an "input/output" array? We have auxiliary knowledge that `A` is both an input and an output, but this is not stated explicitly.
* Is this function compiled for Windows or Linux?
* What is the size of the array `A`? We may have auxiliary knowledge that the array is 100 elements long, but this information is not stated explicitly.
* We can see that `A` is not `const`, so we know that its elements can be changed by the function, but is it an "output-only" array (its initial values are overwritten) or is it an "input/output" array? We have auxiliary knowledge that `A` is both an input and an output, but this is not stated explicitly.
* Is this function compiled for Windows or Linux?
* Does it need to be linked to a C runtime library or any other library?
* For which instruction set is the function compiled? Does it rely on SSE extensions? AVX? AVX512?
* Is this a multi-threaded implementation? Does the implementation assume a fixed number of CPU cores?
* Does the implementation rely on GPU hardware?
* Is this a multi-threaded implementation? Does the implementation assume a fixed number of CPU cores?
* Does the implementation rely on GPU hardware?
* Who created this library? Does it have a version number? Is it distributed under an open-source license?
Some of the questions above can be answered by reading the human-readable documentation provided in h-file comments, in `README.txt` or `LICENSE.txt` files, or in a web page that describes the library. Some of the information may be implied by the library name or the function name (e.g., imagine that the function was named "normalize_10x10_singlecore") or by common sense (e.g., if a GPU is not mentioned anywhere, the function probably doesn't require one). Nevertheless, C does not have a schematized systematic way to express all of this important information. Moreover, human-readable documentation does not expose this information to downstream programming tools. For example, imagine a downstream tool that examines a library and automatically creates tests that measure the performance of each function.
Some of the questions above can be answered by reading the human-readable documentation provided in h-file comments, in `README.txt` or `LICENSE.txt` files, or in a web page that describes the library. Some of the information may be implied by the library name or the function name (e.g., imagine that the function was named "normalize_10x10_singlecore") or by common sense (e.g., if a GPU is not mentioned anywhere, the function probably doesn't require one). Nevertheless, C does not have a schematized systematic way to express all of this important information. Moreover, human-readable documentation does not expose this information to downstream programming tools. For example, imagine a downstream tool that examines a library and automatically creates tests that measure the performance of each function.
The HAT package format attempts to replace this opacity with transparency, by annotating each declared function with descriptive metadata in TOML.
@ -56,11 +56,11 @@ As mentioned above, the `.hat` file is simultaneously a valid h-file and a valid
#endif // TOML
```
What does a C compiler see? Assuming that the `TOML` macro is not defined, the parser ignores everything that appears between `#ifdef TOML` and `#endif`. This leaves whatever appears instead of `// Add C declarations here`.
What does a C compiler see? Assuming that the `TOML` macro is not defined, the parser ignores everything that appears between `#ifdef TOML` and `#endif`. This leaves whatever appears instead of `// Add C declarations here`.
What does a TOML parser see? First note that `#` is a comment escape character in TOML, so the `#ifdef` and `#endif` lines are ignored as comments. Any TOML code that appears instead of `// Add TOML here` is parsed normally. Finally, a special TOML table named `[declaration]` is defined, and inside it a key named `code` with all of the C declarations as a multiline string.
Why is it important for the TOML and the C declarations to live in the same file? Why not put the TOML metadata in a separate file? The fact that C already splits the package code between library files and h-files is already a concern, because the user has to worry about distributing a `.h` file with an incorrect version of the library file. We don't want to make things worse by adding yet another separate file. Keeping the metadata in the same file as the function declaration ensures that each declaration is never separated from its metadata.
Why is it important for the TOML and the C declarations to live in the same file? Why not put the TOML metadata in a separate file? The fact that C already splits the package code between library files and h-files is already a concern, because the user has to worry about distributing a `.h` file with an incorrect version of the library file. We don't want to make things worse by adding yet another separate file. Keeping the metadata in the same file as the function declaration ensures that each declaration is never separated from its metadata.
# Multiple `.hat` files
@ -82,7 +82,7 @@ Requirements: Python 3.7 and above.
pip install hatlib
```
[Documentation](https://github.com/microsoft/hat/tree/main/tools#readme)
[Documentation](https://github.com/microsoft/hat/tree/main/hatlib#readme)
You can also clone this repository and build a package locally:

Просмотреть файл

@ -137,14 +137,14 @@ mean_duration_in_sec = 1.5953456437541567e-06
This repository contains unit tests, authored with the Python `unittest` library. To setup and run all tests:
```shell
pip install -r <path_to_repo>/tools/test/requirements.txt
python -m unittest discover <path_to_repo>/tools/test
pip install -r <path_to_repo>/hatlib/test/requirements.txt
python -m unittest discover <path_to_repo>/hatlib/test
```
To run a test case:
```shell
python -m unittest discover -k "test_file_basic_serialize" <path_to_repo>/tools/test
python -m unittest discover -k "test_file_basic_serialize" <path_to_repo>/hatlib/test
```
Note that some tests will require a C++ compiler (e.g. MSVC for windows, gcc for linux) in `PATH`.
Note that some tests will require a C++ compiler (e.g. MSVC for windows, gcc for linux) in `PATH`.

Просмотреть файл

108
hatlib/arg_info.py Normal file
Просмотреть файл

@ -0,0 +1,108 @@
import ctypes
import numpy as np
import sys
from dataclasses import dataclass
from typing import Any, Tuple
from functools import reduce
from typing import List
@dataclass
class ArgInfo:
"""Extracts necessary information from the description of a function argument in a hat file"""
hat_declared_type: str
numpy_shape: Tuple[int]
numpy_strides: Tuple[int]
numpy_dtype: type
element_num_bytes: int
ctypes_pointer_type: Any
usage: str = ""
def __init__(self, param_description):
self.hat_declared_type = param_description["declared_type"]
self.numpy_shape = tuple(param_description["shape"])
self.usage = param_description["usage"]
if self.hat_declared_type == "float16_t*":
self.numpy_dtype = np.float16
self.element_num_bytes = 2
self.ctypes_pointer_type = ctypes.POINTER(
ctypes.c_uint16) # same bitwidth as float16
elif self.hat_declared_type == "float*":
self.numpy_dtype = np.float32
self.element_num_bytes = 4
self.ctypes_pointer_type = ctypes.POINTER(ctypes.c_float)
elif self.hat_declared_type == "double*":
self.numpy_dtype = np.float64
self.element_num_bytes = 8
self.ctypes_pointer_type = ctypes.POINTER(ctypes.c_double)
elif self.hat_declared_type == "int64_t*":
self.numpy_dtype = np.int64
self.element_num_bytes = 8
self.ctypes_pointer_type = ctypes.POINTER(ctypes.c_int64)
elif self.hat_declared_type == "int32_t*":
self.numpy_dtype = np.int32
self.element_num_bytes = 4
self.ctypes_pointer_type = ctypes.POINTER(ctypes.c_int32)
elif self.hat_declared_type == "int16_t*":
self.numpy_dtype = np.int16
self.element_num_bytes = 2
self.ctypes_pointer_type = ctypes.POINTER(ctypes.c_int16)
elif self.hat_declared_type == "int8_t*":
self.numpy_dtype = np.int8
self.element_num_bytes = 1
self.ctypes_pointer_type = ctypes.POINTER(ctypes.c_int8)
else:
raise NotImplementedError(
f"Unsupported declared_type {self.hat_declared_type} in hat file")
self.numpy_strides = tuple(
[self.element_num_bytes * x for x in param_description["affine_map"]])
def verify_args(args, arg_infos, function_name):
""" Verifies that a list of arguments matches a list of argument descriptions in a HAT file
"""
# check number of args
if len(args) != len(arg_infos):
sys.exit(
f"Error calling {function_name}(...): expected {len(arg_infos)} arguments but received {len(args)}")
# for each arg
for i in range(len(args)):
arg = args[i]
arg_info = arg_infos[i]
# confirm that the arg is a numpy ndarray
if not isinstance(arg, np.ndarray):
sys.exit(
"Error calling {function_name}(...): expected argument {i} to be <class 'numpy.ndarray'> but received {type(arg)}")
# confirm that the arg dtype matches the dexcription in the hat package
if arg_info.numpy_dtype != arg.dtype:
sys.exit(
f"Error calling {function_name}(...): expected argument {i} to have dtype={arg_info.numpy_dtype} but received dtype={arg.dtype}")
# confirm that the arg shape is correct
if arg_info.numpy_shape != arg.shape:
sys.exit(
f"Error calling {function_name}(...): expected argument {i} to have shape={arg_info.numpy_shape} but received shape={arg.shape}")
# confirm that the arg strides are correct
if arg_info.numpy_strides != arg.strides:
sys.exit(
f"Error calling {function_name}(...): expected argument {i} to have strides={arg_info.numpy_strides} but received strides={arg.strides}")
def generate_input_sets(parameters: List[ArgInfo], input_sets_minimum_size_MB: int = 0, num_additional: int = 0):
shapes_to_sizes = [reduce(lambda x, y: x * y, p.numpy_shape)
for p in parameters]
set_size = reduce(
lambda x, y: x + y, [size * p.element_num_bytes for size, p in zip(shapes_to_sizes, parameters)])
num_input_sets = (input_sets_minimum_size_MB * 1024 *
1024 // set_size) + 1 + num_additional
input_sets = [[np.random.random(p.numpy_shape).astype(
p.numpy_dtype) for p in parameters] for _ in range(num_input_sets)]
return input_sets[0] if len(input_sets) == 1 else input_sets

Просмотреть файл

@ -7,18 +7,17 @@ import sys
import time
import toml
import traceback
from functools import reduce
from pathlib import Path
from typing import List
if __package__:
from .hat_file import HATFile
from .hat_to_dynamic import create_dynamic_package
from .hat import load, ArgInfo
from .hat import load, ArgInfo, generate_input_sets
else:
from hat_file import HATFile
from hat_to_dynamic import create_dynamic_package
from hat import load, ArgInfo
from hat import load, ArgInfo, generate_input_sets
class Benchmark:
"""A basic python-based benchmark.
@ -27,6 +26,7 @@ class Benchmark:
Requirements:
A compilation toolchain in your PATH: cl.exe & link.exe (Windows), gcc (Linux), or clang (macOS)
"""
def __init__(self, hat_path):
self.hat_path = Path(hat_path)
@ -36,14 +36,15 @@ class Benchmark:
# create dictionary of function descriptions defined in the hat file
t = toml.load(self.hat_path)
function_descriptions = t["functions"]
self.hat_arg_descriptions = {key : [ArgInfo(d) for d in val["arguments"]] for key, val in function_descriptions.items()}
self.hat_arg_descriptions = {key: [ArgInfo(
d) for d in val["arguments"]] for key, val in function_descriptions.items()}
def run(self,
function_name: str,
warmup_iterations: int = 10,
min_timing_iterations: int = 100,
min_time_in_sec: int = 10,
input_sets_minimum_size_MB = 50) -> float:
input_sets_minimum_size_MB=50) -> float:
"""Runs benchmarking for a function.
Multiple inputs are run through the function until both minimum time and minimum iterations have been reached.
The mean duration is then calculated as mean_duration = total_time_elapsed / total_iterations_performed.
@ -62,8 +63,10 @@ class Benchmark:
# TODO: support packing and unpacking functions
mean_elapsed_time, batch_timings = self._profile(function_name, warmup_iterations, min_timing_iterations, min_time_in_sec, input_sets_minimum_size_MB)
print(f"[Benchmarking] Mean duration per iteration: {mean_elapsed_time:.8f}s")
mean_elapsed_time, batch_timings = self._profile(
function_name, warmup_iterations, min_timing_iterations, min_time_in_sec, input_sets_minimum_size_MB)
print(
f"[Benchmarking] Mean duration per iteration: {mean_elapsed_time:.8f}s")
return mean_elapsed_time, batch_timings
@ -78,29 +81,30 @@ class Benchmark:
perf_counter_scale = 1
return perf_counter, perf_counter_scale
def generate_input_sets(parameters: List[ArgInfo], input_sets_minimum_size_MB: int, num_additional: int = 10):
shapes_to_sizes = [reduce(lambda x, y: x * y, p.numpy_shape) for p in parameters]
set_size = reduce(lambda x, y: x + y, [size * p.element_num_bytes for size, p in zip(shapes_to_sizes, parameters)])
num_input_sets = (input_sets_minimum_size_MB * 1024 * 1024 // set_size) + 1 + num_additional
print(f"[Benchmarking] Using {num_input_sets} input sets, each {set_size} bytes")
return [[np.random.random(p.numpy_shape).astype(p.numpy_dtype) for p in parameters] for _ in range(num_input_sets)]
parameters = self.hat_arg_descriptions[function_name]
# generate sufficient input sets to overflow the L3 cache, since we don't know the size of the model
# we'll make a guess based on the minimum input set size
input_sets = generate_input_sets(parameters, input_sets_minimum_size_MB)
input_sets = generate_input_sets(
parameters, input_sets_minimum_size_MB, num_additional=10)
set_size = 0
for i in input_sets[0]:
set_size += i.size * i.dtype.itemsize
print(
f"[Benchmarking] Using {len(input_sets)} input sets, each {set_size} bytes")
perf_counter, perf_counter_scale = get_perf_counter()
print(f"[Benchmarking] Warming up for {warmup_iterations} iterations...")
print(
f"[Benchmarking] Warming up for {warmup_iterations} iterations...")
for _ in range(warmup_iterations):
for calling_args in input_sets:
self.hat_package[function_name](*calling_args)
print(f"[Benchmarking] Timing for at least {min_time_in_sec}s and at least {min_timing_iterations} iterations...")
print(
f"[Benchmarking] Timing for at least {min_time_in_sec}s and at least {min_timing_iterations} iterations...")
start_time = perf_counter()
end_time = perf_counter()
@ -115,7 +119,8 @@ class Benchmark:
i = iterations % i_max
iterations += 1
end_time = perf_counter()
batch_timings.append((end_time - batch_start_time) / perf_counter_scale)
batch_timings.append(
(end_time - batch_start_time) / perf_counter_scale)
elapsed_time = ((end_time - start_time) / perf_counter_scale)
mean_elapsed_time = elapsed_time / iterations
@ -135,8 +140,8 @@ def write_runtime_to_hat_file(hat_path, function_name, mean_time_secs):
# Workaround to remove extra empty lines
with open(hat_path, "r") as f:
lines = f.readlines()
lines = [lines[i] for i in range(len(lines)) if not(lines[i] == "\n" \
and i < len(lines)-1 and lines[i+1] == "\n")]
lines = [lines[i] for i in range(len(lines)) if not(lines[i] == "\n"
and i < len(lines)-1 and lines[i+1] == "\n")]
with open(hat_path, "w") as f:
f.writelines(lines)
@ -148,27 +153,29 @@ def run_benchmark(hat_path, store_in_hat=False, batch_size=10, min_time_in_sec=1
functions = benchmark.hat_functions
for function_name in functions:
print(f"\nBenchmarking function: {function_name}")
if "Initialize" in function_name or "_debug_check_allclose" in function_name : # Skip init and debug functions
if "Initialize" in function_name or "_debug_check_allclose" in function_name: # Skip init and debug functions
continue
try:
_, batch_timings = benchmark.run(function_name,
warmup_iterations=batch_size,
min_timing_iterations=batch_size,
min_time_in_sec=min_time_in_sec,
input_sets_minimum_size_MB=input_sets_minimum_size_MB)
warmup_iterations=batch_size,
min_timing_iterations=batch_size,
min_time_in_sec=min_time_in_sec,
input_sets_minimum_size_MB=input_sets_minimum_size_MB)
sorted_batch_means = np.array(sorted(batch_timings)) / batch_size
num_batches = len(batch_timings)
mean_of_means = sorted_batch_means.mean()
median_of_means = sorted_batch_means[num_batches//2]
mean_of_small_means = sorted_batch_means[0 : num_batches//2].mean()
robust_mean_of_means = sorted_batch_means[num_batches//5 : -num_batches//5].mean()
mean_of_small_means = sorted_batch_means[0: num_batches//2].mean()
robust_mean_of_means = sorted_batch_means[num_batches //
5: -num_batches//5].mean()
min_of_means = sorted_batch_means[0]
if store_in_hat:
write_runtime_to_hat_file(hat_path, function_name, mean_of_means)
write_runtime_to_hat_file(
hat_path, function_name, mean_of_means)
results.append({"function_name": function_name,
"mean": mean_of_means,
"median_of_means": median_of_means,
@ -178,9 +185,11 @@ def run_benchmark(hat_path, store_in_hat=False, batch_size=10, min_time_in_sec=1
})
except Exception as e:
exc_type, exc_val, exc_tb = sys.exc_info()
traceback.print_exception(exc_type, exc_val, exc_tb, file=sys.stderr)
traceback.print_exception(
exc_type, exc_val, exc_tb, file=sys.stderr)
print("\nException message: ", e)
print(f"WARNING: Failed to run function {function_name}, skipping this benchmark.")
print(
f"WARNING: Failed to run function {function_name}, skipping this benchmark.")
return results
@ -191,27 +200,28 @@ def main(argv):
" hatlib.benchmark_hat_package <hat_path>\n")
arg_parser.add_argument("hat_path",
help="Path to the HAT file",
default=None)
help="Path to the HAT file",
default=None)
arg_parser.add_argument("--store_in_hat",
help="If set, will write the duration as meta-data back into the hat file",
action='store_true')
help="If set, will write the duration as meta-data back into the hat file",
action='store_true')
arg_parser.add_argument("--results_file",
help="Full path where the results will be written",
default="results.csv")
help="Full path where the results will be written",
default="results.csv")
arg_parser.add_argument("--batch_size",
help="The number of function calls in each batch (at least one full batch is executed)",
default=10)
help="The number of function calls in each batch (at least one full batch is executed)",
default=10)
arg_parser.add_argument("--min_time_in_sec",
help="Minimum number of seconds to run the benchmark for",
default=30)
help="Minimum number of seconds to run the benchmark for",
default=30)
arg_parser.add_argument("--input_sets_minimum_size_MB",
help="Minimum size in MB of the input sets. Typically this is large enough to ensure eviction of the biggest cache on the target (e.g. L3 on an desktop CPU)",
default=50)
help="Minimum size in MB of the input sets. Typically this is large enough to ensure eviction of the biggest cache on the target (e.g. L3 on an desktop CPU)",
default=50)
args = vars(arg_parser.parse_args(argv))
results = run_benchmark(args["hat_path"], args["store_in_hat"], batch_size=int(args["batch_size"]), min_time_in_sec=int(args["min_time_in_sec"]), input_sets_minimum_size_MB=int(args["input_sets_minimum_size_MB"]))
results = run_benchmark(args["hat_path"], args["store_in_hat"], batch_size=int(args["batch_size"]), min_time_in_sec=int(
args["min_time_in_sec"]), input_sets_minimum_size_MB=int(args["input_sets_minimum_size_MB"]))
df = pd.DataFrame(results)
df.to_csv(args["results_file"], index=False)
pd.options.display.float_format = '{:8.8f}'.format
@ -219,8 +229,10 @@ def main(argv):
print(f"Results saved to {args['results_file']}")
def main_command():
main(sys.argv[1:]) # drop the first argument (program name)
main(sys.argv[1:]) # drop the first argument (program name)
if __name__ == "__main__":
main_command()

389
hatlib/cuda_loader.py Normal file
Просмотреть файл

@ -0,0 +1,389 @@
import os
import pathlib
import sys
import numpy as np
from functools import reduce
from typing import Dict, List
# CUDA stuff
# TODO: move from pvnrtc module to cuda entirely to reduce dependencies
from pynvrtc.compiler import Program
from cuda import cuda, nvrtc
try:
from .arg_info import ArgInfo, verify_args, generate_input_sets
except:
from arg_info import ArgInfo, verify_args, generate_input_sets
# lifted from https://github.com/NVIDIA/jitify/blob/master/jitify.hpp
HEADER_MAP: Dict[str, str] = {
'float.h':
"""
#pragma once
#define FLT_RADIX 2
#define FLT_MANT_DIG 24
#define DBL_MANT_DIG 53
#define FLT_DIG 6
#define DBL_DIG 15
#define FLT_MIN_EXP -125
#define DBL_MIN_EXP -1021
#define FLT_MIN_10_EXP -37
#define DBL_MIN_10_EXP -307
#define FLT_MAX_EXP 128
#define DBL_MAX_EXP 1024
#define FLT_MAX_10_EXP 38
#define DBL_MAX_10_EXP 308
#define FLT_MAX 3.4028234e38f
#define DBL_MAX 1.7976931348623157e308
#define FLT_EPSILON 1.19209289e-7f
#define DBL_EPSILON 2.220440492503130e-16
#define FLT_MIN 1.1754943e-38f
#define DBL_MIN 2.2250738585072013e-308
#define FLT_ROUNDS 1
#if defined __cplusplus && __cplusplus >= 201103L
#define FLT_EVAL_METHOD 0
#define DECIMAL_DIG 21
#endif
""",
'limits.h':
"""
#pragma once
#if defined _WIN32 || defined _WIN64
#define __WORDSIZE 32
#else
#if defined __x86_64__ && !defined __ILP32__
#define __WORDSIZE 64
#else
#define __WORDSIZE 32
#endif
#endif
#define MB_LEN_MAX 16
#define CHAR_BIT 8
#define SCHAR_MIN (-128)
#define SCHAR_MAX 127
#define UCHAR_MAX 255
enum {
_JITIFY_CHAR_IS_UNSIGNED = (char)-1 >= 0,
CHAR_MIN = _JITIFY_CHAR_IS_UNSIGNED ? 0 : SCHAR_MIN,
CHAR_MAX = _JITIFY_CHAR_IS_UNSIGNED ? UCHAR_MAX : SCHAR_MAX,
};
#define SHRT_MIN (-32768)
#define SHRT_MAX 32767
#define USHRT_MAX 65535
#define INT_MIN (-INT_MAX - 1)
#define INT_MAX 2147483647
#define UINT_MAX 4294967295U
#if __WORDSIZE == 64
# define LONG_MAX 9223372036854775807L
#else
# define LONG_MAX 2147483647L
#endif
#define LONG_MIN (-LONG_MAX - 1L)
#if __WORDSIZE == 64
#define ULONG_MAX 18446744073709551615UL
#else
#define ULONG_MAX 4294967295UL
#endif
#define LLONG_MAX 9223372036854775807LL
#define LLONG_MIN (-LLONG_MAX - 1LL)
#define ULLONG_MAX 18446744073709551615ULL
""",
'stdint.h':
"""
#pragma once
#include <climits>
namespace __jitify_stdint_ns {
typedef signed char int8_t;
typedef signed short int16_t;
typedef signed int int32_t;
typedef signed long long int64_t;
typedef signed char int_fast8_t;
typedef signed short int_fast16_t;
typedef signed int int_fast32_t;
typedef signed long long int_fast64_t;
typedef signed char int_least8_t;
typedef signed short int_least16_t;
typedef signed int int_least32_t;
typedef signed long long int_least64_t;
typedef signed long long intmax_t;
typedef signed long intptr_t; //optional
typedef unsigned char uint8_t;
typedef unsigned short uint16_t;
typedef unsigned int uint32_t;
typedef unsigned long long uint64_t;
typedef unsigned char uint_fast8_t;
typedef unsigned short uint_fast16_t;
typedef unsigned int uint_fast32_t;
typedef unsigned long long uint_fast64_t;
typedef unsigned char uint_least8_t;
typedef unsigned short uint_least16_t;
typedef unsigned int uint_least32_t;
typedef unsigned long long uint_least64_t;
typedef unsigned long long uintmax_t;
#define INT8_MIN SCHAR_MIN
#define INT16_MIN SHRT_MIN
#if defined _WIN32 || defined _WIN64
#define WCHAR_MIN 0
#define WCHAR_MAX USHRT_MAX
typedef unsigned long long uintptr_t; //optional
#else
#define WCHAR_MIN INT_MIN
#define WCHAR_MAX INT_MAX
typedef unsigned long uintptr_t; //optional
#endif
#define INT32_MIN INT_MIN
#define INT64_MIN LLONG_MIN
#define INT8_MAX SCHAR_MAX
#define INT16_MAX SHRT_MAX
#define INT32_MAX INT_MAX
#define INT64_MAX LLONG_MAX
#define UINT8_MAX UCHAR_MAX
#define UINT16_MAX USHRT_MAX
#define UINT32_MAX UINT_MAX
#define UINT64_MAX ULLONG_MAX
#define INTPTR_MIN LONG_MIN
#define INTMAX_MIN LLONG_MIN
#define INTPTR_MAX LONG_MAX
#define INTMAX_MAX LLONG_MAX
#define UINTPTR_MAX ULONG_MAX
#define UINTMAX_MAX ULLONG_MAX
#define PTRDIFF_MIN INTPTR_MIN
#define PTRDIFF_MAX INTPTR_MAX
#define SIZE_MAX UINT64_MAX
} // namespace __jitify_stdint_ns
namespace std { using namespace __jitify_stdint_ns; }
using namespace __jitify_stdint_ns;
""",
'math.h':
"""
#pragma once
namespace __jitify_math_ns {
#if __cplusplus >= 201103L
#define DEFINE_MATH_UNARY_FUNC_WRAPPER(f) \\
inline double f(double x) { return ::f(x); } \\
inline float f##f(float x) { return ::f(x); } \\
/*inline long double f##l(long double x) { return ::f(x); }*/ \\
inline float f(float x) { return ::f(x); } \\
/*inline long double f(long double x) { return ::f(x); }*/
#else
#define DEFINE_MATH_UNARY_FUNC_WRAPPER(f) \\
inline double f(double x) { return ::f(x); } \\
inline float f##f(float x) { return ::f(x); } \\
/*inline long double f##l(long double x) { return ::f(x); }*/
#endif
DEFINE_MATH_UNARY_FUNC_WRAPPER(cos)
DEFINE_MATH_UNARY_FUNC_WRAPPER(sin)
DEFINE_MATH_UNARY_FUNC_WRAPPER(tan)
DEFINE_MATH_UNARY_FUNC_WRAPPER(acos)
DEFINE_MATH_UNARY_FUNC_WRAPPER(asin)
DEFINE_MATH_UNARY_FUNC_WRAPPER(atan)
template<typename T> inline T atan2(T y, T x) { return ::atan2(y, x); }
DEFINE_MATH_UNARY_FUNC_WRAPPER(cosh)
DEFINE_MATH_UNARY_FUNC_WRAPPER(sinh)
DEFINE_MATH_UNARY_FUNC_WRAPPER(tanh)
DEFINE_MATH_UNARY_FUNC_WRAPPER(exp)
template<typename T> inline T frexp(T x, int* exp) { return ::frexp(x, exp); }
template<typename T> inline T ldexp(T x, int exp) { return ::ldexp(x, exp); }
DEFINE_MATH_UNARY_FUNC_WRAPPER(log)
DEFINE_MATH_UNARY_FUNC_WRAPPER(log10)
template<typename T> inline T modf(T x, T* intpart) { return ::modf(x, intpart); }
template<typename T> inline T pow(T x, T y) { return ::pow(x, y); }
DEFINE_MATH_UNARY_FUNC_WRAPPER(sqrt)
DEFINE_MATH_UNARY_FUNC_WRAPPER(ceil)
DEFINE_MATH_UNARY_FUNC_WRAPPER(floor)
template<typename T> inline T fmod(T n, T d) { return ::fmod(n, d); }
DEFINE_MATH_UNARY_FUNC_WRAPPER(fabs)
template<typename T> inline T abs(T x) { return ::abs(x); }
#if __cplusplus >= 201103L
DEFINE_MATH_UNARY_FUNC_WRAPPER(acosh)
DEFINE_MATH_UNARY_FUNC_WRAPPER(asinh)
DEFINE_MATH_UNARY_FUNC_WRAPPER(atanh)
DEFINE_MATH_UNARY_FUNC_WRAPPER(exp2)
DEFINE_MATH_UNARY_FUNC_WRAPPER(expm1)
template<typename T> inline int ilogb(T x) { return ::ilogb(x); }
DEFINE_MATH_UNARY_FUNC_WRAPPER(log1p)
DEFINE_MATH_UNARY_FUNC_WRAPPER(log2)
DEFINE_MATH_UNARY_FUNC_WRAPPER(logb)
template<typename T> inline T scalbn (T x, int n) { return ::scalbn(x, n); }
template<typename T> inline T scalbln(T x, long n) { return ::scalbn(x, n); }
DEFINE_MATH_UNARY_FUNC_WRAPPER(cbrt)
template<typename T> inline T hypot(T x, T y) { return ::hypot(x, y); }
DEFINE_MATH_UNARY_FUNC_WRAPPER(erf)
DEFINE_MATH_UNARY_FUNC_WRAPPER(erfc)
DEFINE_MATH_UNARY_FUNC_WRAPPER(tgamma)
DEFINE_MATH_UNARY_FUNC_WRAPPER(lgamma)
DEFINE_MATH_UNARY_FUNC_WRAPPER(trunc)
DEFINE_MATH_UNARY_FUNC_WRAPPER(round)
template<typename T> inline long lround(T x) { return ::lround(x); }
template<typename T> inline long long llround(T x) { return ::llround(x); }
DEFINE_MATH_UNARY_FUNC_WRAPPER(rint)
template<typename T> inline long lrint(T x) { return ::lrint(x); }
template<typename T> inline long long llrint(T x) { return ::llrint(x); }
DEFINE_MATH_UNARY_FUNC_WRAPPER(nearbyint)
// TODO: remainder, remquo, copysign, nan, nextafter, nexttoward, fdim,
// fmax, fmin, fma
#endif
#undef DEFINE_MATH_UNARY_FUNC_WRAPPER
} // namespace __jitify_math_ns
namespace std { using namespace __jitify_math_ns; }
#define M_PI 3.14159265358979323846
// Note: Global namespace already includes CUDA math funcs
//using namespace __jitify_math_ns;
""",
'cuda_fp16.h': "",
}
HEADER_MAP['climits'] = HEADER_MAP['limits.h']
def ASSERT_DRV(err):
if isinstance(err, cuda.CUresult):
if err != cuda.CUresult.CUDA_SUCCESS:
raise RuntimeError("Cuda Error: {}".format(
cuda.cuGetErrorString(err)[1].decode('utf-8')))
elif isinstance(err, nvrtc.nvrtcResult):
if err != nvrtc.nvrtcResult.NVRTC_SUCCESS:
raise RuntimeError("Nvrtc Error: {}".format(err))
else:
raise RuntimeError("Unknown error type: {}".format(err))
def _find_cuda_incl_path() -> pathlib.Path:
"Tries to find the CUDA include path."
cuda_path = os.getenv("CUDA_PATH")
if not cuda_path:
if sys.platform == 'linux':
cuda_path = pathlib.Path("/usr/local/cuda/include")
if not (cuda_path.exists() and cuda_path.is_dir()):
cuda_path = None
elif sys.platform == 'win32':
...
elif sys.platform == 'darwin':
...
else:
cuda_path /= "include"
return cuda_path
def compile_cuda_program(cuda_src_path: pathlib.Path, func_name):
src = cuda_src_path.read_text()
prog = Program(src=src, name=func_name,
headers=HEADER_MAP.values(), include_names=HEADER_MAP.keys())
ptx = prog.compile([
'-use_fast_math',
'-default-device',
'-std=c++11',
'-arch=sm_52', # TODO: is this needed?
])
return ptx
def initialize_cuda():
# Initialize CUDA Driver API
err, = cuda.cuInit(0)
ASSERT_DRV(err)
# Retrieve handle for device 0
# TODO: add support for multiple CUDA devices?
err, cuDevice = cuda.cuDeviceGet(0)
ASSERT_DRV(err)
# Create context
err, context = cuda.cuCtxCreate(0, cuDevice)
ASSERT_DRV(err)
def get_func_from_ptx(ptx, func_name):
# Note: Incompatible --gpu-architecture would be detected here
err, ptx_mod = cuda.cuModuleLoadData(ptx.encode('utf-8'))
ASSERT_DRV(err)
err, kernel = cuda.cuModuleGetFunction(ptx_mod, func_name.encode('utf-8'))
ASSERT_DRV(err)
return kernel
def _arg_size(arg_info: ArgInfo):
return arg_info.element_num_bytes * reduce(lambda x, y: x*y, arg_info.numpy_shape)
def transfer_mem_host_to_cuda(device_args: List, host_args: List[np.array], arg_infos: List[ArgInfo]):
for device_arg, host_arg, arg_info in zip(device_args, host_args, arg_infos):
if 'input' in arg_info.usage:
err, = cuda.cuMemcpyHtoD(
device_arg, host_arg.ctypes.data, _arg_size(arg_info))
ASSERT_DRV(err)
def transfer_mem_cuda_to_host(device_args: List, host_args: List[np.array], arg_infos: List[ArgInfo]):
for device_arg, host_arg, arg_info in zip(device_args, host_args, arg_infos):
if 'output' in arg_info.usage:
err, = cuda.cuMemcpyDtoH(
host_arg.ctypes.data, device_arg, _arg_size(arg_info))
ASSERT_DRV(err)
def allocate_cuda_mem(arg_infos: List[ArgInfo]):
device_mem = []
for arg in arg_infos:
err, mem = cuda.cuMemAlloc(_arg_size(arg))
ASSERT_DRV(err)
device_mem.append(mem)
return device_mem
def device_args_to_ptr_list(device_args: List):
# CUDA python example says this is subject to change
ptrs = [
np.array([int(d_arg)], dtype=np.uint64) for d_arg in device_args
]
ptrs = np.array([ptr.ctypes.data for ptr in ptrs], dtype=np.uint64)
return ptrs
def create_loader_for_device_function(device_func, hat_details):
hat_path: pathlib.Path = hat_details.path
cuda_src_path: pathlib.Path = hat_path.parent / device_func["provider"]
func_name = device_func["name"]
ptx = compile_cuda_program(cuda_src_path, func_name)
initialize_cuda()
kernel = get_func_from_ptx(ptx, func_name)
hat_arg_descriptions = device_func["arguments"]
arg_infos = [ArgInfo(d) for d in hat_arg_descriptions]
launch_parameters = device_func["launch_parameters"]
def f(*args):
verify_args(args, arg_infos, func_name)
device_mem = allocate_cuda_mem(arg_infos)
transfer_mem_host_to_cuda(
device_args=device_mem, host_args=args, arg_infos=arg_infos)
ptrs = device_args_to_ptr_list(device_mem)
err, stream = cuda.cuStreamCreate(0)
ASSERT_DRV(err)
err, = cuda.cuLaunchKernel(
kernel,
*launch_parameters, # [ grid[x-z], block[x-z] ]
0, # dynamic shared memory
stream, # stream
ptrs.ctypes.data, # kernel arguments
0, # extra (ignore)
)
ASSERT_DRV(err)
err, = cuda.cuStreamSynchronize(stream)
ASSERT_DRV(err)
transfer_mem_cuda_to_host(
device_args=device_mem, host_args=args, arg_infos=arg_infos)
return f

165
hatlib/hat.py Normal file
Просмотреть файл

@ -0,0 +1,165 @@
"""Loads a dynamically-linked HAT package in Python
Call 'load' to load a HAT package in Python. After loading, call the HAT
functions using numpy arrays as arguments. The shape, element type, and order
of each numpy array should exactly match the requirements of the HAT function.
For example:
import numpy as np
import hatlib as hat
# load the package
package = hat.load("my_package.hat")
# print the function names
for name in package.names:
print(name)
# create numpy arguments with the correct shape, dtype, and order
A = np.ones([256,32], dtype=np.float32, order="C")
B = np.ones([32,256], dtype=np.float32, order="C")
D = np.ones([256,32], dtype=np.float32, order="C")
E = np.ones([256,32], dtype=np.float32, order="C")
# call a package function named 'my_func_698b5e5c'
package.my_func_698b5e5c(A, B, D, E)
"""
import ctypes
import numpy as np
import pathlib
import sys
import toml
from collections import OrderedDict
from dataclasses import dataclass
from typing import Any, Tuple
try:
from . import hat_file
from .arg_info import ArgInfo, verify_args, generate_input_sets
except:
import hat_file
from arg_info import ArgInfo, verify_args, generate_input_sets
try:
try:
from . import cuda_loader
except ModuleNotFoundError:
import cuda_loader
except:
CUDA_AVAILABLE = False
else:
CUDA_AVAILABLE = True
# Remove when ROCM is finally available
ROCM_AVAILABLE = False
def generate_input_sets_for_hat_file(hat_path):
hat_path = pathlib.Path(hat_path).absolute()
t: hat_file.HATFile = toml.load(hat_path)
return {
func_name:
generate_input_sets(list(map(ArgInfo, func_desc["arguments"])))
for func_name, func_desc in t["functions"].items()
}
class AttributeDict(OrderedDict):
""" Dictionary that allows entries to be accessed like attributes
"""
__getattr__ = OrderedDict.__getitem__
@property
def names(self):
return list(self.keys())
def __getitem__(self, key):
for k, v in self.items():
if k.startswith(key):
return v
return OrderedDict.__getitem__(key)
def hat_description_to_python_function(hat_description: hat_file.HATFile,
hat_details: AttributeDict):
""" Creates a callable function based on a function description in a HAT
package
"""
for func_name, func_desc in hat_description["functions"].items():
func_desc: hat_file.Function
func_name: str
launches = func_desc.get("launches")
if not launches:
hat_arg_descriptions = func_desc["arguments"]
function_name = func_desc["name"]
hat_library: ctypes.CDLL = hat_details.shared_lib
def f(*args):
# verify that the (numpy) input args match the description in
# the hat file
arg_infos = [ArgInfo(d) for d in hat_arg_descriptions]
verify_args(args, arg_infos, function_name)
# prepare the args to the hat package
hat_args = [
arg.ctypes.data_as(arg_info.ctypes_pointer_type)
for arg, arg_info in zip(args, arg_infos)
]
# call the function in the hat package
hat_library[function_name](*hat_args)
yield func_name, f
else:
device_func = hat_description.get("device_functions",
{}).get(launches)
func_runtime = func_desc.get("runtime")
if not device_func:
raise RuntimeError(
f"Couldn't find device function for loader: " + launches)
if not func_runtime:
raise RuntimeError(f"Couldn't find runtime for loader: " +
launches)
if func_runtime == "CUDA" and CUDA_AVAILABLE:
yield func_name, cuda_loader.create_loader_for_device_function(
device_func, hat_details)
elif func_runtime == "ROCM" and ROCM_AVAILABLE:
yield func_name, rocm_loader.create_loader_for_device_function(
device_func, hat_details)
def load(hat_path):
""" Creates a class with static functions based on the function
descriptions in a HAT package
"""
# load the function decscriptions from the hat file
hat_path = pathlib.Path(hat_path).absolute()
t: hat_file.HATFile = toml.load(hat_path)
hat_details = AttributeDict({"path": hat_path})
# function_descriptions = t["functions"]
hat_binary_filename = t["dependencies"]["link_target"]
hat_binary_path = hat_path.parent / hat_binary_filename
# check that the HAT library has a supported file extension
supported_extensions = [".dll", ".so"]
extension = hat_binary_path.suffix
if extension and extension not in supported_extensions:
sys.exit(f"Unsupported HAT library extension: {extension}")
# load the hat_library:
hat_library = ctypes.cdll.LoadLibrary(
str(hat_binary_path)) if extension else None
hat_details["shared_lib"] = hat_library
# create dictionary of functions defined in the hat file
function_dict = AttributeDict(
dict(hat_description_to_python_function(t, hat_details)))
return function_dict

Просмотреть файл

@ -10,6 +10,7 @@ import tomlkit
# TODO : type-checking on leaf node values
def _read_toml_file(filepath):
path = os.path.abspath(filepath)
toml_doc = None
@ -18,11 +19,13 @@ def _read_toml_file(filepath):
toml_doc = tomlkit.parse(file_contents)
return toml_doc
def _check_required_table_entry(table, key):
if key not in table:
# TODO : add more context to this error message
raise ValueError(f"Invalid HAT file: missing required key {key}")
def _check_required_table_entries(table, keys):
for key in keys:
_check_required_table_entry(table, key)
@ -34,26 +37,32 @@ class ParameterType(Enum):
Element = "element"
Void = "void"
class UsageType(Enum):
Input = "input"
Output = "output"
InputOutput = "input_output"
class CallingConventionType(Enum):
StdCall = "stdcall"
CDecl = "cdecl"
FastCall = "fastcall"
VectorCall = "vectorcall"
Device = "devicecall"
class TargetType(Enum):
CPU = "CPU"
GPU = "GPU"
class OperatingSystem(Enum):
Windows = "windows"
MacOS = "macos"
Linux = "linux"
@dataclass
class AuxiliarySupportedTable:
AuxiliaryKey = "auxiliary"
@ -70,6 +79,7 @@ class AuxiliarySupportedTable:
else:
return {}
@dataclass
class Description(AuxiliarySupportedTable):
TableName: str = "description"
@ -91,10 +101,12 @@ class Description(AuxiliarySupportedTable):
@staticmethod
def parse_from_table(table):
return Description(author=table["author"],
version=table["version"],
license_url=table["license_url"],
auxiliary=AuxiliarySupportedTable.parse_auxiliary(table))
return Description(
author=table["author"],
version=table["version"],
license_url=table["license_url"],
auxiliary=AuxiliarySupportedTable.parse_auxiliary(table))
@dataclass
class Parameter:
@ -135,9 +147,14 @@ class Parameter:
# TODO : change "usage" to "role" in schema
@staticmethod
def parse_from_table(param_table):
required_table_entries = ["name", "description", "logical_type", "declared_type", "element_type", "usage"]
required_table_entries = [
"name", "description", "logical_type", "declared_type",
"element_type", "usage"
]
_check_required_table_entries(param_table, required_table_entries)
affine_array_required_table_entries = ["shape", "affine_map", "affine_offset"]
affine_array_required_table_entries = [
"shape", "affine_map", "affine_offset"
]
runtime_array_required_table_entries = ["size"]
name = param_table["name"]
@ -147,14 +164,23 @@ class Parameter:
element_type = param_table["element_type"]
usage = UsageType(param_table["usage"])
param = Parameter(name=name, description=description, logical_type=logical_type, declared_type=declared_type, element_type=element_type, usage=usage)
param = Parameter(name=name,
description=description,
logical_type=logical_type,
declared_type=declared_type,
element_type=element_type,
usage=usage)
if logical_type == ParameterType.AffineArray:
_check_required_table_entries(param_table, affine_array_required_table_entries)
_check_required_table_entries(param_table,
affine_array_required_table_entries)
param.shape = param_table["shape"]
param.affine_map = param_table["affine_map"]
param.affine_offset = param_table["affine_offset"]
elif logical_type == ParameterType.RuntimeArray:
_check_required_table_entries(param_table, runtime_array_required_table_entries)
_check_required_table_entries(
param_table, runtime_array_required_table_entries)
param.size = param_table["size"]
return param
@ -162,13 +188,20 @@ class Parameter:
@dataclass
class Function(AuxiliarySupportedTable):
name: str = ""
description: str = ""
calling_convention: CallingConventionType = None
# required
arguments: list = field(default_factory=list)
return_info: Parameter = None
calling_convention: CallingConventionType = None
description: str = ""
hat_file: any = None
link_target: Path = None
name: str = ""
return_info: Parameter = None
# optional
launch_parameters: list = field(default_factory=list)
launches: str = ""
provider: str = ""
runtime: str = ""
def to_table(self):
table = tomlkit.table()
@ -179,7 +212,22 @@ class Function(AuxiliarySupportedTable):
arg_array = tomlkit.array()
for arg_table in arg_tables:
arg_array.append(arg_table)
table.add("arguments", arg_array) # TODO : figure out why this isn't indenting after serialization in some cases
table.add(
"arguments", arg_array
) # TODO : figure out why this isn't indenting after serialization in some cases
if self.launch_parameters:
table.add("launch_parameters", self.launch_parameters)
if self.launches:
table.add("launches", self.launches)
if self.provider:
table.add("provider", self.provider)
if self.runtime:
table.add("runtime", self.runtime)
table.add("return", self.return_info.to_table())
self.add_auxiliary_table(table)
@ -188,72 +236,153 @@ class Function(AuxiliarySupportedTable):
@staticmethod
def parse_from_table(function_table):
required_table_entries = ["name", "description", "calling_convention", "arguments", "return"]
required_table_entries = [
"name", "description", "calling_convention", "arguments", "return"
]
_check_required_table_entries(function_table, required_table_entries)
arguments = [Parameter.parse_from_table(param_table) for param_table in function_table["arguments"]]
arguments = [
Parameter.parse_from_table(param_table)
for param_table in function_table["arguments"]
]
launch_parameters = function_table[
"launch_parameters"] if "launch_parameters" in function_table else []
launches = function_table[
"launches"] if "launches" in function_table else ""
provider = function_table[
"provider"] if "provider" in function_table else ""
runtime = function_table[
"runtime"] if "runtime" in function_table else ""
return_info = Parameter.parse_from_table(function_table["return"])
return Function(name=function_table["name"],
description=function_table["description"],
calling_convention=CallingConventionType(function_table["calling_convention"]),
arguments=arguments,
return_info=return_info,
auxiliary=AuxiliarySupportedTable.parse_auxiliary(function_table))
return Function(
name=function_table["name"],
description=function_table["description"],
calling_convention=CallingConventionType(
function_table["calling_convention"]),
arguments=arguments,
return_info=return_info,
launch_parameters=launch_parameters,
launches=launches,
provider=provider,
runtime=runtime,
auxiliary=AuxiliarySupportedTable.parse_auxiliary(function_table))
class FunctionTable:
TableName = "functions"
class FunctionTableCommon:
def __init__(self, function_map):
self.function_map = function_map
self.functions = self.function_map.values()
def to_table(self):
serialized_map = { function_key : self.function_map[function_key].to_table() for function_key in self.function_map }
func_table = tomlkit.table()
for function_key in self.function_map:
func_table.add(function_key, self.function_map[function_key].to_table())
func_table.add(function_key,
self.function_map[function_key].to_table())
return func_table
@staticmethod
def parse_from_table(all_functions_table):
function_map = {function_key: Function.parse_from_table(all_functions_table[function_key]) for function_key in all_functions_table}
return FunctionTable(function_map)
@classmethod
def parse_from_table(cls, all_functions_table):
function_map = {
function_key:
Function.parse_from_table(all_functions_table[function_key])
for function_key in all_functions_table
}
return cls(function_map)
class FunctionTable(FunctionTableCommon):
TableName = "functions"
class DeviceFunctionTable(FunctionTableCommon):
TableName = "device_functions"
@dataclass
class Target:
@dataclass
class Required:
@dataclass
class CPU:
TableName = TargetType.CPU.value
# required
architecture: str = ""
extensions: list = field(default_factory=list)
# optional
runtime: str = ""
def to_table(self):
table = tomlkit.table()
table.add("architecture", self.architecture)
table.add("extensions", self.extensions)
if self.runtime:
table.add("runtime", self.runtime)
return table
@staticmethod
def parse_from_table(table):
required_table_entries = ["architecture", "extensions"]
_check_required_table_entries(table, required_table_entries)
return Target.Required.CPU(architecture=table["architecture"], extensions=table["extensions"])
# TODO : support GPU
runtime = table.get("runtime", "")
return Target.Required.CPU(
architecture=table["architecture"],
extensions=table["extensions"],
runtime=runtime)
@dataclass
class GPU:
TableName = TargetType.CPU.value
TableName = TargetType.GPU.value
blocks: int = 0
instruction_set_version: str = ""
min_threads: int = 0
min_global_memory_KB: int = 0
min_shared_memory_KB: int = 0
min_texture_memory_KB: int = 0
model: str = ""
runtime: str = ""
def to_table(self):
return tomlkit.table()
table = tomlkit.table()
table.add("model", self.model)
table.add("runtime", self.runtime)
table.add("blocks", self.blocks)
table.add("instruction_set_version",
self.instruction_set_version)
table.add("min_threads", self.min_threads)
table.add("min_global_memory_KB", self.min_global_memory_KB)
table.add("min_shared_memory_KB", self.min_shared_memory_KB)
table.add("min_texture_memory_KB", self.min_texture_memory_KB)
return table
@staticmethod
def parse_from_table(table):
pass
required_table_entries = [
"runtime",
"model",
]
_check_required_table_entries(table, required_table_entries)
return Target.Required.GPU(
runtime=table["runtime"],
model=table["model"],
blocks=table["blocks"],
instruction_set_version=table["instruction_set_version"],
min_threads=table["min_threads"],
min_global_memory_KB=table["min_global_memory_KB"],
min_shared_memory_KB=table["min_shared_memory_KB"],
min_texture_memory_KB=table["min_texture_memory_KB"])
TableName = "required"
os: OperatingSystem = None
@ -264,7 +393,7 @@ class Target:
table = tomlkit.table()
table.add("os", self.os.value)
table.add(Target.Required.CPU.TableName, self.cpu.to_table())
if self.gpu is not None:
if self.gpu and self.gpu.runtime:
table.add(Target.Required.GPU.TableName, self.gpu.to_table())
return table
@ -272,9 +401,11 @@ class Target:
def parse_from_table(table):
required_table_entries = ["os", Target.Required.CPU.TableName]
_check_required_table_entries(table, required_table_entries)
cpu_info = Target.Required.CPU.parse_from_table(table[Target.Required.CPU.TableName])
cpu_info = Target.Required.CPU.parse_from_table(
table[Target.Required.CPU.TableName])
if Target.Required.GPU.TableName in table:
gpu_info = Target.Required.GPU.parse_from_table(table[Target.Required.GPU.TableName])
gpu_info = Target.Required.GPU.parse_from_table(
table[Target.Required.GPU.TableName])
else:
gpu_info = Target.Required.GPU()
return Target.Required(os=table["os"], cpu=cpu_info, gpu=gpu_info)
@ -298,20 +429,24 @@ class Target:
table = tomlkit.table()
table.add(Target.Required.TableName, self.required.to_table())
if self.optimized_for is not None:
table.add(Target.OptimizedFor.TableName, self.optimized_for.to_table())
table.add(Target.OptimizedFor.TableName,
self.optimized_for.to_table())
return table
@staticmethod
def parse_from_table(target_table):
required_table_entries = [Target.Required.TableName]
_check_required_table_entries(target_table, required_table_entries)
required_data = Target.Required.parse_from_table(target_table[Target.Required.TableName])
required_data = Target.Required.parse_from_table(
target_table[Target.Required.TableName])
if Target.OptimizedFor.TableName in target_table:
optimized_for_data = Target.OptimizedFor.parse_from_table(target_table[Target.OptimizedFor.TableName])
optimized_for_data = Target.OptimizedFor.parse_from_table(
target_table[Target.OptimizedFor.TableName])
else:
optimized_for_data = Target.OptimizedFor()
return Target(required=required_data, optimized_for=optimized_for_data)
@dataclass
class LibraryReference:
name: str = ""
@ -328,8 +463,8 @@ class LibraryReference:
@staticmethod
def parse_from_table(table):
return LibraryReference(name=table["name"],
version=table["version"],
target_file=table["target_file"])
version=table["version"],
target_file=table["target_file"])
@dataclass
@ -355,12 +490,18 @@ class Dependencies(AuxiliarySupportedTable):
@staticmethod
def parse_from_table(dependencies_table):
required_table_entries = ["link_target", "deploy_files", "dynamic"]
_check_required_table_entries(dependencies_table, required_table_entries)
dynamic = [LibraryReference.parse_from_table(lib_ref_table) for lib_ref_table in dependencies_table["dynamic"]]
_check_required_table_entries(dependencies_table,
required_table_entries)
dynamic = [
LibraryReference.parse_from_table(lib_ref_table)
for lib_ref_table in dependencies_table["dynamic"]
]
return Dependencies(link_target=dependencies_table["link_target"],
deploy_files=dependencies_table["deploy_files"],
dynamic=dynamic,
auxiliary=AuxiliarySupportedTable.parse_auxiliary(dependencies_table))
auxiliary=AuxiliarySupportedTable.parse_auxiliary(
dependencies_table))
@dataclass
class CompiledWith:
@ -386,13 +527,18 @@ class CompiledWith:
@staticmethod
def parse_from_table(compiled_with_table):
required_table_entries = ["compiler", "flags", "crt", "libraries"]
_check_required_table_entries(compiled_with_table, required_table_entries)
libraries = [LibraryReference.parse_from_table(lib_ref_table) for lib_ref_table in compiled_with_table["libraries"]]
_check_required_table_entries(compiled_with_table,
required_table_entries)
libraries = [
LibraryReference.parse_from_table(lib_ref_table)
for lib_ref_table in compiled_with_table["libraries"]
]
return CompiledWith(compiler=compiled_with_table["compiler"],
flags=compiled_with_table["flags"],
crt=compiled_with_table["crt"],
libraries=libraries)
@dataclass
class Declaration:
TableName = "declaration"
@ -406,14 +552,16 @@ class Declaration:
@staticmethod
def parse_from_table(declaration_table):
required_table_entries = ["code"]
_check_required_table_entries(declaration_table, required_table_entries)
_check_required_table_entries(declaration_table,
required_table_entries)
return Declaration(code=declaration_table["code"])
@dataclass
class HATFile:
"""Encapsulates a HAT file. An instance of this class can be created by calling the
"""Encapsulates a HAT file. An instance of this class can be created by calling the
Deserialize class method e.g.:
some_hat_file = Deserialize('someFile.hat')
some_hat_file = Deserialize('someFile.hat')
Similarly, HAT files can be serialized but creating/modifying a HATFile instance
and then calling Serilize e.g.:
some_hat_file.name = 'some new name'
@ -422,8 +570,11 @@ class HATFile:
name: str = ""
description: Description = None
_function_table: FunctionTable = None
_device_function_table: DeviceFunctionTable = None
functions: list = field(default_factory=list)
device_functions: list = field(default_factory=list)
function_map: dict = field(default_factory=dict)
device_function_map: list = field(default_factory=list)
target: Target = None
dependencies: Dependencies = None
compiled_with: CompiledWith = None
@ -438,7 +589,13 @@ class HATFile:
self.function_map = self._function_table.function_map
for func in self.functions:
func.hat_file = self
func.link_target = Path(self.path).resolve().parent / self.dependencies.link_target
func.link_target = Path(self.path).resolve(
).parent / self.dependencies.link_target
if not self._device_function_table:
self._device_function_table = DeviceFunctionTable({})
self.device_function_map = self._device_function_table.function_map
self.device_functions = self._device_function_table.functions
def Serialize(self, filepath=None):
"""Serilizes the HATFile to disk using the file location specified by `filepath`.
@ -447,7 +604,11 @@ class HATFile:
filepath = self.path
root_table = tomlkit.table()
root_table.add(Description.TableName, self.description.to_table())
root_table.add(FunctionTable.TableName, self._function_table.to_table())
root_table.add(FunctionTable.TableName,
self._function_table.to_table())
if self.device_function_map:
root_table.add(DeviceFunctionTable.TableName,
self._device_function_table.to_table())
root_table.add(Target.TableName, self.target.to_table())
root_table.add(Dependencies.TableName, self.dependencies.to_table())
root_table.add(CompiledWith.TableName, self.compiled_with.to_table())
@ -460,23 +621,34 @@ class HATFile:
out_file.write(self.HATEpilogue.format(name))
@staticmethod
def Deserialize(filepath):
def Deserialize(filepath) -> "HATFile":
"""Creates an instance of A HATFile class by deserializing the contents of the file at `filepath`"""
hat_toml = _read_toml_file(filepath)
name = os.path.splitext(os.path.basename(filepath))[0]
required_entries = [Description.TableName,
FunctionTable.TableName,
Target.TableName,
Dependencies.TableName,
CompiledWith.TableName,
Declaration.TableName]
required_entries = [
Description.TableName, FunctionTable.TableName, Target.TableName,
Dependencies.TableName, CompiledWith.TableName,
Declaration.TableName
]
_check_required_table_entries(hat_toml, required_entries)
hat_file = HATFile(name=name,
description=Description.parse_from_table(hat_toml[Description.TableName]),
_function_table=FunctionTable.parse_from_table(hat_toml[FunctionTable.TableName]),
target=Target.parse_from_table(hat_toml[Target.TableName]),
dependencies=Dependencies.parse_from_table(hat_toml[Dependencies.TableName]),
compiled_with=CompiledWith.parse_from_table(hat_toml[CompiledWith.TableName]),
declaration=Declaration.parse_from_table(hat_toml[Declaration.TableName]),
path=Path(filepath).resolve())
device_function_table = None
if DeviceFunctionTable.TableName in hat_toml:
device_function_table = DeviceFunctionTable.parse_from_table(
hat_toml[DeviceFunctionTable.TableName])
hat_file = HATFile(
name=name,
description=Description.parse_from_table(
hat_toml[Description.TableName]),
_function_table=FunctionTable.parse_from_table(
hat_toml[FunctionTable.TableName]),
_device_function_table=device_function_table,
target=Target.parse_from_table(
hat_toml[Target.TableName]),
dependencies=Dependencies.parse_from_table(
hat_toml[Dependencies.TableName]),
compiled_with=CompiledWith.parse_from_table(
hat_toml[CompiledWith.TableName]),
declaration=Declaration.parse_from_table(
hat_toml[Declaration.TableName]),
path=Path(filepath).resolve())
return hat_file

Просмотреть файл

Просмотреть файл

Просмотреть файл

Просмотреть файл

Просмотреть файл

Просмотреть файл

Просмотреть файл

Просмотреть файл

@ -1,4 +1,3 @@
#!/usr/bin/env python3
import unittest
import sys, os
@ -8,6 +7,7 @@ sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
from benchmark_hat_package import run_benchmark
class BenchmarkHATPackage_test(unittest.TestCase):
def test_benchmark(self):
A = acc.Array(role=acc.Array.Role.INPUT, shape=(256, 256))
@ -23,9 +23,16 @@ class BenchmarkHATPackage_test(unittest.TestCase):
package = acc.Package()
package.add(nest, args=(A, B, C), base_name="test_function")
package.build(name="BenchmarkHATPackage_test_benchmark", output_dir="test_acccgen", format=acc.Package.Format.HAT_DYNAMIC)
package.build(name="BenchmarkHATPackage_test_benchmark",
output_dir="test_acccgen",
format=acc.Package.Format.HAT_DYNAMIC)
run_benchmark("test_acccgen/BenchmarkHATPackage_test_benchmark.hat",
store_in_hat=False,
batch_size=2,
min_time_in_sec=1,
input_sets_minimum_size_MB=1)
run_benchmark("test_acccgen/BenchmarkHATPackage_test_benchmark.hat", store_in_hat=False, batch_size=2, min_time_in_sec=1, input_sets_minimum_size_MB=1)
if __name__ == '__main__':
unittest.main()
unittest.main()

Просмотреть файл

@ -10,8 +10,8 @@ from hat import load
from hat_to_dynamic import create_dynamic_package
from hat_to_lib import create_static_package
class HAT_test(unittest.TestCase):
class HAT_test(unittest.TestCase):
def test_load(self):
# Generate a HAT package
@ -30,9 +30,12 @@ class HAT_test(unittest.TestCase):
for mode in [acc.Package.Mode.RELEASE, acc.Package.Mode.DEBUG]:
package_name = f"HAT_test_load_{mode.value}"
package.build(name=package_name, output_dir="test_acccgen", mode=mode)
package.build(name=package_name,
output_dir="test_acccgen",
mode=mode)
create_dynamic_package(f"test_acccgen/{package_name}.hat", f"test_acccgen/{package_name}.dyn.hat")
create_dynamic_package(f"test_acccgen/{package_name}.hat",
f"test_acccgen/{package_name}.dyn.hat")
hat_package = load(f"test_acccgen/{package_name}.dyn.hat")
@ -40,7 +43,7 @@ class HAT_test(unittest.TestCase):
print(name)
# create numpy arguments with the correct shape and dtype
A = np.random.rand(16, 16).astype(np.float32)
A = np.random.rand(16, 16).astype(np.float32)
B = np.random.rand(16, 16).astype(np.float32)
B_ref = B + A

Просмотреть файл

@ -0,0 +1,92 @@
#!/usr/bin/env python3
from pathlib import Path
import unittest
import sys, os
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
from hat_file import (CallingConventionType, CompiledWith, Declaration,
Dependencies, Description, Function, FunctionTable,
HATFile, OperatingSystem, Parameter, ParameterType,
Target, UsageType)
class HATFile_test(unittest.TestCase):
def test_file_basic_serialize(self):
# Construct a HAT file from scratch
# Start with a function definition
my_function = Function(
name="my_function",
description="Some description",
calling_convention=CallingConventionType.StdCall,
return_info=Parameter(logical_type=ParameterType.RuntimeArray,
declared_type="float*",
element_type="float",
usage=UsageType.Input,
shape="[16, 16]",
affine_map=[16, 1],
size="16 * 16 * sizeof(float)"))
# Create the function table
functions = FunctionTable({"my_function": my_function})
# Create the HATFile object
hat_file1 = HATFile(
name="test_file",
description=Description(
version="0.0.1",
author="me",
license_url="https://www.apache.org/licenses/LICENSE-2.0.html"
),
_function_table=functions,
target=Target(required=Target.Required(os=OperatingSystem.Windows,
cpu=Target.Required.CPU(
architecture="Haswell",
extensions=["AVX2"]),
gpu=None),
optimized_for=Target.OptimizedFor()),
dependencies=Dependencies(link_target="my_lib.lib"),
compiled_with=CompiledWith(compiler="VC++"),
declaration=Declaration(),
path=Path(".").resolve())
# Serialize it to disk
test_file_name = "test_file_serialize.hat"
try:
hat_file1.Serialize(test_file_name)
# Deserialize it and verify it has what we expect
hat_file2 = HATFile.Deserialize(test_file_name)
finally:
# Remove the file
os.remove(test_file_name)
# Do basic verification that the deserialized HatFile contains what we specified
# when we created the HATFile directly
self.assertEqual(hat_file1.description, hat_file2.description)
self.assertEqual(hat_file1.dependencies, hat_file2.dependencies)
self.assertEqual(hat_file1.compiled_with.to_table(),
hat_file2.compiled_with.to_table())
self.assertTrue("my_function" in hat_file2.function_map)
def test_file_basic_deserialize(self):
# Load a HAT file from the samples directory
hat_file1 = HATFile.Deserialize(
os.path.join(os.path.dirname(__file__), "..", "..", "samples",
"sample_gemm_library.hat"))
description = {
"author": "John Doe",
"version": "1.2.3.5",
"license_url": "https://www.apache.org/licenses/LICENSE-2.0.html",
}
# Do basic verification of known values in the file
# Verify the description has entries we expect
self.assertLessEqual(description.items(),
hat_file1.description.to_table().items())
# Verify the list of functions
self.assertTrue(len(hat_file1.functions) == 2)
self.assertTrue("GEMM_B94D27B9934D3E08" in hat_file1.function_map)
self.assertTrue("blas_sgemm_row_major" in hat_file1.function_map)
if __name__ == '__main__':
unittest.main()

Просмотреть файл

@ -0,0 +1,92 @@
#!/usr/bin/env python3
from pathlib import Path
import unittest
import sys, os
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
from hat_file import (CallingConventionType, CompiledWith, Declaration,
Dependencies, Description, Function, FunctionTable,
HATFile, OperatingSystem, Parameter, ParameterType,
Target, UsageType)
class HATFile_test(unittest.TestCase):
def test_file_basic_serialize(self):
# Construct a HAT file from scratch
# Start with a function definition
my_function = Function(
name="my_function",
description="Some description",
calling_convention=CallingConventionType.StdCall,
return_info=Parameter(logical_type=ParameterType.RuntimeArray,
declared_type="float*",
element_type="float",
usage=UsageType.Input,
shape="[16, 16]",
affine_map=[16, 1],
size="16 * 16 * sizeof(float)"))
# Create the function table
functions = FunctionTable({"my_function": my_function})
# Create the HATFile object
hat_file1 = HATFile(
name="test_file",
description=Description(
version="0.0.1",
author="me",
license_url="https://www.apache.org/licenses/LICENSE-2.0.html"
),
_function_table=functions,
target=Target(required=Target.Required(os=OperatingSystem.Windows,
cpu=Target.Required.CPU(
architecture="Haswell",
extensions=["AVX2"]),
gpu=None),
optimized_for=Target.OptimizedFor()),
dependencies=Dependencies(link_target="my_lib.lib"),
compiled_with=CompiledWith(compiler="VC++"),
declaration=Declaration(),
path=Path(".").resolve())
# Serialize it to disk
test_file_name = "test_file_serialize.hat"
try:
hat_file1.Serialize(test_file_name)
# Deserialize it and verify it has what we expect
hat_file2 = HATFile.Deserialize(test_file_name)
finally:
# Remove the file
os.remove(test_file_name)
# Do basic verification that the deserialized HatFile contains what we specified
# when we created the HATFile directly
self.assertEqual(hat_file1.description, hat_file2.description)
self.assertEqual(hat_file1.dependencies, hat_file2.dependencies)
self.assertEqual(hat_file1.compiled_with.to_table(),
hat_file2.compiled_with.to_table())
self.assertTrue("my_function" in hat_file2.function_map)
def test_file_basic_deserialize(self):
# Load a HAT file from the samples directory
hat_file1 = HATFile.Deserialize(
os.path.join(os.path.dirname(__file__), "..", "..", "samples",
"sample_gemm_library.hat"))
description = {
"author": "John Doe",
"version": "1.2.3.5",
"license_url": "https://www.apache.org/licenses/LICENSE-2.0.html",
}
# Do basic verification of known values in the file
# Verify the description has entries we expect
self.assertLessEqual(description.items(),
hat_file1.description.to_table().items())
# Verify the list of functions
self.assertTrue(len(hat_file1.function_map) == 2)
self.assertTrue("GEMM_B94D27B9934D3E08" in hat_file1.function_map)
self.assertTrue("blas_sgemm_row_major" in hat_file1.function_map)
if __name__ == '__main__':
unittest.main()

Просмотреть файл

@ -1,6 +1,6 @@
# HAT TOML Schema
[toml-schema]
version = "0.0.0.2"
version = "0.0.0.3"
# Types to be used elsewhere in this schema
[types]
@ -66,6 +66,10 @@ version = "0.0.0.2"
[types.functionType]
type = "table"
##########
# Required
##########
# The name of the function
[types.functionType.name]
type = "string"
@ -77,8 +81,8 @@ version = "0.0.0.2"
# The calling convention for this function
[types.functionType.calling_convention]
type = "string"
allowedvalues = [ "stdcall", "cdecl", "fastcall", "vectorcall" ]
allowedvalues = [ "stdcall", "cdecl", "fastcall", "vectorcall", "device" ]
# An array of arguments to the function
[types.functionType.arguments]
type = "array"
@ -88,6 +92,30 @@ version = "0.0.0.2"
[types.functionType.return]
typeof = "paramType"
##########
# Optional
##########
# The parameters needed to launch this function, if applicable
[types.functionType.launch_parameters]
type = "array"
optional = true
# The function that is launched by this function
[types.functionType.launches]
type = "string"
optional = true
# The provider of this function, if any
[types.functionType.provider]
type = "string"
optional = true
# The runtime used by the function
[types.functionType.runtime]
type = "string"
optional = true
# Optional additional usage-specific information about the function that isn't part of this schema
[types.functionType.auxiliary]
type = "table"
@ -136,12 +164,18 @@ version = "0.0.0.2"
type = "table"
optional = true
# Collection of functions declared within the HAT file and their metadata
# Collection of host functions declared within the HAT file and their metadata
# The keys in a collection are not prescribed by the schema, and in this case are the names of the functions as the HAT format does not support function overloading.
[elements.functions]
type = "collection"
typeof = "functionType"
# Collection of device functions declared within the HAT file and their metadata
# The keys in a collection are not prescribed by the schema, and in this case are the names of the functions as the HAT format does not support function overloading.
[elements.device_functions]
type = "collection"
typeof = "functionType"
# Table of information about the target device the functions described in this HAT file are intended to be used with
[elements.target]
type = "table"
@ -168,11 +202,17 @@ version = "0.0.0.2"
type = "array"
arraytype = "string"
# Optional CPU runtime library
[elements.target.required.CPU.runtime]
type = "string"
allowedvalues = [ "openmp" ]
optional = true
# Optional additional information not defined in this schema
[elements.target.required.CPU.auxiliary]
type = "table"
optional = true
# Required GPU characteristics if there are GPU functions in this HAT package
[elements.target.required.GPU]
type = "table"
@ -194,7 +234,7 @@ version = "0.0.0.2"
# Minimum global memory in KB that will be allocated
[elements.target.required.GPU.min_global_memory_KB]
type = "integer"
# Minimum shared memory in KB that will be allocated
[elements.target.required.GPU.min_shared_memory_KB]
type = "integer"

Просмотреть файл

@ -29,7 +29,7 @@ install_requires =
tomlkit
vswhere; sys_platform == "win32"
package_dir =
hatlib = tools
hatlib = hatlib
[options.entry_points]
console_scripts =

Просмотреть файл

@ -1,173 +0,0 @@
#!/usr/bin/env python3
"""Loads a dynamically-linked HAT package in Python
Call 'load' to load a HAT package in Python. After loading, call the HAT functions using numpy
arrays as arguments. The shape, element type, and order of each numpy array should exactly match
the requirements of the HAT function.
For example:
import numpy as np
import hatlib as hat
# load the package
package = hat.load("my_package.hat")
# print the function names
for name in package.names:
print(name)
# create numpy arguments with the correct shape, dtype, and order
A = np.ones([256,32], dtype=np.float32, order="C")
B = np.ones([32,256], dtype=np.float32, order="C")
D = np.ones([256,32], dtype=np.float32, order="C")
E = np.ones([256,32], dtype=np.float32, order="C")
# call a package function named 'my_func_698b5e5c'
package.my_func_698b5e5c(A, B, D, E)
"""
import sys
import toml
import ctypes
import os
import numpy as np
from collections import OrderedDict
from dataclasses import dataclass
from typing import Any, Tuple
@dataclass
class ArgInfo:
"""Extracts necessary information from the description of a function argument in a hat file"""
hat_declared_type: str
numpy_shape: Tuple[int]
numpy_strides: Tuple[int]
numpy_dtype: type
element_num_bytes: int
ctypes_pointer_type: Any
def __init__(self, param_description):
self.hat_declared_type = param_description["declared_type"]
self.numpy_shape = tuple(param_description["shape"])
if self.hat_declared_type == "float16_t*":
self.numpy_dtype = np.float16
self.element_num_bytes = 2
self.ctypes_pointer_type = ctypes.POINTER(ctypes.c_uint16) # same bitwidth as float16
elif self.hat_declared_type == "float*":
self.numpy_dtype = np.float32
self.element_num_bytes = 4
self.ctypes_pointer_type = ctypes.POINTER(ctypes.c_float)
elif self.hat_declared_type == "double*":
self.numpy_dtype = np.float64
self.element_num_bytes = 8
self.ctypes_pointer_type = ctypes.POINTER(ctypes.c_double)
elif self.hat_declared_type == "int64_t*":
self.numpy_dtype = np.int64
self.element_num_bytes = 8
self.ctypes_pointer_type = ctypes.POINTER(ctypes.c_int64)
elif self.hat_declared_type == "int32_t*":
self.numpy_dtype = np.int32
self.element_num_bytes = 4
self.ctypes_pointer_type = ctypes.POINTER(ctypes.c_int32)
elif self.hat_declared_type == "int16_t*":
self.numpy_dtype = np.int16
self.element_num_bytes = 2
self.ctypes_pointer_type = ctypes.POINTER(ctypes.c_int16)
elif self.hat_declared_type == "int8_t*":
self.numpy_dtype = np.int8
self.element_num_bytes = 1
self.ctypes_pointer_type = ctypes.POINTER(ctypes.c_int8)
else:
raise NotImplementedError(f"Unsupported declared_type {self.hat_declared_type} in hat file")
self.numpy_strides = tuple([self.element_num_bytes * x for x in param_description["affine_map"]])
def verify_args(args, arg_infos, function_name):
""" Verifies that a list of arguments matches a list of argument descriptions in a HAT file
"""
# check number of args
if len(args) != len(arg_infos):
sys.exit(f"Error calling {function_name}(...): expected {len(arg_infos)} arguments but received {len(args)}")
# for each arg
for i in range(len(args)):
arg = args[i]
arg_info = arg_infos[i]
# confirm that the arg is a numpy ndarray
if not isinstance(arg, np.ndarray):
sys.exit("Error calling {function_name}(...): expected argument {i} to be <class 'numpy.ndarray'> but received {type(arg)}")
# confirm that the arg dtype matches the dexcription in the hat package
if arg_info.numpy_dtype != arg.dtype:
sys.exit(f"Error calling {function_name}(...): expected argument {i} to have dtype={arg_info.numpy_dtype} but received dtype={arg.dtype}")
# confirm that the arg shape is correct
if arg_info.numpy_shape != arg.shape:
sys.exit(f"Error calling {function_name}(...): expected argument {i} to have shape={arg_info.numpy_shape} but received shape={arg.shape}")
# confirm that the arg strides are correct
if arg_info.numpy_strides != arg.strides:
sys.exit(f"Error calling {function_name}(...): expected argument {i} to have strides={arg_info.numpy_strides} but received strides={arg.strides}")
def hat_description_to_python_function(hat_description, hat_library):
""" Creates a callable function based on a function description in a HAT package
"""
hat_arg_descriptions = hat_description["arguments"]
function_name = hat_description["name"]
def f(*args):
# verify that the (numpy) input args match the description in the hat file
arg_infos = [ArgInfo(d) for d in hat_arg_descriptions]
verify_args(args, arg_infos, function_name)
# prepare the args to the hat package
hat_args = [arg.ctypes.data_as(arg_info.ctypes_pointer_type) for arg, arg_info in zip (args, arg_infos)]
# call the function in the hat package
hat_library[function_name](*hat_args)
return f
class AttributeDict(OrderedDict):
""" Dictionary that allows entries to be accessed like attributes
"""
__getattr__ = OrderedDict.__getitem__
@property
def names(self):
return list(self.keys())
def __getitem__(self, key):
for k, v in self.items():
if k.startswith(key):
return v
return OrderedDict.__getitem__(key)
def load(hat_path):
""" Creates a class with static functions based on the function descriptions in a HAT package
"""
# load the function decscriptions from the hat file
hat_path = os.path.abspath(hat_path)
t = toml.load(hat_path)
function_descriptions = t["functions"]
hat_binary_filename = t["dependencies"]["link_target"]
hat_binary_path = os.path.join(os.path.dirname(hat_path), hat_binary_filename)
# check that the HAT library has a supported file extension
supported_extensions = [".dll", ".so"]
_, extension = os.path.splitext(hat_binary_path)
if extension not in supported_extensions:
sys.exit(f"Unsupported HAT library extension: {extension}")
# load the hat_library:
hat_library = ctypes.cdll.LoadLibrary(hat_binary_path)
# create dictionary of functions defined in the hat file
function_dict = AttributeDict({key : hat_description_to_python_function(val, hat_library) for key,val in function_descriptions.items()})
return function_dict

Просмотреть файл

@ -1,79 +0,0 @@
#!/usr/bin/env python3
from pathlib import Path
import unittest
import sys, os
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
from hat_file import CallingConventionType, CompiledWith, Declaration, Dependencies, Description, Function, FunctionTable, HATFile, OperatingSystem, Parameter, ParameterType, Target, UsageType
class HATFile_test(unittest.TestCase):
def test_file_basic_serialize(self):
# Construct a HAT file from scratch
# Start with a function definition
my_function = Function(name="my_function",
description="Some description",
calling_convention=CallingConventionType.StdCall,
return_info=Parameter(logical_type=ParameterType.RuntimeArray,
declared_type="float*",
element_type="float",
usage=UsageType.Input,
shape="[16, 16]",
affine_map=[16, 1],
size="16 * 16 * sizeof(float)"))
# Create the function table
functions = FunctionTable({"my_function" : my_function})
# Create the HATFile object
hat_file1 = HATFile(name="test_file",
description=Description(version="0.0.1",
author="me",
license_url="https://www.apache.org/licenses/LICENSE-2.0.html"),
_function_table=functions,
target=Target(
required=Target.Required(os=OperatingSystem.Windows,
cpu=Target.Required.CPU(architecture="Haswell", extensions=["AVX2"]),
gpu=None),
optimized_for=Target.OptimizedFor()),
dependencies=Dependencies(link_target="my_lib.lib"),
compiled_with=CompiledWith(compiler="VC++"),
declaration=Declaration(),
path=Path(".").resolve())
# Serialize it to disk
test_file_name = "test_file_serialize.hat"
try:
hat_file1.Serialize(test_file_name)
# Deserialize it and verify it has what we expect
hat_file2 = HATFile.Deserialize(test_file_name)
finally:
# Remove the file
os.remove(test_file_name)
# Do basic verification that the deserialized HatFile contains what we specified
# when we created the HATFile directly
self.assertEqual(hat_file1.description, hat_file2.description)
self.assertEqual(hat_file1.dependencies, hat_file2.dependencies)
self.assertEqual(hat_file1.compiled_with.to_table(), hat_file2.compiled_with.to_table())
self.assertTrue("my_function" in hat_file2.function_map)
def test_file_basic_deserialize(self):
# Load a HAT file from the samples directory
hat_file1 = HATFile.Deserialize(os.path.join(os.path.dirname(__file__), "..", "..", "samples", "sample_gemm_library.hat"))
description = {
"author": "John Doe",
"version": "1.2.3.5",
"license_url": "https://www.apache.org/licenses/LICENSE-2.0.html",
}
# Do basic verification of known values in the file
# Verify the description has entries we expect
self.assertLessEqual(description.items(), hat_file1.description.to_table().items())
# Verify the list of functions
self.assertTrue(len(hat_file1.functions) == 2)
self.assertTrue("GEMM_B94D27B9934D3E08" in hat_file1.function_map)
self.assertTrue("blas_sgemm_row_major" in hat_file1.function_map)
if __name__ == '__main__':
unittest.main()

Просмотреть файл

@ -1,79 +0,0 @@
#!/usr/bin/env python3
from pathlib import Path
import unittest
import sys, os
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
from hat_file import CallingConventionType, CompiledWith, Declaration, Dependencies, Description, Function, FunctionTable, HATFile, OperatingSystem, Parameter, ParameterType, Target, UsageType
class HATFile_test(unittest.TestCase):
def test_file_basic_serialize(self):
# Construct a HAT file from scratch
# Start with a function definition
my_function = Function(name="my_function",
description="Some description",
calling_convention=CallingConventionType.StdCall,
return_info=Parameter(logical_type=ParameterType.RuntimeArray,
declared_type="float*",
element_type="float",
usage=UsageType.Input,
shape="[16, 16]",
affine_map=[16, 1],
size="16 * 16 * sizeof(float)"))
# Create the function table
functions = FunctionTable({"my_function" : my_function})
# Create the HATFile object
hat_file1 = HATFile(name="test_file",
description=Description(version="0.0.1",
author="me",
license_url="https://www.apache.org/licenses/LICENSE-2.0.html"),
_function_table=functions,
target=Target(
required=Target.Required(os=OperatingSystem.Windows,
cpu=Target.Required.CPU(architecture="Haswell", extensions=["AVX2"]),
gpu=None),
optimized_for=Target.OptimizedFor()),
dependencies=Dependencies(link_target="my_lib.lib"),
compiled_with=CompiledWith(compiler="VC++"),
declaration=Declaration(),
path=Path(".").resolve())
# Serialize it to disk
test_file_name = "test_file_serialize.hat"
try:
hat_file1.Serialize(test_file_name)
# Deserialize it and verify it has what we expect
hat_file2 = HATFile.Deserialize(test_file_name)
finally:
# Remove the file
os.remove(test_file_name)
# Do basic verification that the deserialized HatFile contains what we specified
# when we created the HATFile directly
self.assertEqual(hat_file1.description, hat_file2.description)
self.assertEqual(hat_file1.dependencies, hat_file2.dependencies)
self.assertEqual(hat_file1.compiled_with.to_table(), hat_file2.compiled_with.to_table())
self.assertTrue("my_function" in hat_file2.function_map)
def test_file_basic_deserialize(self):
# Load a HAT file from the samples directory
hat_file1 = HATFile.Deserialize(os.path.join(os.path.dirname(__file__), "..", "..", "samples", "sample_gemm_library.hat"))
description = {
"author": "John Doe",
"version": "1.2.3.5",
"license_url": "https://www.apache.org/licenses/LICENSE-2.0.html",
}
# Do basic verification of known values in the file
# Verify the description has entries we expect
self.assertLessEqual(description.items(), hat_file1.description.to_table().items())
# Verify the list of functions
self.assertTrue(len(hat_file1.function_map) == 2)
self.assertTrue("GEMM_B94D27B9934D3E08" in hat_file1.function_map)
self.assertTrue("blas_sgemm_row_major" in hat_file1.function_map)
if __name__ == '__main__':
unittest.main()