Refactors and restructures code for tighter integration of HAT object model (#36)

Changes:
* utility functions in `hatlib` now make use of the HAT object model (OM). Previously, `tomlkit` was being using directly and was bypassing the checks and OM construction of child objects.
* `hatlib.load` now returns a `HATPackage` object that results from deserializing the HAT file passed in. Additionally, if functions can be loaded, a non-empty `AttributeDict` is returned as a second value.
* Moved test folder to be a sibling of library folder. Now the package no longer contains the tests when installed. Additionally, the library can be tested by running `python -m unittest discover test` in the root of the repository.
* Adds YAPF style config to setup.cfg and formatted all touched files with `yapf`
This commit is contained in:
Kern Handa 2022-03-30 13:50:00 -07:00 коммит произвёл GitHub
Родитель b2ebaf5323
Коммит 544fcc0e01
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
22 изменённых файлов: 656 добавлений и 656 удалений

4
.github/workflows/ci.yml поставляемый
Просмотреть файл

@ -31,8 +31,8 @@ jobs:
python -m pip install -r hatlib/requirements.txt
- name: Unittest
run: |
python -m pip install -r hatlib/test/requirements.txt
python -m unittest discover hatlib/test
python -m pip install -r test/requirements.txt
python -m unittest discover test
- name: Build whl
run: |
python -m pip install build

Просмотреть файл

@ -2,9 +2,9 @@ import ctypes
import numpy as np
import sys
from dataclasses import dataclass
from typing import Any, Tuple
from functools import reduce
from typing import List
from typing import Any, List, Tuple
from . import hat_file
@dataclass
@ -16,17 +16,16 @@ class ArgInfo:
numpy_dtype: type
element_num_bytes: int
ctypes_pointer_type: Any
usage: str = ""
usage: hat_file.UsageType = None
def __init__(self, param_description):
self.hat_declared_type = param_description["declared_type"]
self.numpy_shape = tuple(param_description["shape"])
self.usage = param_description["usage"]
def __init__(self, param_description: hat_file.Parameter):
self.hat_declared_type = param_description.declared_type
self.numpy_shape = tuple(param_description.shape)
self.usage = param_description.usage
if self.hat_declared_type == "float16_t*":
self.numpy_dtype = np.float16
self.element_num_bytes = 2
self.ctypes_pointer_type = ctypes.POINTER(
ctypes.c_uint16) # same bitwidth as float16
self.ctypes_pointer_type = ctypes.POINTER(ctypes.c_uint16) # same bitwidth as float16
elif self.hat_declared_type == "float*":
self.numpy_dtype = np.float32
self.element_num_bytes = 4
@ -53,23 +52,18 @@ class ArgInfo:
self.ctypes_pointer_type = ctypes.POINTER(ctypes.c_int8)
else:
raise NotImplementedError(
f"Unsupported declared_type {self.hat_declared_type} in hat file"
)
raise NotImplementedError(f"Unsupported declared_type {self.hat_declared_type} in hat file")
self.numpy_strides = tuple([
self.element_num_bytes * x for x in param_description["affine_map"]
])
self.numpy_strides = tuple([self.element_num_bytes * x for x in param_description.affine_map])
def verify_args(args, arg_infos, function_name):
# TODO: Update this to take a HATFunction instead, instead of arg_infos and function_name
def verify_args(args: List, arg_infos: List[ArgInfo], function_name: str):
""" Verifies that a list of arguments matches a list of argument descriptions in a HAT file
"""
# check number of args
if len(args) != len(arg_infos):
sys.exit(
f"Error calling {function_name}(...): expected {len(arg_infos)} arguments but received {len(args)}"
)
sys.exit(f"Error calling {function_name}(...): expected {len(arg_infos)} arguments but received {len(args)}")
# for each arg
for i in range(len(args)):
@ -99,24 +93,3 @@ def verify_args(args, arg_infos, function_name):
sys.exit(
f"Error calling {function_name}(...): expected argument {i} to have strides={arg_info.numpy_strides} but received strides={arg.strides}"
)
def generate_input_sets(parameters: List[ArgInfo],
input_sets_minimum_size_MB: int = 0,
num_additional: int = 0):
shapes_to_sizes = [
reduce(lambda x, y: x * y, p.numpy_shape) for p in parameters
]
set_size = reduce(lambda x, y: x + y, [
size * p.element_num_bytes
for size, p in zip(shapes_to_sizes, parameters)
])
num_input_sets = (input_sets_minimum_size_MB * 1024 * 1024 //
set_size) + 1 + num_additional
input_sets = [[
np.random.random(p.numpy_shape).astype(p.numpy_dtype)
for p in parameters
] for _ in range(num_input_sets)]
return input_sets[0] if len(input_sets) == 1 else input_sets

Просмотреть файл

@ -9,14 +9,8 @@ import toml
import traceback
from pathlib import Path
if __package__:
from .hat_file import HATFile
from .hat_to_dynamic import create_dynamic_package
from .hat import load, ArgInfo, generate_input_sets
else:
from hat_file import HATFile
from hat_to_dynamic import create_dynamic_package
from hat import load, ArgInfo, generate_input_sets
from .hat_file import HATFile
from .hat import load, generate_input_sets_for_func
class Benchmark:
@ -26,18 +20,13 @@ class Benchmark:
Requirements:
A compilation toolchain in your PATH: cl.exe & link.exe (Windows), gcc (Linux), or clang (macOS)
"""
def __init__(self, hat_path):
self.hat_path = Path(hat_path)
self.hat_package = load(self.hat_path)
self.hat_functions = self.hat_package.names
def __init__(self, hat_path: str):
self.hat_path = hat_path
self.hat_package, self.func_dict = load(self.hat_path)
self.hat_functions = self.func_dict.names
# create dictionary of function descriptions defined in the hat file
t = toml.load(self.hat_path)
function_descriptions = t["functions"]
self.hat_arg_descriptions = {key: [ArgInfo(
d) for d in val["arguments"]] for key, val in function_descriptions.items()}
self.function_descriptions = self.hat_package.hat_file.function_map
def run(self,
function_name: str,
@ -58,20 +47,22 @@ class Benchmark:
Mean duration in seconds,
Vector of timings in seconds for each batch that was run
"""
if function_name not in self.hat_package.names:
if function_name not in self.hat_functions:
raise ValueError(f"{function_name} is not found")
# TODO: support packing and unpacking functions
mean_elapsed_time, batch_timings = self._profile(
function_name, warmup_iterations, min_timing_iterations, min_time_in_sec, input_sets_minimum_size_MB)
function_name, warmup_iterations, min_timing_iterations,
min_time_in_sec, input_sets_minimum_size_MB)
print(
f"[Benchmarking] Mean duration per iteration: {mean_elapsed_time:.8f}s")
f"[Benchmarking] Mean duration per iteration: {mean_elapsed_time:.8f}s"
)
return mean_elapsed_time, batch_timings
def _profile(self, function_name, warmup_iterations, min_timing_iterations, min_time_in_sec, input_sets_minimum_size_MB):
def _profile(self, function_name, warmup_iterations, min_timing_iterations,
min_time_in_sec, input_sets_minimum_size_MB):
def get_perf_counter():
if hasattr(time, 'perf_counter_ns'):
perf_counter = time.perf_counter_ns
@ -81,19 +72,21 @@ class Benchmark:
perf_counter_scale = 1
return perf_counter, perf_counter_scale
parameters = self.hat_arg_descriptions[function_name]
func = self.function_descriptions[function_name]
# generate sufficient input sets to overflow the L3 cache, since we don't know the size of the model
# we'll make a guess based on the minimum input set size
input_sets = generate_input_sets(
parameters, input_sets_minimum_size_MB, num_additional=10)
input_sets = generate_input_sets_for_func(func,
input_sets_minimum_size_MB,
num_additional=10)
set_size = 0
for i in input_sets[0]:
set_size += i.size * i.dtype.itemsize
print(
f"[Benchmarking] Using {len(input_sets)} input sets, each {set_size} bytes")
f"[Benchmarking] Using {len(input_sets)} input sets, each {set_size} bytes"
)
perf_counter, perf_counter_scale = get_perf_counter()
print(
@ -101,10 +94,11 @@ class Benchmark:
for _ in range(warmup_iterations):
for calling_args in input_sets:
self.hat_package[function_name](*calling_args)
self.func_dict[function_name](*calling_args)
print(
f"[Benchmarking] Timing for at least {min_time_in_sec}s and at least {min_timing_iterations} iterations...")
f"[Benchmarking] Timing for at least {min_time_in_sec}s and at least {min_timing_iterations} iterations..."
)
start_time = perf_counter()
end_time = perf_counter()
@ -115,7 +109,7 @@ class Benchmark:
while ((end_time - start_time) / perf_counter_scale) < min_time_in_sec:
batch_start_time = perf_counter()
for _ in range(min_timing_iterations):
self.hat_package[function_name](*input_sets[i])
self.func_dict[function_name](*input_sets[i])
i = iterations % i_max
iterations += 1
end_time = perf_counter()
@ -140,13 +134,20 @@ def write_runtime_to_hat_file(hat_path, function_name, mean_time_secs):
# Workaround to remove extra empty lines
with open(hat_path, "r") as f:
lines = f.readlines()
lines = [lines[i] for i in range(len(lines)) if not(lines[i] == "\n"
and i < len(lines)-1 and lines[i+1] == "\n")]
lines = [
lines[i] for i in range(len(lines))
if not (lines[i] == "\n" and i < len(lines) -
1 and lines[i + 1] == "\n")
]
with open(hat_path, "w") as f:
f.writelines(lines)
def run_benchmark(hat_path, store_in_hat=False, batch_size=10, min_time_in_sec=10, input_sets_minimum_size_MB=50):
def run_benchmark(hat_path,
store_in_hat=False,
batch_size=10,
min_time_in_sec=10,
input_sets_minimum_size_MB=50):
results = []
benchmark = Benchmark(hat_path)
@ -157,71 +158,89 @@ def run_benchmark(hat_path, store_in_hat=False, batch_size=10, min_time_in_sec=1
continue
try:
_, batch_timings = benchmark.run(function_name,
warmup_iterations=batch_size,
min_timing_iterations=batch_size,
min_time_in_sec=min_time_in_sec,
input_sets_minimum_size_MB=input_sets_minimum_size_MB)
_, batch_timings = benchmark.run(
function_name,
warmup_iterations=batch_size,
min_timing_iterations=batch_size,
min_time_in_sec=min_time_in_sec,
input_sets_minimum_size_MB=input_sets_minimum_size_MB)
sorted_batch_means = np.array(sorted(batch_timings)) / batch_size
num_batches = len(batch_timings)
mean_of_means = sorted_batch_means.mean()
median_of_means = sorted_batch_means[num_batches//2]
mean_of_small_means = sorted_batch_means[0: num_batches//2].mean()
robust_mean_of_means = sorted_batch_means[num_batches //
5: -num_batches//5].mean()
median_of_means = sorted_batch_means[num_batches // 2]
mean_of_small_means = sorted_batch_means[0:num_batches // 2].mean()
robust_means = sorted_batch_means[(num_batches //
5):(-num_batches // 5)]
robust_mean_of_means = robust_means.mean()
min_of_means = sorted_batch_means[0]
if store_in_hat:
write_runtime_to_hat_file(
hat_path, function_name, mean_of_means)
results.append({"function_name": function_name,
"mean": mean_of_means,
"median_of_means": median_of_means,
"mean_of_small_means": mean_of_small_means,
"robust_mean": robust_mean_of_means,
"min_of_means": min_of_means,
})
write_runtime_to_hat_file(hat_path, function_name,
mean_of_means)
results.append({
"function_name": function_name,
"mean": mean_of_means,
"median_of_means": median_of_means,
"mean_of_small_means": mean_of_small_means,
"robust_mean": robust_mean_of_means,
"min_of_means": min_of_means,
})
except Exception as e:
exc_type, exc_val, exc_tb = sys.exc_info()
traceback.print_exception(
exc_type, exc_val, exc_tb, file=sys.stderr)
traceback.print_exception(exc_type,
exc_val,
exc_tb,
file=sys.stderr)
print("\nException message: ", e)
print(
f"WARNING: Failed to run function {function_name}, skipping this benchmark.")
f"WARNING: Failed to run function {function_name}, skipping this benchmark."
)
return results
def main(argv):
arg_parser = argparse.ArgumentParser(
description="Benchmarks each function in a HAT package and estimates its duration.\n"
description=
"Benchmarks each function in a HAT package and estimates its duration.\n"
"Example:\n"
" hatlib.benchmark_hat_package <hat_path>\n")
arg_parser.add_argument("hat_path",
help="Path to the HAT file",
default=None)
arg_parser.add_argument("--store_in_hat",
help="If set, will write the duration as meta-data back into the hat file",
action='store_true')
arg_parser.add_argument(
"--store_in_hat",
help=
"If set, will write the duration as meta-data back into the hat file",
action='store_true')
arg_parser.add_argument("--results_file",
help="Full path where the results will be written",
default="results.csv")
arg_parser.add_argument("--batch_size",
help="The number of function calls in each batch (at least one full batch is executed)",
default=10)
arg_parser.add_argument("--min_time_in_sec",
help="Minimum number of seconds to run the benchmark for",
default=30)
arg_parser.add_argument("--input_sets_minimum_size_MB",
help="Minimum size in MB of the input sets. Typically this is large enough to ensure eviction of the biggest cache on the target (e.g. L3 on an desktop CPU)",
default=50)
arg_parser.add_argument(
"--batch_size",
help=
"The number of function calls in each batch (at least one full batch is executed)",
default=10)
arg_parser.add_argument(
"--min_time_in_sec",
help="Minimum number of seconds to run the benchmark for",
default=30)
arg_parser.add_argument(
"--input_sets_minimum_size_MB",
help=
"Minimum size in MB of the input sets. Typically this is large enough to ensure eviction of the biggest cache on the target (e.g. L3 on an desktop CPU)",
default=50)
args = vars(arg_parser.parse_args(argv))
results = run_benchmark(args["hat_path"], args["store_in_hat"], batch_size=int(args["batch_size"]), min_time_in_sec=int(
args["min_time_in_sec"]), input_sets_minimum_size_MB=int(args["input_sets_minimum_size_MB"]))
results = run_benchmark(args["hat_path"],
args["store_in_hat"],
batch_size=int(args["batch_size"]),
min_time_in_sec=int(args["min_time_in_sec"]),
input_sets_minimum_size_MB=int(
args["input_sets_minimum_size_MB"]))
df = pd.DataFrame(results)
df.to_csv(args["results_file"], index=False)
pd.options.display.float_format = '{:8.8f}'.format

Просмотреть файл

@ -10,19 +10,15 @@ from typing import List
from pynvrtc.compiler import Program
from cuda import cuda, nvrtc
try:
from .arg_info import ArgInfo, verify_args
from .gpu_headers import CUDA_HEADER_MAP
except:
from arg_info import ArgInfo, verify_args
from gpu_headers import CUDA_HEADER_MAP
from .arg_info import ArgInfo, verify_args
from .gpu_headers import CUDA_HEADER_MAP
from .hat_file import Function
def ASSERT_DRV(err):
if isinstance(err, cuda.CUresult):
if err != cuda.CUresult.CUDA_SUCCESS:
raise RuntimeError("Cuda Error: {}".format(
cuda.cuGetErrorString(err)[1].decode('utf-8')))
raise RuntimeError("Cuda Error: {}".format(cuda.cuGetErrorString(err)[1].decode('utf-8')))
elif isinstance(err, nvrtc.nvrtcResult):
if err != nvrtc.nvrtcResult.NVRTC_SUCCESS:
raise RuntimeError("Nvrtc Error: {}".format(err))
@ -51,15 +47,12 @@ def _find_cuda_incl_path() -> pathlib.Path:
def compile_cuda_program(cuda_src_path: pathlib.Path, func_name):
src = cuda_src_path.read_text()
prog = Program(src=src,
name=func_name,
headers=CUDA_HEADER_MAP.values(),
include_names=CUDA_HEADER_MAP.keys())
prog = Program(src=src, name=func_name, headers=CUDA_HEADER_MAP.values(), include_names=CUDA_HEADER_MAP.keys())
ptx = prog.compile([
'-use_fast_math',
'-default-device',
'-std=c++11',
'-arch=sm_52', # TODO: is this needed?
'-arch=sm_52', # TODO: is this needed?
])
return ptx
@ -91,27 +84,20 @@ def get_func_from_ptx(ptx, func_name):
def _arg_size(arg_info: ArgInfo):
return arg_info.element_num_bytes * reduce(lambda x, y: x * y,
arg_info.numpy_shape)
return arg_info.element_num_bytes * reduce(lambda x, y: x * y, arg_info.numpy_shape)
def transfer_mem_host_to_cuda(device_args: List, host_args: List[np.array],
arg_infos: List[ArgInfo]):
for device_arg, host_arg, arg_info in zip(device_args, host_args,
arg_infos):
if 'input' in arg_info.usage:
err, = cuda.cuMemcpyHtoD(device_arg, host_arg.ctypes.data,
_arg_size(arg_info))
def transfer_mem_host_to_cuda(device_args: List, host_args: List[np.array], arg_infos: List[ArgInfo]):
for device_arg, host_arg, arg_info in zip(device_args, host_args, arg_infos):
if 'input' in arg_info.usage.value:
err, = cuda.cuMemcpyHtoD(device_arg, host_arg.ctypes.data, _arg_size(arg_info))
ASSERT_DRV(err)
def transfer_mem_cuda_to_host(device_args: List, host_args: List[np.array],
arg_infos: List[ArgInfo]):
for device_arg, host_arg, arg_info in zip(device_args, host_args,
arg_infos):
if 'output' in arg_info.usage:
err, = cuda.cuMemcpyDtoH(host_arg.ctypes.data, device_arg,
_arg_size(arg_info))
def transfer_mem_cuda_to_host(device_args: List, host_args: List[np.array], arg_infos: List[ArgInfo]):
for device_arg, host_arg, arg_info in zip(device_args, host_args, arg_infos):
if 'output' in arg_info.usage.value:
err, = cuda.cuMemcpyDtoH(host_arg.ctypes.data, device_arg, _arg_size(arg_info))
ASSERT_DRV(err)
@ -133,10 +119,12 @@ def device_args_to_ptr_list(device_args: List):
return ptrs
def create_loader_for_device_function(device_func, hat_details):
hat_path: pathlib.Path = hat_details.path
cuda_src_path: pathlib.Path = hat_path.parent / device_func["provider"]
func_name = device_func["name"]
def create_loader_for_device_function(device_func: Function, hat_dir_path: str):
if not device_func.provider:
raise RuntimeError("Expected a provider for the device function")
cuda_src_path: pathlib.Path = pathlib.Path(hat_dir_path) / device_func.provider
func_name = device_func.name
ptx = compile_cuda_program(cuda_src_path, func_name)
@ -144,16 +132,14 @@ def create_loader_for_device_function(device_func, hat_details):
kernel = get_func_from_ptx(ptx, func_name)
hat_arg_descriptions = device_func["arguments"]
hat_arg_descriptions = device_func.arguments
arg_infos = [ArgInfo(d) for d in hat_arg_descriptions]
launch_parameters = device_func["launch_parameters"]
launch_parameters = device_func.launch_parameters
def f(*args):
verify_args(args, arg_infos, func_name)
device_mem = allocate_cuda_mem(arg_infos)
transfer_mem_host_to_cuda(device_args=device_mem,
host_args=args,
arg_infos=arg_infos)
transfer_mem_host_to_cuda(device_args=device_mem, host_args=args, arg_infos=arg_infos)
ptrs = device_args_to_ptr_list(device_mem)
err, stream = cuda.cuStreamCreate(0)
@ -161,18 +147,16 @@ def create_loader_for_device_function(device_func, hat_details):
err, = cuda.cuLaunchKernel(
kernel,
*launch_parameters, # [ grid[x-z], block[x-z] ]
0, # dynamic shared memory
stream, # stream
ptrs.ctypes.data, # kernel arguments
0, # extra (ignore)
*launch_parameters, # [ grid[x-z], block[x-z] ]
0, # dynamic shared memory
stream, # stream
ptrs.ctypes.data, # kernel arguments
0, # extra (ignore)
)
ASSERT_DRV(err)
err, = cuda.cuStreamSynchronize(stream)
ASSERT_DRV(err)
transfer_mem_cuda_to_host(device_args=device_mem,
host_args=args,
arg_infos=arg_infos)
transfer_mem_cuda_to_host(device_args=device_mem, host_args=args, arg_infos=arg_infos)
return f

Просмотреть файл

@ -24,161 +24,66 @@ For example:
# call a package function named 'my_func_698b5e5c'
package.my_func_698b5e5c(A, B, D, E)
"""
import numpy as np
from typing import Tuple, Union
from functools import reduce
import ctypes
import pathlib
import sys
import toml
from collections import OrderedDict
from functools import partial
from . import hat_file
from . import hat_package
from .arg_info import ArgInfo
try:
from . import hat_file
from .arg_info import ArgInfo, verify_args, generate_input_sets
except:
import hat_file
from arg_info import ArgInfo, verify_args, generate_input_sets
try:
try:
from . import cuda_loader
except ModuleNotFoundError:
import cuda_loader
except:
CUDA_AVAILABLE = False
else:
CUDA_AVAILABLE = True
def generate_input_sets_for_func(func: hat_file.Function,
input_sets_minimum_size_MB: int = 0,
num_additional: int = 0):
parameters = list(map(ArgInfo, func.arguments))
shapes_to_sizes = [
reduce(lambda x, y: x * y, p.numpy_shape) for p in parameters
]
set_size = reduce(
lambda x, y: x + y,
map(lambda size, p: size * p.element_num_bytes, shapes_to_sizes,
parameters))
try:
try:
from . import rocm_loader
except ModuleNotFoundError:
import rocm_loader
except:
ROCM_AVAILABLE = False
else:
ROCM_AVAILABLE = True
num_input_sets = (input_sets_minimum_size_MB * 1024 * 1024 //
set_size) + 1 + num_additional
input_sets = [[
np.random.random(p.numpy_shape).astype(p.numpy_dtype)
for p in parameters
] for _ in range(num_input_sets)]
NOTIFY_ABOUT_CUDA = not CUDA_AVAILABLE
NOTIFY_ABOUT_ROCM = not ROCM_AVAILABLE
return input_sets[0] if len(input_sets) == 1 else input_sets
def generate_input_sets_for_hat_file(hat_path):
hat_path = pathlib.Path(hat_path).absolute()
t: hat_file.HATFile = toml.load(hat_path)
t = hat_file.HATFile.Deserialize(hat_path)
return {
func_name:
generate_input_sets(list(map(ArgInfo, func_desc["arguments"])))
for func_name, func_desc in t["functions"].items()
func_name: generate_input_sets_for_func(func_desc)
for func_name, func_desc in t.function_map.items()
}
class AttributeDict(OrderedDict):
""" Dictionary that allows entries to be accessed like attributes
def load(
hat_path,
try_dynamic_load=True
) -> Tuple[hat_package.HATPackage, Union[hat_package.AttributeDict, None]]:
"""
__getattr__ = OrderedDict.__getitem__
@property
def names(self):
return list(self.keys())
def __getitem__(self, key):
for k, v in self.items():
if k.startswith(key):
return v
return OrderedDict.__getitem__(key)
def hat_description_to_python_function(hat_description: hat_file.HATFile,
hat_details: AttributeDict):
""" Creates a callable function based on a function description in a HAT
package
Returns a HATPackage object loaded from the path provided. If
`try_dynamic_load` is True, a non-empty dictionary object that can be used
to invoke the functions in the HATPackage on the current system is the
second returned object, `None` otherwise.
"""
for func_name, func_desc in hat_description["functions"].items():
pkg = hat_package.HATPackage(hat_file_path=hat_path)
func_desc: hat_file.Function
func_name: str
# TODO: Add heuristics to determine whether loading is possible on this system
function_dict = None
launches = func_desc.get("launches")
if not launches:
hat_library: ctypes.CDLL = hat_details.shared_lib
if try_dynamic_load:
try:
function_dict = hat_package.hat_package_to_func_dict(pkg)
except:
# TODO: Figure out how to communicate failure better
pass
def f(function_name, hat_arg_descriptions, *args):
# verify that the (numpy) input args match the description in
# the hat file
arg_infos = [ArgInfo(d) for d in hat_arg_descriptions]
verify_args(args, arg_infos, function_name)
# prepare the args to the hat package
hat_args = [
arg.ctypes.data_as(arg_info.ctypes_pointer_type)
for arg, arg_info in zip(args, arg_infos)
]
# call the function in the hat package
hat_library[function_name](*hat_args)
yield func_name, partial(f, func_desc["name"], func_desc["arguments"])
else:
device_func = hat_description.get("device_functions",
{}).get(launches)
func_runtime = func_desc.get("runtime")
if not device_func:
raise RuntimeError(
f"Couldn't find device function for loader: " + launches)
if not func_runtime:
raise RuntimeError(f"Couldn't find runtime for loader: " +
launches)
if func_runtime == "CUDA":
global NOTIFY_ABOUT_CUDA
if CUDA_AVAILABLE:
yield (func_name,
cuda_loader.create_loader_for_device_function(
device_func, hat_details))
elif NOTIFY_ABOUT_CUDA:
print("CUDA functionality not available on this machine."
" Please install the cuda and pvnrtc python modules")
NOTIFY_ABOUT_CUDA = False
elif func_runtime == "ROCM":
global NOTIFY_ABOUT_ROCM
if ROCM_AVAILABLE:
yield (func_name,
rocm_loader.create_loader_for_device_function(
device_func, hat_details))
elif NOTIFY_ABOUT_ROCM:
print("ROCm functionality not available on this machine."
" Please install the ROCm 4.2 or higher")
NOTIFY_ABOUT_ROCM = False
def load(hat_path):
""" Creates a class with static functions based on the function
descriptions in a HAT package
"""
# load the function decscriptions from the hat file
hat_path = pathlib.Path(hat_path).absolute()
t: hat_file.HATFile = toml.load(hat_path)
hat_details = AttributeDict({"path": hat_path})
# function_descriptions = t["functions"]
hat_binary_filename = t["dependencies"]["link_target"]
hat_binary_path = hat_path.parent / hat_binary_filename
# check that the HAT library has a supported file extension
supported_extensions = [".dll", ".so"]
extension = hat_binary_path.suffix
if extension and extension not in supported_extensions:
sys.exit(f"Unsupported HAT library extension: {extension}")
# load the hat_library:
hat_library = ctypes.cdll.LoadLibrary(
str(hat_binary_path)) if extension else None
hat_details["shared_lib"] = hat_library
# create dictionary of functions defined in the hat file
function_dict = AttributeDict(
dict(hat_description_to_python_function(t, hat_details)))
return function_dict
return pkg, function_dict

Просмотреть файл

@ -1,12 +1,12 @@
#!/usr/bin/env python3
# Utility to parse the TOML metadata from HAT files
from enum import Enum
from dataclasses import dataclass, field
import os
from pathlib import Path
import tomlkit
from dataclasses import dataclass, field
from enum import Enum
from pathlib import Path
from typing import Dict, List
# TODO : type-checking on leaf node values
@ -105,7 +105,8 @@ class Description(AuxiliarySupportedTable):
author=table["author"],
version=table["version"],
license_url=table["license_url"],
auxiliary=AuxiliarySupportedTable.parse_auxiliary(table))
auxiliary=AuxiliarySupportedTable.parse_auxiliary(table)
)
@dataclass
@ -147,14 +148,9 @@ class Parameter:
# TODO : change "usage" to "role" in schema
@staticmethod
def parse_from_table(param_table):
required_table_entries = [
"name", "description", "logical_type", "declared_type",
"element_type", "usage"
]
required_table_entries = ["name", "description", "logical_type", "declared_type", "element_type", "usage"]
_check_required_table_entries(param_table, required_table_entries)
affine_array_required_table_entries = [
"shape", "affine_map", "affine_offset"
]
affine_array_required_table_entries = ["shape", "affine_map", "affine_offset"]
runtime_array_required_table_entries = ["size"]
name = param_table["name"]
@ -164,23 +160,23 @@ class Parameter:
element_type = param_table["element_type"]
usage = UsageType(param_table["usage"])
param = Parameter(name=name,
description=description,
logical_type=logical_type,
declared_type=declared_type,
element_type=element_type,
usage=usage)
param = Parameter(
name=name,
description=description,
logical_type=logical_type,
declared_type=declared_type,
element_type=element_type,
usage=usage
)
if logical_type == ParameterType.AffineArray:
_check_required_table_entries(param_table,
affine_array_required_table_entries)
_check_required_table_entries(param_table, affine_array_required_table_entries)
param.shape = param_table["shape"]
param.affine_map = param_table["affine_map"]
param.affine_offset = param_table["affine_offset"]
elif logical_type == ParameterType.RuntimeArray:
_check_required_table_entries(
param_table, runtime_array_required_table_entries)
_check_required_table_entries(param_table, runtime_array_required_table_entries)
param.size = param_table["size"]
return param
@ -189,7 +185,7 @@ class Parameter:
@dataclass
class Function(AuxiliarySupportedTable):
# required
arguments: list = field(default_factory=list)
arguments: List[Parameter] = field(default_factory=list)
calling_convention: CallingConventionType = None
description: str = ""
hat_file: any = None
@ -214,7 +210,7 @@ class Function(AuxiliarySupportedTable):
arg_array.append(arg_table)
table.add(
"arguments", arg_array
) # TODO : figure out why this isn't indenting after serialization in some cases
) # TODO : figure out why this isn't indenting after serialization in some cases
if self.launch_parameters:
table.add("launch_parameters", self.launch_parameters)
@ -236,44 +232,36 @@ class Function(AuxiliarySupportedTable):
@staticmethod
def parse_from_table(function_table):
required_table_entries = [
"name", "description", "calling_convention", "arguments", "return"
]
required_table_entries = ["name", "description", "calling_convention", "arguments", "return"]
_check_required_table_entries(function_table, required_table_entries)
arguments = [
Parameter.parse_from_table(param_table)
for param_table in function_table["arguments"]
]
arguments = [Parameter.parse_from_table(param_table) for param_table in function_table["arguments"]]
launch_parameters = function_table[
"launch_parameters"] if "launch_parameters" in function_table else []
launch_parameters = function_table["launch_parameters"] if "launch_parameters" in function_table else []
launches = function_table[
"launches"] if "launches" in function_table else ""
launches = function_table["launches"] if "launches" in function_table else ""
provider = function_table[
"provider"] if "provider" in function_table else ""
provider = function_table["provider"] if "provider" in function_table else ""
runtime = function_table[
"runtime"] if "runtime" in function_table else ""
runtime = function_table["runtime"] if "runtime" in function_table else ""
return_info = Parameter.parse_from_table(function_table["return"])
return Function(
name=function_table["name"],
description=function_table["description"],
calling_convention=CallingConventionType(
function_table["calling_convention"]),
calling_convention=CallingConventionType(function_table["calling_convention"]),
arguments=arguments,
return_info=return_info,
launch_parameters=launch_parameters,
launches=launches,
provider=provider,
runtime=runtime,
auxiliary=AuxiliarySupportedTable.parse_auxiliary(function_table))
auxiliary=AuxiliarySupportedTable.parse_auxiliary(function_table)
)
class FunctionTableCommon:
def __init__(self, function_map):
self.function_map = function_map
self.functions = self.function_map.values()
@ -281,15 +269,13 @@ class FunctionTableCommon:
def to_table(self):
func_table = tomlkit.table()
for function_key in self.function_map:
func_table.add(function_key,
self.function_map[function_key].to_table())
func_table.add(function_key, self.function_map[function_key].to_table())
return func_table
@classmethod
def parse_from_table(cls, all_functions_table):
function_map = {
function_key:
Function.parse_from_table(all_functions_table[function_key])
function_key: Function.parse_from_table(all_functions_table[function_key])
for function_key in all_functions_table
}
return cls(function_map)
@ -305,8 +291,10 @@ class DeviceFunctionTable(FunctionTableCommon):
@dataclass
class Target:
@dataclass
class Required:
@dataclass
class CPU:
TableName = TargetType.CPU.value
@ -336,9 +324,8 @@ class Target:
runtime = table.get("runtime", "")
return Target.Required.CPU(
architecture=table["architecture"],
extensions=table["extensions"],
runtime=runtime)
architecture=table["architecture"], extensions=table["extensions"], runtime=runtime
)
@dataclass
class GPU:
@ -357,8 +344,7 @@ class Target:
table.add("model", self.model)
table.add("runtime", self.runtime)
table.add("blocks", self.blocks)
table.add("instruction_set_version",
self.instruction_set_version)
table.add("instruction_set_version", self.instruction_set_version)
table.add("min_threads", self.min_threads)
table.add("min_global_memory_KB", self.min_global_memory_KB)
table.add("min_shared_memory_KB", self.min_shared_memory_KB)
@ -382,7 +368,8 @@ class Target:
min_threads=table["min_threads"],
min_global_memory_KB=table["min_global_memory_KB"],
min_shared_memory_KB=table["min_shared_memory_KB"],
min_texture_memory_KB=table["min_texture_memory_KB"])
min_texture_memory_KB=table["min_texture_memory_KB"]
)
TableName = "required"
os: OperatingSystem = None
@ -401,11 +388,9 @@ class Target:
def parse_from_table(table):
required_table_entries = ["os", Target.Required.CPU.TableName]
_check_required_table_entries(table, required_table_entries)
cpu_info = Target.Required.CPU.parse_from_table(
table[Target.Required.CPU.TableName])
cpu_info = Target.Required.CPU.parse_from_table(table[Target.Required.CPU.TableName])
if Target.Required.GPU.TableName in table:
gpu_info = Target.Required.GPU.parse_from_table(
table[Target.Required.GPU.TableName])
gpu_info = Target.Required.GPU.parse_from_table(table[Target.Required.GPU.TableName])
else:
gpu_info = Target.Required.GPU()
return Target.Required(os=table["os"], cpu=cpu_info, gpu=gpu_info)
@ -429,19 +414,16 @@ class Target:
table = tomlkit.table()
table.add(Target.Required.TableName, self.required.to_table())
if self.optimized_for is not None:
table.add(Target.OptimizedFor.TableName,
self.optimized_for.to_table())
table.add(Target.OptimizedFor.TableName, self.optimized_for.to_table())
return table
@staticmethod
def parse_from_table(target_table):
required_table_entries = [Target.Required.TableName]
_check_required_table_entries(target_table, required_table_entries)
required_data = Target.Required.parse_from_table(
target_table[Target.Required.TableName])
required_data = Target.Required.parse_from_table(target_table[Target.Required.TableName])
if Target.OptimizedFor.TableName in target_table:
optimized_for_data = Target.OptimizedFor.parse_from_table(
target_table[Target.OptimizedFor.TableName])
optimized_for_data = Target.OptimizedFor.parse_from_table(target_table[Target.OptimizedFor.TableName])
else:
optimized_for_data = Target.OptimizedFor()
return Target(required=required_data, optimized_for=optimized_for_data)
@ -462,9 +444,7 @@ class LibraryReference:
@staticmethod
def parse_from_table(table):
return LibraryReference(name=table["name"],
version=table["version"],
target_file=table["target_file"])
return LibraryReference(name=table["name"], version=table["version"], target_file=table["target_file"])
@dataclass
@ -490,17 +470,14 @@ class Dependencies(AuxiliarySupportedTable):
@staticmethod
def parse_from_table(dependencies_table):
required_table_entries = ["link_target", "deploy_files", "dynamic"]
_check_required_table_entries(dependencies_table,
required_table_entries)
dynamic = [
LibraryReference.parse_from_table(lib_ref_table)
for lib_ref_table in dependencies_table["dynamic"]
]
return Dependencies(link_target=dependencies_table["link_target"],
deploy_files=dependencies_table["deploy_files"],
dynamic=dynamic,
auxiliary=AuxiliarySupportedTable.parse_auxiliary(
dependencies_table))
_check_required_table_entries(dependencies_table, required_table_entries)
dynamic = [LibraryReference.parse_from_table(lib_ref_table) for lib_ref_table in dependencies_table["dynamic"]]
return Dependencies(
link_target=dependencies_table["link_target"],
deploy_files=dependencies_table["deploy_files"],
dynamic=dynamic,
auxiliary=AuxiliarySupportedTable.parse_auxiliary(dependencies_table)
)
@dataclass
@ -527,16 +504,16 @@ class CompiledWith:
@staticmethod
def parse_from_table(compiled_with_table):
required_table_entries = ["compiler", "flags", "crt", "libraries"]
_check_required_table_entries(compiled_with_table,
required_table_entries)
_check_required_table_entries(compiled_with_table, required_table_entries)
libraries = [
LibraryReference.parse_from_table(lib_ref_table)
for lib_ref_table in compiled_with_table["libraries"]
LibraryReference.parse_from_table(lib_ref_table) for lib_ref_table in compiled_with_table["libraries"]
]
return CompiledWith(compiler=compiled_with_table["compiler"],
flags=compiled_with_table["flags"],
crt=compiled_with_table["crt"],
libraries=libraries)
return CompiledWith(
compiler=compiled_with_table["compiler"],
flags=compiled_with_table["flags"],
crt=compiled_with_table["crt"],
libraries=libraries
)
@dataclass
@ -552,8 +529,7 @@ class Declaration:
@staticmethod
def parse_from_table(declaration_table):
required_table_entries = ["code"]
_check_required_table_entries(declaration_table,
required_table_entries)
_check_required_table_entries(declaration_table, required_table_entries)
return Declaration(code=declaration_table["code"])
@ -573,8 +549,8 @@ class HATFile:
_device_function_table: DeviceFunctionTable = None
functions: list = field(default_factory=list)
device_functions: list = field(default_factory=list)
function_map: dict = field(default_factory=dict)
device_function_map: list = field(default_factory=list)
function_map: Dict[str, Function] = field(default_factory=dict)
device_function_map: Dict[str, Function] = field(default_factory=dict)
target: Target = None
dependencies: Dependencies = None
compiled_with: CompiledWith = None
@ -589,8 +565,7 @@ class HATFile:
self.function_map = self._function_table.function_map
for func in self.functions:
func.hat_file = self
func.link_target = Path(self.path).resolve(
).parent / self.dependencies.link_target
func.link_target = Path(self.path).resolve().parent / self.dependencies.link_target
if not self._device_function_table:
self._device_function_table = DeviceFunctionTable({})
@ -604,11 +579,9 @@ class HATFile:
filepath = self.path
root_table = tomlkit.table()
root_table.add(Description.TableName, self.description.to_table())
root_table.add(FunctionTable.TableName,
self._function_table.to_table())
root_table.add(FunctionTable.TableName, self._function_table.to_table())
if self.device_function_map:
root_table.add(DeviceFunctionTable.TableName,
self._device_function_table.to_table())
root_table.add(DeviceFunctionTable.TableName, self._device_function_table.to_table())
root_table.add(Target.TableName, self.target.to_table())
root_table.add(Dependencies.TableName, self.dependencies.to_table())
root_table.add(CompiledWith.TableName, self.compiled_with.to_table())
@ -626,29 +599,22 @@ class HATFile:
hat_toml = _read_toml_file(filepath)
name = os.path.splitext(os.path.basename(filepath))[0]
required_entries = [
Description.TableName, FunctionTable.TableName, Target.TableName,
Dependencies.TableName, CompiledWith.TableName,
Declaration.TableName
Description.TableName, FunctionTable.TableName, Target.TableName, Dependencies.TableName,
CompiledWith.TableName, Declaration.TableName
]
_check_required_table_entries(hat_toml, required_entries)
device_function_table = None
if DeviceFunctionTable.TableName in hat_toml:
device_function_table = DeviceFunctionTable.parse_from_table(
hat_toml[DeviceFunctionTable.TableName])
device_function_table = DeviceFunctionTable.parse_from_table(hat_toml[DeviceFunctionTable.TableName])
hat_file = HATFile(
name=name,
description=Description.parse_from_table(
hat_toml[Description.TableName]),
_function_table=FunctionTable.parse_from_table(
hat_toml[FunctionTable.TableName]),
description=Description.parse_from_table(hat_toml[Description.TableName]),
_function_table=FunctionTable.parse_from_table(hat_toml[FunctionTable.TableName]),
_device_function_table=device_function_table,
target=Target.parse_from_table(
hat_toml[Target.TableName]),
dependencies=Dependencies.parse_from_table(
hat_toml[Dependencies.TableName]),
compiled_with=CompiledWith.parse_from_table(
hat_toml[CompiledWith.TableName]),
declaration=Declaration.parse_from_table(
hat_toml[Declaration.TableName]),
path=Path(filepath).resolve())
target=Target.parse_from_table(hat_toml[Target.TableName]),
dependencies=Dependencies.parse_from_table(hat_toml[Dependencies.TableName]),
compiled_with=CompiledWith.parse_from_table(hat_toml[CompiledWith.TableName]),
declaration=Declaration.parse_from_table(hat_toml[Declaration.TableName]),
path=Path(filepath).resolve()
)
return hat_file

Просмотреть файл

@ -2,11 +2,19 @@
# Utility to parse and validate a HAT package
from .hat_file import HATFile
import ctypes
from typing import List, Union
from collections import OrderedDict
from functools import partial
from .hat_file import HATFile, Function, Parameter
from .arg_info import ArgInfo, verify_args
import os
class HATPackage:
def __init__(self, hat_file_path):
"""A HAT Package is defined to be a HAT file and corresponding binary file, located in the same directory.
The binary file is specified in the HAT file's link_target attribute.
@ -18,16 +26,18 @@ class HATPackage:
self.hat_file = HATFile.Deserialize(hat_file_path)
self.link_target = self.hat_file.dependencies.link_target
self.link_target_path = os.path.join(os.path.split(self.hat_file_path)[0], self.hat_file.dependencies.link_target)
if not os.path.isfile(self.link_target_path):
raise ValueError(f"HAT file {self.hat_file_path} references link_target {self.hat_file.dependencies.link_target} which is not found in same directory as HAT file (expecting it to be in {os.path.split(self.hat_file_path)[0]}")
self.link_target_path = os.path.join(
os.path.split(self.hat_file_path)[0], self.hat_file.dependencies.link_target
)
self.functions = self.hat_file.functions
def get_functions(self):
return self.hat_file.functions
def get_functions_for_target(self, os: str, arch: str, required_extensions:list = []):
def get_functions_for_target(self, os: str, arch: str, required_extensions: list = []):
all_functions = self.get_functions()
def matches_target(hat_function):
hat_file = hat_function.hat_file
if hat_file.target.required.os != os or hat_file.target.required.cpu.architecture != arch:
@ -38,3 +48,140 @@ class HATPackage:
return True
return list(filter(matches_target, all_functions))
class AttributeDict(OrderedDict):
""" Dictionary that allows entries to be accessed like attributes
"""
__getattr__ = OrderedDict.__getitem__
@property
def names(self):
return list(self.keys())
def __getitem__(self, key):
for k, v in self.items():
if k.startswith(key):
return v
return OrderedDict.__getitem__(key)
def _make_cpu_func(shared_lib: ctypes.CDLL, function_name: str, arg_infos: List[Parameter]):
arg_infos = [ArgInfo(d) for d in arg_infos]
fn = shared_lib[function_name]
def f(*args):
# verify that the (numpy) input args match the description in
# the hat file
verify_args(args, arg_infos, function_name)
# prepare the args to the hat package
hat_args = [arg.ctypes.data_as(arg_info.ctypes_pointer_type) for arg, arg_info in zip(args, arg_infos)]
# call the function in the hat package
fn(*hat_args)
return f
def _make_device_func(func_runtime: str, hat_dir_path: str, func: Function):
if func_runtime == "CUDA":
from . import cuda_loader
return cuda_loader.create_loader_for_device_function(func, hat_dir_path)
elif func_runtime == "ROCM":
from . import rocm_loader
return rocm_loader.create_loader_for_device_function(func, hat_dir_path)
def _load_pkg_binary_module(hat_pkg: HATPackage):
shared_lib = None
if os.path.isfile(hat_pkg.link_target_path):
supported_extensions = [".dll", ".so"]
_, extension = os.path.splitext(hat_pkg.link_target_path)
if extension and extension not in supported_extensions:
# TODO: Should this be an error? Maybe just move on to the
# device function section?
raise RuntimeError(f"Unsupported HAT library extension: {extension}")
hat_binary_path = os.path.abspath(hat_pkg.link_target_path)
# load the hat_library:
hat_library = ctypes.cdll.LoadLibrary(hat_binary_path) if extension else None
shared_lib = hat_library
return shared_lib
def hat_package_to_func_dict(hat_pkg: HATPackage) -> AttributeDict:
try:
try:
from . import cuda_loader
except ModuleNotFoundError:
import cuda_loader
except:
CUDA_AVAILABLE = False
else:
CUDA_AVAILABLE = True
try:
try:
from . import rocm_loader
except ModuleNotFoundError:
import rocm_loader
except:
ROCM_AVAILABLE = False
else:
ROCM_AVAILABLE = True
NOTIFY_ABOUT_CUDA = not CUDA_AVAILABLE
NOTIFY_ABOUT_ROCM = not ROCM_AVAILABLE
# check that the HAT library has a supported file extension
func_dict = AttributeDict()
shared_lib = _load_pkg_binary_module(hat_pkg)
hat_dir_path, _ = os.path.split(hat_pkg.hat_file_path)
for func_name, func_desc in hat_pkg.hat_file.function_map.items():
launches = func_desc.launches
if not launches and shared_lib:
func_dict[func_name] = _make_cpu_func(shared_lib, func_desc.name, func_desc.arguments)
else:
device_func = hat_pkg.hat_file.device_function_map.get(launches)
func_runtime = func_desc.runtime
if not device_func:
raise RuntimeError(f"Couldn't find device function for loader: " + launches)
if not func_runtime:
raise RuntimeError(f"Couldn't find runtime for loader: " + launches)
# TODO: Generalize this concept to work so it's not CUDA/ROCM specific
if func_runtime == "CUDA" and not CUDA_AVAILABLE:
# TODO: printing to stdout only makes sense in tool mode
if NOTIFY_ABOUT_CUDA:
print(
"CUDA functionality not available on this machine. Please install the cuda and pvnrtc python modules"
)
NOTIFY_ABOUT_CUDA = False
continue
elif func_runtime == "ROCM" and not ROCM_AVAILABLE:
# TODO: printing to stdout only makes sense in tool mode
if NOTIFY_ABOUT_ROCM:
print("ROCm functionality not available on this machine. Please install ROCm 4.2 or higher")
NOTIFY_ABOUT_ROCM = False
continue
func_dict[func_name] = _make_device_func(
func_runtime=func_runtime, hat_dir_path=hat_dir_path, func=device_func
)
return func_dict

Просмотреть файл

@ -1,5 +1,4 @@
#!/usr/bin/env python3
"""Converts a statically-linked HAT package into a Dynamically-linked HAT package
HAT packages come in two varieties: statically-linked and dynamically-linked. A statically-linked
@ -24,21 +23,23 @@ import argparse
import shutil
from secrets import token_hex
if __package__:
from .hat_file import HATFile, OperatingSystem
from .platform_utilities import get_platform, ensure_compiler_in_path, run_command
else:
from hat_file import HATFile, OperatingSystem
from platform_utilities import get_platform, ensure_compiler_in_path, run_command
from .hat_file import HATFile, OperatingSystem
from .hat_package import HATPackage
from .platform_utilities import get_platform, ensure_compiler_in_path, run_command
def linux_create_dynamic_package(input_hat_path, input_hat_binary_path, output_hat_path, hat_file, quiet=True):
def linux_create_dynamic_package(input_hat_path,
input_hat_binary_path,
output_hat_path,
hat_file,
quiet=True):
"""Creates a dynamic HAT (.so) from a static HAT (.o/.a) on a Linux/macOS platform"""
# Confirm that this is a static hat library
_, extension = os.path.splitext(input_hat_binary_path)
if extension not in [".o", ".a"]:
sys.exit(
f"ERROR: Expected input library to have extension .o or .a, but received {input_hat_binary_path} instead")
f"ERROR: Expected input library to have extension .o or .a, but received {input_hat_binary_path} instead"
)
# Create a C source file to resolve inline functions defined in the static HAT package
include_path = os.path.dirname(input_hat_binary_path)
@ -48,7 +49,8 @@ def linux_create_dynamic_package(input_hat_path, input_hat_binary_path, output_h
f.write(f"#include <{os.path.basename(input_hat_path)}>")
# compile it separately so that we can suppress the warnings about the missing terminating ' character
run_command(
f'gcc -c -w -fPIC -o "{inline_obj_path}" -I"{include_path}" "{inline_c_path}"', quiet=quiet)
f'gcc -c -w -fPIC -o "{inline_obj_path}" -I"{include_path}" "{inline_c_path}"',
quiet=quiet)
# create new HAT binary
prefix, _ = os.path.splitext(output_hat_path)
@ -58,7 +60,8 @@ def linux_create_dynamic_package(input_hat_path, input_hat_binary_path, output_h
libraries = " ".join(
[d.target_file for d in hat_file.dependencies.dynamic])
run_command(
f'gcc -shared -fPIC -o "{output_hat_binary_path}" "{inline_obj_path}" "{input_hat_binary_path}" {libraries}', quiet=quiet)
f'gcc -shared -fPIC -o "{output_hat_binary_path}" "{inline_obj_path}" "{input_hat_binary_path}" {libraries}',
quiet=quiet)
# create new HAT file
# previous dependencies are now part of the binary
@ -67,15 +70,22 @@ def linux_create_dynamic_package(input_hat_path, input_hat_binary_path, output_h
output_hat_binary_path)
hat_file.Serialize(output_hat_path)
return HATPackage(output_hat_path)
def windows_create_dynamic_package(input_hat_path, input_hat_binary_path, output_hat_path, hat_file, quiet=True):
def windows_create_dynamic_package(input_hat_path,
input_hat_binary_path,
output_hat_path,
hat_file,
quiet=True):
"""Creates a Windows dynamic HAT package (.dll) from a static HAT package (.obj/.lib)"""
# Confirm that this is a static hat library
_, extension = os.path.splitext(input_hat_binary_path)
if extension not in [".obj", ".lib"]:
sys.exit(
f"ERROR: Expected input library to have extension .obj or .lib, but received {input_hat_binary_path} instead")
f"ERROR: Expected input library to have extension .obj or .lib, but received {input_hat_binary_path} instead"
)
# Create all file in a directory named build
if not os.path.exists("build"):
@ -91,9 +101,11 @@ def windows_create_dynamic_package(input_hat_path, input_hat_binary_path, output
# Resolve inline functions defined in the static HAT package
f.write("#include <{}>\n".format(os.path.basename(input_hat_path)))
f.write(
"BOOL APIENTRY DllMain(HMODULE, DWORD, LPVOID) { return TRUE; }\n")
"BOOL APIENTRY DllMain(HMODULE, DWORD, LPVOID) { return TRUE; }\n"
)
run_command(
f'cl.exe /nologo /I"{os.path.dirname(input_hat_path)}" /Fodllmain.obj /c dllmain.cpp', quiet=quiet)
f'cl.exe /nologo /I"{os.path.dirname(input_hat_path)}" /Fodllmain.obj /c dllmain.cpp',
quiet=quiet)
# create the new HAT binary dll
# always create a new dll (avoids case where dll is already loaded)
@ -121,36 +133,54 @@ def windows_create_dynamic_package(input_hat_path, input_hat_binary_path, output
finally:
os.chdir(cwd) # restore the current working directory
return HATPackage(output_hat_path)
def parse_args():
"""Parses and checks the command line arguments"""
parser = argparse.ArgumentParser(description="Creates a dynamically-linked HAT package from a statically-linked HAT package.\n"
"Example:\n"
" hatlib.hat_to_dynamic input.hat output.hat\n")
parser = argparse.ArgumentParser(
description=
"Creates a dynamically-linked HAT package from a statically-linked HAT package.\n"
"Example:\n"
" hatlib.hat_to_dynamic input.hat output.hat\n")
parser.add_argument("input_hat_path", type=str,
help="Path to the existing HAT file, which represents a statically-linked HAT package")
parser.add_argument("output_hat_path", type=str,
help="Path to the new HAT file, which will represent a dynamically-linked HAT package")
parser.add_argument('-v', "--verbose", action='store_true', help="Enable verbose output")
parser.add_argument(
"input_hat_path",
type=str,
help=
"Path to the existing HAT file, which represents a statically-linked HAT package"
)
parser.add_argument(
"output_hat_path",
type=str,
help=
"Path to the new HAT file, which will represent a dynamically-linked HAT package"
)
parser.add_argument('-v',
"--verbose",
action='store_true',
help="Enable verbose output")
args = parser.parse_args()
# check args
if not os.path.exists(args.input_hat_path):
sys.exit(f"ERROR: File {args.input_hat_path} not found")
if os.path.abspath(args.input_hat_path) == os.path.abspath(args.output_hat_path):
if os.path.abspath(args.input_hat_path) == os.path.abspath(
args.output_hat_path):
sys.exit("ERROR: Output file must be different from input file")
_, extension = os.path.splitext(args.input_hat_path)
if extension != ".hat":
sys.exit(
f"ERROR: Expected input file to have extension .hat, but received {extension} instead")
f"ERROR: Expected input file to have extension .hat, but received {extension} instead"
)
_, extension = os.path.splitext(args.output_hat_path)
if extension != ".hat":
sys.exit(
f"ERROR: Expected output file to have extension .hat, but received {extension} instead")
f"ERROR: Expected output file to have extension .hat, but received {extension} instead"
)
return args
@ -165,22 +195,30 @@ def create_dynamic_package(input_hat_path, output_hat_path, quiet=True):
# get the absolute path to the input binary
input_hat_binary_filename = hat_file.dependencies.link_target
input_hat_binary_path = os.path.join(
os.path.dirname(input_hat_path), input_hat_binary_filename)
input_hat_binary_path = os.path.join(os.path.dirname(input_hat_path),
input_hat_binary_filename)
# create the dynamic package
output_hat_path = os.path.abspath(output_hat_path)
if platform == OperatingSystem.Windows:
windows_create_dynamic_package(
input_hat_path, input_hat_binary_path, output_hat_path, hat_file, quiet=quiet)
windows_create_dynamic_package(input_hat_path,
input_hat_binary_path,
output_hat_path,
hat_file,
quiet=quiet)
elif platform in [OperatingSystem.Linux, OperatingSystem.MacOS]:
linux_create_dynamic_package(
input_hat_path, input_hat_binary_path, output_hat_path, hat_file, quiet=quiet)
linux_create_dynamic_package(input_hat_path,
input_hat_binary_path,
output_hat_path,
hat_file,
quiet=quiet)
def main():
args = parse_args()
create_dynamic_package(args.input_hat_path, args.output_hat_path, quiet=not args.verbose)
create_dynamic_package(args.input_hat_path,
args.output_hat_path,
quiet=not args.verbose)
if __name__ == "__main__":

Просмотреть файл

@ -23,12 +23,8 @@ import os
import argparse
import shutil
if __package__:
from .hat_file import HATFile, OperatingSystem
from .platform_utilities import get_platform, ensure_compiler_in_path, run_command
else:
from hat_file import HATFile, OperatingSystem
from platform_utilities import get_platform, ensure_compiler_in_path, run_command
from .hat_file import HATFile, OperatingSystem
from .platform_utilities import get_platform, ensure_compiler_in_path, run_command
def linux_create_static_package(input_hat_binary_path, output_hat_path, hat_file, quiet=True):

Просмотреть файл

@ -8,10 +8,7 @@ import shutil
import subprocess
import sys
if __package__:
from .hat_file import OperatingSystem
else:
from hat_file import OperatingSystem
from .hat_file import OperatingSystem
def _preprocess_command(command_to_run, shell):
@ -33,13 +30,8 @@ def _dump_file_contents(iostream):
def run_command(
command_to_run,
working_directory=None,
stdout=None,
stderr=None,
shell=False,
pretend=False,
quiet=True):
command_to_run, working_directory=None, stdout=None, stderr=None, shell=False, pretend=False, quiet=True
):
if not working_directory:
working_directory = os.getcwd()
@ -50,20 +42,14 @@ def run_command(
command_to_run = _preprocess_command(command_to_run, shell)
if not pretend:
with subprocess.Popen(
command_to_run,
close_fds=(platform.system() != "Windows"),
shell=shell,
stdout=stdout,
stderr=stderr,
cwd=working_directory) as proc:
with subprocess.Popen(command_to_run, close_fds=(platform.system() != "Windows"), shell=shell, stdout=stdout,
stderr=stderr, cwd=working_directory) as proc:
proc.wait()
if proc.returncode:
_dump_file_contents(stderr)
_dump_file_contents(stdout)
raise subprocess.CalledProcessError(
proc.returncode, command_to_run)
raise subprocess.CalledProcessError(proc.returncode, command_to_run)
def get_platform():
@ -84,8 +70,7 @@ def linux_ensure_compiler_in_path():
Prompts the user if not found."""
compiler = os.environ.get("CXX") or (os.environ.get("CC") or "gcc")
if not shutil.which(compiler):
sys.exit(
'ERROR: Could not find any valid C or C++ compiler, please install gcc before continuing')
sys.exit('ERROR: Could not find any valid C or C++ compiler, please install gcc before continuing')
def windows_ensure_compiler_in_path():
@ -94,15 +79,14 @@ def windows_ensure_compiler_in_path():
import vswhere
vs_path = vswhere.get_latest_path()
if not vs_path:
sys.exit(
"ERROR: Could not find Visual Studio, please ensure that you have Visual Studio installed")
sys.exit("ERROR: Could not find Visual Studio, please ensure that you have Visual Studio installed")
# Check if cl.exe is in PATH
if not shutil.which("cl"): # returns 0 if found, !0 otherwise
vcvars_script_path = os.path.join(
vs_path, r"VC\Auxiliary\Build\vcvars64.bat")
if not shutil.which("cl"): # returns 0 if found, !0 otherwise
vcvars_script_path = os.path.join(vs_path, r"VC\Auxiliary\Build\vcvars64.bat")
sys.exit(
f'ERROR: Could not find cl.exe, please run "{vcvars_script_path}" (including quotes) to setup your command prompt')
f'ERROR: Could not find cl.exe, please run "{vcvars_script_path}" (including quotes) to setup your command prompt'
)
def ensure_compiler_in_path():

Просмотреть файл

@ -4,7 +4,8 @@ lifted from https://github.com/jatinx/PyHIP
TODO: move to a submodule
"""
import sys, ctypes
import ctypes
import sys
_libhip_libname = 'libamdhip64.so'
@ -15,7 +16,7 @@ else:
# Currently we do not support windows, mainly because I do not have a windows build of hip
raise RuntimeError('Only linux is supported')
if _libhip == None:
if _libhip is None:
raise OSError('hiprtc library not found')
@ -482,7 +483,7 @@ def hipMalloc(count, ctype=None):
ptr = ctypes.c_void_p()
status = _libhip.hipMalloc(ctypes.byref(ptr), count)
hipCheckStatus(status)
if ctype != None:
if ctype is not None:
ptr = ctypes.cast(ptr, ctypes.POINTER(ctype))
return ptr

Просмотреть файл

@ -4,7 +4,8 @@ lifted from https://github.com/jatinx/PyHIP
TODO: move to a submodule
"""
import sys, ctypes
import ctypes
import sys
# _libhiprtc_libname = 'libhiprtc.so' # Currently its the same library
_libhiprtc_libname = 'libamdhip64.so'
@ -16,7 +17,7 @@ else:
# Currently we do not support windows, mainly because I do not have a windows build of hip
raise RuntimeError('Only linux is supported')
if _libhiprtc == None:
if _libhiprtc is None:
raise OSError('hiprtc library not found')

Просмотреть файл

@ -4,21 +4,15 @@ import numpy as np
from functools import reduce
from typing import List
try:
from .arg_info import ArgInfo, verify_args
from .gpu_headers import ROCM_HEADER_MAP
from .pyhip_hip import *
from .pyhip_hiprtc import *
except ModuleNotFoundError:
from arg_info import ArgInfo, verify_args
from gpu_headers import ROCM_HEADER_MAP
from pyhip_hip import *
from pyhip_hiprtc import *
from .arg_info import ArgInfo, verify_args
from .hat_file import Function
from .gpu_headers import ROCM_HEADER_MAP
from .pyhip.hip import *
from .pyhip.hiprtc import *
def _arg_size(arg_info: ArgInfo):
return arg_info.element_num_bytes * reduce(lambda x, y: x * y,
arg_info.numpy_shape)
return arg_info.element_num_bytes * reduce(lambda x, y: x * y, arg_info.numpy_shape)
def initialize_rocm():
@ -29,15 +23,14 @@ def initialize_rocm():
def compile_rocm_program(rocm_src_path: pathlib.Path, func_name):
src = rocm_src_path.read_text()
prog = hiprtcCreateProgram(source=src,
name=func_name + ".cu",
header_names=ROCM_HEADER_MAP.keys(),
header_sources=ROCM_HEADER_MAP.values())
prog = hiprtcCreateProgram(
source=src,
name=func_name + ".cu",
header_names=ROCM_HEADER_MAP.keys(),
header_sources=ROCM_HEADER_MAP.values()
)
device_properties = hipGetDeviceProperties(0)
hiprtcCompileProgram(prog, [
f'--offload-arch={device_properties.gcnArchName}',
'-D__HIP_PLATFORM_AMD__'
])
hiprtcCompileProgram(prog, [f'--offload-arch={device_properties.gcnArchName}', '-D__HIP_PLATFORM_AMD__'])
print(hiprtcGetProgramLog(prog))
code = hiprtcGetCode(prog)
@ -59,24 +52,16 @@ def allocate_rocm_mem(arg_infos: List[ArgInfo]):
return device_mem
def transfer_mem_host_to_rocm(device_args: List, host_args: List[np.array],
arg_infos: List[ArgInfo]):
for device_arg, host_arg, arg_info in zip(device_args, host_args,
arg_infos):
if 'input' in arg_info.usage:
hipMemcpy_htod(dst=device_arg,
src=host_arg.ctypes.data,
count=_arg_size(arg_info))
def transfer_mem_host_to_rocm(device_args: List, host_args: List[np.array], arg_infos: List[ArgInfo]):
for device_arg, host_arg, arg_info in zip(device_args, host_args, arg_infos):
if 'input' in arg_info.usage.value:
hipMemcpy_htod(dst=device_arg, src=host_arg.ctypes.data, count=_arg_size(arg_info))
def transfer_mem_rocm_to_host(device_args: List, host_args: List[np.array],
arg_infos: List[ArgInfo]):
for device_arg, host_arg, arg_info in zip(device_args, host_args,
arg_infos):
if 'output' in arg_info.usage:
hipMemcpy_dtoh(dst=host_arg.ctypes.data,
src=device_arg,
count=_arg_size(arg_info))
def transfer_mem_rocm_to_host(device_args: List, host_args: List[np.array], arg_infos: List[ArgInfo]):
for device_arg, host_arg, arg_info in zip(device_args, host_args, arg_infos):
if 'output' in arg_info.usage.value:
hipMemcpy_dtoh(dst=host_arg.ctypes.data, src=device_arg, count=_arg_size(arg_info))
def device_args_to_ptr_list(device_args: List):
@ -86,10 +71,12 @@ def device_args_to_ptr_list(device_args: List):
return ptrs
def create_loader_for_device_function(device_func, hat_details):
hat_path: pathlib.Path = hat_details.path
rocm_src_path: pathlib.Path = hat_path.parent / device_func["provider"]
func_name = device_func["name"]
def create_loader_for_device_function(device_func: Function, hat_dir_path: str):
if not device_func.provider:
raise RuntimeError("Expected a provider for the device function")
rocm_src_path: pathlib.Path = pathlib.Path(hat_dir_path) / device_func.provider
func_name = device_func.name
rocm_program = compile_rocm_program(rocm_src_path, func_name)
@ -97,33 +84,28 @@ def create_loader_for_device_function(device_func, hat_details):
kernel = get_func_from_rocm_program(rocm_program, func_name)
hat_arg_descriptions = device_func["arguments"]
hat_arg_descriptions = device_func.arguments
arg_infos = [ArgInfo(d) for d in hat_arg_descriptions]
launch_parameters = device_func["launch_parameters"]
launch_parameters = device_func.launch_parameters
class DataStruct(ctypes.Structure):
_fields_ = [(f"arg{i}", ctypes.c_void_p)
for i in range(len(arg_infos))]
_fields_ = [(f"arg{i}", ctypes.c_void_p) for i in range(len(arg_infos))]
def f(*args):
verify_args(args, arg_infos, func_name)
device_mem = allocate_rocm_mem(arg_infos)
transfer_mem_host_to_rocm(device_args=device_mem,
host_args=args,
arg_infos=arg_infos)
transfer_mem_host_to_rocm(device_args=device_mem, host_args=args, arg_infos=arg_infos)
data = DataStruct(*device_mem)
hipModuleLaunchKernel(
kernel,
*launch_parameters, # [ grid[x-z], block[x-z] ]
0, # dynamic shared memory
0, # stream
data, # data
*launch_parameters, # [ grid[x-z], block[x-z] ]
0, # dynamic shared memory
0, # stream
data, # data
)
hipDeviceSynchronize()
transfer_mem_rocm_to_host(device_args=device_mem,
host_args=args,
arg_infos=arg_infos)
transfer_mem_rocm_to_host(device_args=device_mem, host_args=args, arg_infos=arg_infos)
return f

Просмотреть файл

@ -5,10 +5,7 @@ from ast import arg
import enum
import sys
if __package__:
from . import hat
else:
import hat
from . import hat
def verify_hat_package(hat_path):
@ -20,29 +17,25 @@ def verify_hat_package(hat_path):
print("Inputs before function call:")
for i, func_input in enumerate(func_inputs):
print(
f"\tInput {i}: {','.join(map(str, func_input.flatten()[:32]))}"
)
print(f"\tInput {i}: {','.join(map(str, func_input.ravel()[:32]))}")
fn(*inputs[name])
print("Inputs after function call:")
for i, func_input in enumerate(func_inputs):
print(
f"\tInput {i}: {','.join(map(str, func_input.flatten()[:32]))}"
)
print(f"\tInput {i}: {','.join(map(str, func_input.ravel()[:32]))}")
def main():
arg_parser = argparse.ArgumentParser(
description="Executes every available function in the hat package \
with randomized inputs. Meant for quick verification.\n"
"Example:\n"
" hatlib.verify_hat_package <hat_path>\n")
description=(
"Executes every available function in the hat package with randomized inputs. Meant for quick verification.\n"
"Example:\n"
" hatlib.verify_hat_package <hat_path>\n"
)
)
arg_parser.add_argument("hat_path",
help="Path to the HAT file",
default=None)
arg_parser.add_argument("hat_path", help="Path to the HAT file", default=None)
args = vars(arg_parser.parse_args())
verify_hat_package(args["hat_path"])

Просмотреть файл

@ -4,4 +4,21 @@ build-backend = "setuptools.build_meta"
[tool.setuptools_scm]
version_scheme = "python-simplified-semver"
local_scheme = "no-local-version"
local_scheme = "no-local-version"
[tool.yapf]
allow_multiline_dictionary_keys = false
allow_split_before_dict_value = true
based_on_style = "pep8"
coalesce_brackets = true
column_limit = 120
dedent_closing_brackets = true
each_dict_entry_on_separate_line = true
force_multiline_dict = true
indent_dictionary_value = true
spaces_before_comment = 4
split_before_first_argument = true
split_before_logical_operator = true
[tool.pydocstyle]
max-line-length = 120

Просмотреть файл

@ -36,3 +36,6 @@ console_scripts =
hatlib.benchmark_hat = hatlib.benchmark_hat_package:main_command
hatlib.hat_to_dynamic = hatlib.hat_to_dynamic:main
hatlib.verify_hat = hatlib.verify_hat_package:main
[pycodestyle]
max-line-length = 120

Просмотреть файл

Просмотреть файл

Просмотреть файл

@ -3,13 +3,11 @@ import unittest
import os
import sys
import accera as acc
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
from benchmark_hat_package import run_benchmark
from hatlib import run_benchmark
class BenchmarkHATPackage_test(unittest.TestCase):
def test_benchmark(self):
A = acc.Array(role=acc.Array.Role.INPUT, shape=(256, 256))
B = acc.Array(role=acc.Array.Role.INPUT, shape=(256, 256))
@ -24,21 +22,23 @@ class BenchmarkHATPackage_test(unittest.TestCase):
package = acc.Package()
package.add(nest, args=(A, B, C), base_name="test_function")
package.build(name="BenchmarkHATPackage_test_benchmark",
output_dir="test_acccgen",
format=acc.Package.Format.HAT_DYNAMIC)
package.build(
name="BenchmarkHATPackage_test_benchmark", output_dir="test_acccgen", format=acc.Package.Format.HAT_DYNAMIC
)
run_benchmark("test_acccgen/BenchmarkHATPackage_test_benchmark.hat",
store_in_hat=False,
batch_size=2,
min_time_in_sec=1,
input_sets_minimum_size_MB=1)
run_benchmark(
"test_acccgen/BenchmarkHATPackage_test_benchmark.hat",
store_in_hat=False,
batch_size=2,
min_time_in_sec=1,
input_sets_minimum_size_MB=1
)
def test_benchmark_multiple_functions(self):
A = acc.Array(role=acc.Array.Role.INPUT, shape=(256, 256))
B = acc.Array(role=acc.Array.Role.INPUT, shape=(256, 256))
C = acc.Array(role=acc.Array.Role.INPUT_OUTPUT, shape=(256, 256))
D = acc.Array(role=acc.Array.Role.INPUT, shape=(256, 256)) # dummy argument
D = acc.Array(role=acc.Array.Role.INPUT, shape=(256, 256)) # dummy argument
nest = acc.Nest(shape=(256, 256, 256))
i, j, k = nest.get_indices()
@ -54,15 +54,17 @@ class BenchmarkHATPackage_test(unittest.TestCase):
# with the correct signature
package.add(nest, args=(A, B, C, D), base_name="test_function_dummy")
package.build(name="BenchmarkHATPackage_test_benchmark",
output_dir="test_acccgen",
format=acc.Package.Format.HAT_DYNAMIC)
package.build(
name="BenchmarkHATPackage_test_benchmark", output_dir="test_acccgen", format=acc.Package.Format.HAT_DYNAMIC
)
run_benchmark("test_acccgen/BenchmarkHATPackage_test_benchmark.hat",
store_in_hat=False,
batch_size=2,
min_time_in_sec=1,
input_sets_minimum_size_MB=1)
run_benchmark(
"test_acccgen/BenchmarkHATPackage_test_benchmark.hat",
store_in_hat=False,
batch_size=2,
min_time_in_sec=1,
input_sets_minimum_size_MB=1
)
if __name__ == '__main__':

Просмотреть файл

@ -4,12 +4,7 @@ import numpy as np
import os
import sys
import unittest
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
from hat import load
from hat_to_dynamic import create_dynamic_package
from hat_to_lib import create_static_package
from hatlib import load, create_dynamic_package, create_static_package
class HAT_test(unittest.TestCase):
@ -32,17 +27,13 @@ class HAT_test(unittest.TestCase):
for mode in [acc.Package.Mode.RELEASE, acc.Package.Mode.DEBUG]:
package_name = f"HAT_test_load_{mode.value}"
package.build(name=package_name,
output_dir="test_acccgen",
format=acc.Package.Format.HAT_STATIC,
mode=mode)
package.build(name=package_name, output_dir="test_acccgen", format=acc.Package.Format.HAT_STATIC, mode=mode)
create_dynamic_package(f"test_acccgen/{package_name}.hat",
f"test_acccgen/{package_name}.dyn.hat")
create_dynamic_package(f"test_acccgen/{package_name}.hat", f"test_acccgen/{package_name}.dyn.hat")
hat_package = load(f"test_acccgen/{package_name}.dyn.hat")
_, func_dict = load(f"test_acccgen/{package_name}.dyn.hat")
for name in hat_package.names:
for name in func_dict.names:
print(name)
# create numpy arguments with the correct shape and dtype
@ -51,7 +42,7 @@ class HAT_test(unittest.TestCase):
B_ref = B + A
# find the function by basename
test_function = hat_package["test_function"]
test_function = func_dict["test_function"]
test_function(A, B)
# check for correctness
@ -59,7 +50,7 @@ class HAT_test(unittest.TestCase):
# find the function by actual name
B_ref = B + A
test_function1 = hat_package[function.name]
test_function1 = func_dict[function.name]
test_function1(A, B)
# check for correctness

Просмотреть файл

@ -4,13 +4,10 @@ import os
import sys
import unittest
from pathlib import Path
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
from hat_file import (CallingConventionType, CompiledWith, Declaration,
Dependencies, Description, Function, FunctionTable,
HATFile, OperatingSystem, Parameter, ParameterType,
Target, UsageType)
from hatlib import (
CallingConventionType, CompiledWith, Declaration, Dependencies, Description, Function, FunctionTable, HATFile,
OperatingSystem, Parameter, ParameterType, Target, UsageType
)
class HATFile_test(unittest.TestCase):
@ -22,34 +19,38 @@ class HATFile_test(unittest.TestCase):
name="my_function",
description="Some description",
calling_convention=CallingConventionType.StdCall,
return_info=Parameter(logical_type=ParameterType.RuntimeArray,
declared_type="float*",
element_type="float",
usage=UsageType.Input,
shape="[16, 16]",
affine_map=[16, 1],
size="16 * 16 * sizeof(float)"))
return_info=Parameter(
logical_type=ParameterType.RuntimeArray,
declared_type="float*",
element_type="float",
usage=UsageType.Input,
shape="[16, 16]",
affine_map=[16, 1],
size="16 * 16 * sizeof(float)"
)
)
# Create the function table
functions = FunctionTable({"my_function": my_function})
# Create the HATFile object
hat_file1 = HATFile(
name="test_file",
description=Description(
version="0.0.1",
author="me",
license_url="https://www.apache.org/licenses/LICENSE-2.0.html"
version="0.0.1", author="me", license_url="https://www.apache.org/licenses/LICENSE-2.0.html"
),
_function_table=functions,
target=Target(required=Target.Required(os=OperatingSystem.Windows,
cpu=Target.Required.CPU(
architecture="Haswell",
extensions=["AVX2"]),
gpu=None),
optimized_for=Target.OptimizedFor()),
target=Target(
required=Target.Required(
os=OperatingSystem.Windows,
cpu=Target.Required.CPU(architecture="Haswell", extensions=["AVX2"]),
gpu=None
),
optimized_for=Target.OptimizedFor()
),
dependencies=Dependencies(link_target="my_lib.lib"),
compiled_with=CompiledWith(compiler="VC++"),
declaration=Declaration(),
path=Path(".").resolve())
path=Path(".").resolve()
)
# Serialize it to disk
test_file_name = "test_file_serialize.hat"
@ -65,15 +66,14 @@ class HATFile_test(unittest.TestCase):
# specified when we created the HATFile directly
self.assertEqual(hat_file1.description, hat_file2.description)
self.assertEqual(hat_file1.dependencies, hat_file2.dependencies)
self.assertEqual(hat_file1.compiled_with.to_table(),
hat_file2.compiled_with.to_table())
self.assertEqual(hat_file1.compiled_with.to_table(), hat_file2.compiled_with.to_table())
self.assertTrue("my_function" in hat_file2.function_map)
def test_file_basic_deserialize(self):
# Load a HAT file from the samples directory
hat_file1 = HATFile.Deserialize(
os.path.join(os.path.dirname(__file__), "..", "..", "samples",
"sample_gemm_library.hat"))
os.path.join(os.path.dirname(__file__), "..", "samples", "sample_gemm_library.hat")
)
description = {
"author": "John Doe",
"version": "1.2.3.5",
@ -82,8 +82,7 @@ class HATFile_test(unittest.TestCase):
# Do basic verification of known values in the file
# Verify the description has entries we expect
self.assertLessEqual(description.items(),
hat_file1.description.to_table().items())
self.assertLessEqual(description.items(), hat_file1.description.to_table().items())
# Verify the list of functions
self.assertTrue(len(hat_file1.functions) == 2)
self.assertTrue("GEMM_B94D27B9934D3E08" in hat_file1.function_map)

Просмотреть файл

@ -4,13 +4,10 @@ import sys
import os
import unittest
from pathlib import Path
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
from hat_file import (CallingConventionType, CompiledWith, Declaration,
Dependencies, Description, Function, FunctionTable,
HATFile, OperatingSystem, Parameter, ParameterType,
Target, UsageType)
from hatlib import (
CallingConventionType, CompiledWith, Declaration, Dependencies, Description, Function, FunctionTable, HATFile,
OperatingSystem, Parameter, ParameterType, Target, UsageType
)
class HATFile_test(unittest.TestCase):
@ -22,34 +19,38 @@ class HATFile_test(unittest.TestCase):
name="my_function",
description="Some description",
calling_convention=CallingConventionType.StdCall,
return_info=Parameter(logical_type=ParameterType.RuntimeArray,
declared_type="float*",
element_type="float",
usage=UsageType.Input,
shape="[16, 16]",
affine_map=[16, 1],
size="16 * 16 * sizeof(float)"))
return_info=Parameter(
logical_type=ParameterType.RuntimeArray,
declared_type="float*",
element_type="float",
usage=UsageType.Input,
shape="[16, 16]",
affine_map=[16, 1],
size="16 * 16 * sizeof(float)"
)
)
# Create the function table
functions = FunctionTable({"my_function": my_function})
# Create the HATFile object
hat_file1 = HATFile(
name="test_file",
description=Description(
version="0.0.1",
author="me",
license_url="https://www.apache.org/licenses/LICENSE-2.0.html"
version="0.0.1", author="me", license_url="https://www.apache.org/licenses/LICENSE-2.0.html"
),
_function_table=functions,
target=Target(required=Target.Required(os=OperatingSystem.Windows,
cpu=Target.Required.CPU(
architecture="Haswell",
extensions=["AVX2"]),
gpu=None),
optimized_for=Target.OptimizedFor()),
target=Target(
required=Target.Required(
os=OperatingSystem.Windows,
cpu=Target.Required.CPU(architecture="Haswell", extensions=["AVX2"]),
gpu=None
),
optimized_for=Target.OptimizedFor()
),
dependencies=Dependencies(link_target="my_lib.lib"),
compiled_with=CompiledWith(compiler="VC++"),
declaration=Declaration(),
path=Path(".").resolve())
path=Path(".").resolve()
)
# Serialize it to disk
test_file_name = "test_file_serialize.hat"
@ -65,15 +66,14 @@ class HATFile_test(unittest.TestCase):
# specified when we created the HATFile directly
self.assertEqual(hat_file1.description, hat_file2.description)
self.assertEqual(hat_file1.dependencies, hat_file2.dependencies)
self.assertEqual(hat_file1.compiled_with.to_table(),
hat_file2.compiled_with.to_table())
self.assertEqual(hat_file1.compiled_with.to_table(), hat_file2.compiled_with.to_table())
self.assertTrue("my_function" in hat_file2.function_map)
def test_file_basic_deserialize(self):
# Load a HAT file from the samples directory
hat_file1 = HATFile.Deserialize(
os.path.join(os.path.dirname(__file__), "..", "..", "samples",
"sample_gemm_library.hat"))
os.path.join(os.path.dirname(__file__), "..", "samples", "sample_gemm_library.hat")
)
description = {
"author": "John Doe",
"version": "1.2.3.5",
@ -82,8 +82,7 @@ class HATFile_test(unittest.TestCase):
# Do basic verification of known values in the file
# Verify the description has entries we expect
self.assertLessEqual(description.items(),
hat_file1.description.to_table().items())
self.assertLessEqual(description.items(), hat_file1.description.to_table().items())
# Verify the list of functions
self.assertTrue(len(hat_file1.function_map) == 2)
self.assertTrue("GEMM_B94D27B9934D3E08" in hat_file1.function_map)