зеркало из https://github.com/microsoft/hat.git
Refactors and restructures code for tighter integration of HAT object model (#36)
Changes: * utility functions in `hatlib` now make use of the HAT object model (OM). Previously, `tomlkit` was being using directly and was bypassing the checks and OM construction of child objects. * `hatlib.load` now returns a `HATPackage` object that results from deserializing the HAT file passed in. Additionally, if functions can be loaded, a non-empty `AttributeDict` is returned as a second value. * Moved test folder to be a sibling of library folder. Now the package no longer contains the tests when installed. Additionally, the library can be tested by running `python -m unittest discover test` in the root of the repository. * Adds YAPF style config to setup.cfg and formatted all touched files with `yapf`
This commit is contained in:
Родитель
b2ebaf5323
Коммит
544fcc0e01
|
@ -31,8 +31,8 @@ jobs:
|
|||
python -m pip install -r hatlib/requirements.txt
|
||||
- name: Unittest
|
||||
run: |
|
||||
python -m pip install -r hatlib/test/requirements.txt
|
||||
python -m unittest discover hatlib/test
|
||||
python -m pip install -r test/requirements.txt
|
||||
python -m unittest discover test
|
||||
- name: Build whl
|
||||
run: |
|
||||
python -m pip install build
|
||||
|
|
|
@ -2,9 +2,9 @@ import ctypes
|
|||
import numpy as np
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Tuple
|
||||
from functools import reduce
|
||||
from typing import List
|
||||
from typing import Any, List, Tuple
|
||||
|
||||
from . import hat_file
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -16,17 +16,16 @@ class ArgInfo:
|
|||
numpy_dtype: type
|
||||
element_num_bytes: int
|
||||
ctypes_pointer_type: Any
|
||||
usage: str = ""
|
||||
usage: hat_file.UsageType = None
|
||||
|
||||
def __init__(self, param_description):
|
||||
self.hat_declared_type = param_description["declared_type"]
|
||||
self.numpy_shape = tuple(param_description["shape"])
|
||||
self.usage = param_description["usage"]
|
||||
def __init__(self, param_description: hat_file.Parameter):
|
||||
self.hat_declared_type = param_description.declared_type
|
||||
self.numpy_shape = tuple(param_description.shape)
|
||||
self.usage = param_description.usage
|
||||
if self.hat_declared_type == "float16_t*":
|
||||
self.numpy_dtype = np.float16
|
||||
self.element_num_bytes = 2
|
||||
self.ctypes_pointer_type = ctypes.POINTER(
|
||||
ctypes.c_uint16) # same bitwidth as float16
|
||||
self.ctypes_pointer_type = ctypes.POINTER(ctypes.c_uint16) # same bitwidth as float16
|
||||
elif self.hat_declared_type == "float*":
|
||||
self.numpy_dtype = np.float32
|
||||
self.element_num_bytes = 4
|
||||
|
@ -53,23 +52,18 @@ class ArgInfo:
|
|||
self.ctypes_pointer_type = ctypes.POINTER(ctypes.c_int8)
|
||||
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Unsupported declared_type {self.hat_declared_type} in hat file"
|
||||
)
|
||||
raise NotImplementedError(f"Unsupported declared_type {self.hat_declared_type} in hat file")
|
||||
|
||||
self.numpy_strides = tuple([
|
||||
self.element_num_bytes * x for x in param_description["affine_map"]
|
||||
])
|
||||
self.numpy_strides = tuple([self.element_num_bytes * x for x in param_description.affine_map])
|
||||
|
||||
|
||||
def verify_args(args, arg_infos, function_name):
|
||||
# TODO: Update this to take a HATFunction instead, instead of arg_infos and function_name
|
||||
def verify_args(args: List, arg_infos: List[ArgInfo], function_name: str):
|
||||
""" Verifies that a list of arguments matches a list of argument descriptions in a HAT file
|
||||
"""
|
||||
# check number of args
|
||||
if len(args) != len(arg_infos):
|
||||
sys.exit(
|
||||
f"Error calling {function_name}(...): expected {len(arg_infos)} arguments but received {len(args)}"
|
||||
)
|
||||
sys.exit(f"Error calling {function_name}(...): expected {len(arg_infos)} arguments but received {len(args)}")
|
||||
|
||||
# for each arg
|
||||
for i in range(len(args)):
|
||||
|
@ -99,24 +93,3 @@ def verify_args(args, arg_infos, function_name):
|
|||
sys.exit(
|
||||
f"Error calling {function_name}(...): expected argument {i} to have strides={arg_info.numpy_strides} but received strides={arg.strides}"
|
||||
)
|
||||
|
||||
|
||||
def generate_input_sets(parameters: List[ArgInfo],
|
||||
input_sets_minimum_size_MB: int = 0,
|
||||
num_additional: int = 0):
|
||||
shapes_to_sizes = [
|
||||
reduce(lambda x, y: x * y, p.numpy_shape) for p in parameters
|
||||
]
|
||||
set_size = reduce(lambda x, y: x + y, [
|
||||
size * p.element_num_bytes
|
||||
for size, p in zip(shapes_to_sizes, parameters)
|
||||
])
|
||||
|
||||
num_input_sets = (input_sets_minimum_size_MB * 1024 * 1024 //
|
||||
set_size) + 1 + num_additional
|
||||
input_sets = [[
|
||||
np.random.random(p.numpy_shape).astype(p.numpy_dtype)
|
||||
for p in parameters
|
||||
] for _ in range(num_input_sets)]
|
||||
|
||||
return input_sets[0] if len(input_sets) == 1 else input_sets
|
||||
|
|
|
@ -9,14 +9,8 @@ import toml
|
|||
import traceback
|
||||
from pathlib import Path
|
||||
|
||||
if __package__:
|
||||
from .hat_file import HATFile
|
||||
from .hat_to_dynamic import create_dynamic_package
|
||||
from .hat import load, ArgInfo, generate_input_sets
|
||||
else:
|
||||
from hat_file import HATFile
|
||||
from hat_to_dynamic import create_dynamic_package
|
||||
from hat import load, ArgInfo, generate_input_sets
|
||||
from .hat_file import HATFile
|
||||
from .hat import load, generate_input_sets_for_func
|
||||
|
||||
|
||||
class Benchmark:
|
||||
|
@ -26,18 +20,13 @@ class Benchmark:
|
|||
Requirements:
|
||||
A compilation toolchain in your PATH: cl.exe & link.exe (Windows), gcc (Linux), or clang (macOS)
|
||||
"""
|
||||
|
||||
def __init__(self, hat_path):
|
||||
self.hat_path = Path(hat_path)
|
||||
|
||||
self.hat_package = load(self.hat_path)
|
||||
self.hat_functions = self.hat_package.names
|
||||
def __init__(self, hat_path: str):
|
||||
self.hat_path = hat_path
|
||||
self.hat_package, self.func_dict = load(self.hat_path)
|
||||
self.hat_functions = self.func_dict.names
|
||||
|
||||
# create dictionary of function descriptions defined in the hat file
|
||||
t = toml.load(self.hat_path)
|
||||
function_descriptions = t["functions"]
|
||||
self.hat_arg_descriptions = {key: [ArgInfo(
|
||||
d) for d in val["arguments"]] for key, val in function_descriptions.items()}
|
||||
self.function_descriptions = self.hat_package.hat_file.function_map
|
||||
|
||||
def run(self,
|
||||
function_name: str,
|
||||
|
@ -58,20 +47,22 @@ class Benchmark:
|
|||
Mean duration in seconds,
|
||||
Vector of timings in seconds for each batch that was run
|
||||
"""
|
||||
if function_name not in self.hat_package.names:
|
||||
if function_name not in self.hat_functions:
|
||||
raise ValueError(f"{function_name} is not found")
|
||||
|
||||
# TODO: support packing and unpacking functions
|
||||
|
||||
mean_elapsed_time, batch_timings = self._profile(
|
||||
function_name, warmup_iterations, min_timing_iterations, min_time_in_sec, input_sets_minimum_size_MB)
|
||||
function_name, warmup_iterations, min_timing_iterations,
|
||||
min_time_in_sec, input_sets_minimum_size_MB)
|
||||
print(
|
||||
f"[Benchmarking] Mean duration per iteration: {mean_elapsed_time:.8f}s")
|
||||
f"[Benchmarking] Mean duration per iteration: {mean_elapsed_time:.8f}s"
|
||||
)
|
||||
|
||||
return mean_elapsed_time, batch_timings
|
||||
|
||||
def _profile(self, function_name, warmup_iterations, min_timing_iterations, min_time_in_sec, input_sets_minimum_size_MB):
|
||||
|
||||
def _profile(self, function_name, warmup_iterations, min_timing_iterations,
|
||||
min_time_in_sec, input_sets_minimum_size_MB):
|
||||
def get_perf_counter():
|
||||
if hasattr(time, 'perf_counter_ns'):
|
||||
perf_counter = time.perf_counter_ns
|
||||
|
@ -81,19 +72,21 @@ class Benchmark:
|
|||
perf_counter_scale = 1
|
||||
return perf_counter, perf_counter_scale
|
||||
|
||||
parameters = self.hat_arg_descriptions[function_name]
|
||||
func = self.function_descriptions[function_name]
|
||||
|
||||
# generate sufficient input sets to overflow the L3 cache, since we don't know the size of the model
|
||||
# we'll make a guess based on the minimum input set size
|
||||
input_sets = generate_input_sets(
|
||||
parameters, input_sets_minimum_size_MB, num_additional=10)
|
||||
input_sets = generate_input_sets_for_func(func,
|
||||
input_sets_minimum_size_MB,
|
||||
num_additional=10)
|
||||
|
||||
set_size = 0
|
||||
for i in input_sets[0]:
|
||||
set_size += i.size * i.dtype.itemsize
|
||||
|
||||
print(
|
||||
f"[Benchmarking] Using {len(input_sets)} input sets, each {set_size} bytes")
|
||||
f"[Benchmarking] Using {len(input_sets)} input sets, each {set_size} bytes"
|
||||
)
|
||||
|
||||
perf_counter, perf_counter_scale = get_perf_counter()
|
||||
print(
|
||||
|
@ -101,10 +94,11 @@ class Benchmark:
|
|||
|
||||
for _ in range(warmup_iterations):
|
||||
for calling_args in input_sets:
|
||||
self.hat_package[function_name](*calling_args)
|
||||
self.func_dict[function_name](*calling_args)
|
||||
|
||||
print(
|
||||
f"[Benchmarking] Timing for at least {min_time_in_sec}s and at least {min_timing_iterations} iterations...")
|
||||
f"[Benchmarking] Timing for at least {min_time_in_sec}s and at least {min_timing_iterations} iterations..."
|
||||
)
|
||||
start_time = perf_counter()
|
||||
end_time = perf_counter()
|
||||
|
||||
|
@ -115,7 +109,7 @@ class Benchmark:
|
|||
while ((end_time - start_time) / perf_counter_scale) < min_time_in_sec:
|
||||
batch_start_time = perf_counter()
|
||||
for _ in range(min_timing_iterations):
|
||||
self.hat_package[function_name](*input_sets[i])
|
||||
self.func_dict[function_name](*input_sets[i])
|
||||
i = iterations % i_max
|
||||
iterations += 1
|
||||
end_time = perf_counter()
|
||||
|
@ -140,13 +134,20 @@ def write_runtime_to_hat_file(hat_path, function_name, mean_time_secs):
|
|||
# Workaround to remove extra empty lines
|
||||
with open(hat_path, "r") as f:
|
||||
lines = f.readlines()
|
||||
lines = [lines[i] for i in range(len(lines)) if not(lines[i] == "\n"
|
||||
and i < len(lines)-1 and lines[i+1] == "\n")]
|
||||
lines = [
|
||||
lines[i] for i in range(len(lines))
|
||||
if not (lines[i] == "\n" and i < len(lines) -
|
||||
1 and lines[i + 1] == "\n")
|
||||
]
|
||||
with open(hat_path, "w") as f:
|
||||
f.writelines(lines)
|
||||
|
||||
|
||||
def run_benchmark(hat_path, store_in_hat=False, batch_size=10, min_time_in_sec=10, input_sets_minimum_size_MB=50):
|
||||
def run_benchmark(hat_path,
|
||||
store_in_hat=False,
|
||||
batch_size=10,
|
||||
min_time_in_sec=10,
|
||||
input_sets_minimum_size_MB=50):
|
||||
results = []
|
||||
|
||||
benchmark = Benchmark(hat_path)
|
||||
|
@ -157,71 +158,89 @@ def run_benchmark(hat_path, store_in_hat=False, batch_size=10, min_time_in_sec=1
|
|||
continue
|
||||
|
||||
try:
|
||||
_, batch_timings = benchmark.run(function_name,
|
||||
warmup_iterations=batch_size,
|
||||
min_timing_iterations=batch_size,
|
||||
min_time_in_sec=min_time_in_sec,
|
||||
input_sets_minimum_size_MB=input_sets_minimum_size_MB)
|
||||
_, batch_timings = benchmark.run(
|
||||
function_name,
|
||||
warmup_iterations=batch_size,
|
||||
min_timing_iterations=batch_size,
|
||||
min_time_in_sec=min_time_in_sec,
|
||||
input_sets_minimum_size_MB=input_sets_minimum_size_MB)
|
||||
|
||||
sorted_batch_means = np.array(sorted(batch_timings)) / batch_size
|
||||
num_batches = len(batch_timings)
|
||||
|
||||
mean_of_means = sorted_batch_means.mean()
|
||||
median_of_means = sorted_batch_means[num_batches//2]
|
||||
mean_of_small_means = sorted_batch_means[0: num_batches//2].mean()
|
||||
robust_mean_of_means = sorted_batch_means[num_batches //
|
||||
5: -num_batches//5].mean()
|
||||
median_of_means = sorted_batch_means[num_batches // 2]
|
||||
mean_of_small_means = sorted_batch_means[0:num_batches // 2].mean()
|
||||
robust_means = sorted_batch_means[(num_batches //
|
||||
5):(-num_batches // 5)]
|
||||
robust_mean_of_means = robust_means.mean()
|
||||
min_of_means = sorted_batch_means[0]
|
||||
|
||||
if store_in_hat:
|
||||
write_runtime_to_hat_file(
|
||||
hat_path, function_name, mean_of_means)
|
||||
results.append({"function_name": function_name,
|
||||
"mean": mean_of_means,
|
||||
"median_of_means": median_of_means,
|
||||
"mean_of_small_means": mean_of_small_means,
|
||||
"robust_mean": robust_mean_of_means,
|
||||
"min_of_means": min_of_means,
|
||||
})
|
||||
write_runtime_to_hat_file(hat_path, function_name,
|
||||
mean_of_means)
|
||||
results.append({
|
||||
"function_name": function_name,
|
||||
"mean": mean_of_means,
|
||||
"median_of_means": median_of_means,
|
||||
"mean_of_small_means": mean_of_small_means,
|
||||
"robust_mean": robust_mean_of_means,
|
||||
"min_of_means": min_of_means,
|
||||
})
|
||||
except Exception as e:
|
||||
exc_type, exc_val, exc_tb = sys.exc_info()
|
||||
traceback.print_exception(
|
||||
exc_type, exc_val, exc_tb, file=sys.stderr)
|
||||
traceback.print_exception(exc_type,
|
||||
exc_val,
|
||||
exc_tb,
|
||||
file=sys.stderr)
|
||||
print("\nException message: ", e)
|
||||
print(
|
||||
f"WARNING: Failed to run function {function_name}, skipping this benchmark.")
|
||||
f"WARNING: Failed to run function {function_name}, skipping this benchmark."
|
||||
)
|
||||
return results
|
||||
|
||||
|
||||
def main(argv):
|
||||
arg_parser = argparse.ArgumentParser(
|
||||
description="Benchmarks each function in a HAT package and estimates its duration.\n"
|
||||
description=
|
||||
"Benchmarks each function in a HAT package and estimates its duration.\n"
|
||||
"Example:\n"
|
||||
" hatlib.benchmark_hat_package <hat_path>\n")
|
||||
|
||||
arg_parser.add_argument("hat_path",
|
||||
help="Path to the HAT file",
|
||||
default=None)
|
||||
arg_parser.add_argument("--store_in_hat",
|
||||
help="If set, will write the duration as meta-data back into the hat file",
|
||||
action='store_true')
|
||||
arg_parser.add_argument(
|
||||
"--store_in_hat",
|
||||
help=
|
||||
"If set, will write the duration as meta-data back into the hat file",
|
||||
action='store_true')
|
||||
arg_parser.add_argument("--results_file",
|
||||
help="Full path where the results will be written",
|
||||
default="results.csv")
|
||||
arg_parser.add_argument("--batch_size",
|
||||
help="The number of function calls in each batch (at least one full batch is executed)",
|
||||
default=10)
|
||||
arg_parser.add_argument("--min_time_in_sec",
|
||||
help="Minimum number of seconds to run the benchmark for",
|
||||
default=30)
|
||||
arg_parser.add_argument("--input_sets_minimum_size_MB",
|
||||
help="Minimum size in MB of the input sets. Typically this is large enough to ensure eviction of the biggest cache on the target (e.g. L3 on an desktop CPU)",
|
||||
default=50)
|
||||
arg_parser.add_argument(
|
||||
"--batch_size",
|
||||
help=
|
||||
"The number of function calls in each batch (at least one full batch is executed)",
|
||||
default=10)
|
||||
arg_parser.add_argument(
|
||||
"--min_time_in_sec",
|
||||
help="Minimum number of seconds to run the benchmark for",
|
||||
default=30)
|
||||
arg_parser.add_argument(
|
||||
"--input_sets_minimum_size_MB",
|
||||
help=
|
||||
"Minimum size in MB of the input sets. Typically this is large enough to ensure eviction of the biggest cache on the target (e.g. L3 on an desktop CPU)",
|
||||
default=50)
|
||||
|
||||
args = vars(arg_parser.parse_args(argv))
|
||||
|
||||
results = run_benchmark(args["hat_path"], args["store_in_hat"], batch_size=int(args["batch_size"]), min_time_in_sec=int(
|
||||
args["min_time_in_sec"]), input_sets_minimum_size_MB=int(args["input_sets_minimum_size_MB"]))
|
||||
results = run_benchmark(args["hat_path"],
|
||||
args["store_in_hat"],
|
||||
batch_size=int(args["batch_size"]),
|
||||
min_time_in_sec=int(args["min_time_in_sec"]),
|
||||
input_sets_minimum_size_MB=int(
|
||||
args["input_sets_minimum_size_MB"]))
|
||||
df = pd.DataFrame(results)
|
||||
df.to_csv(args["results_file"], index=False)
|
||||
pd.options.display.float_format = '{:8.8f}'.format
|
||||
|
|
|
@ -10,19 +10,15 @@ from typing import List
|
|||
from pynvrtc.compiler import Program
|
||||
from cuda import cuda, nvrtc
|
||||
|
||||
try:
|
||||
from .arg_info import ArgInfo, verify_args
|
||||
from .gpu_headers import CUDA_HEADER_MAP
|
||||
except:
|
||||
from arg_info import ArgInfo, verify_args
|
||||
from gpu_headers import CUDA_HEADER_MAP
|
||||
from .arg_info import ArgInfo, verify_args
|
||||
from .gpu_headers import CUDA_HEADER_MAP
|
||||
from .hat_file import Function
|
||||
|
||||
|
||||
def ASSERT_DRV(err):
|
||||
if isinstance(err, cuda.CUresult):
|
||||
if err != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise RuntimeError("Cuda Error: {}".format(
|
||||
cuda.cuGetErrorString(err)[1].decode('utf-8')))
|
||||
raise RuntimeError("Cuda Error: {}".format(cuda.cuGetErrorString(err)[1].decode('utf-8')))
|
||||
elif isinstance(err, nvrtc.nvrtcResult):
|
||||
if err != nvrtc.nvrtcResult.NVRTC_SUCCESS:
|
||||
raise RuntimeError("Nvrtc Error: {}".format(err))
|
||||
|
@ -51,15 +47,12 @@ def _find_cuda_incl_path() -> pathlib.Path:
|
|||
def compile_cuda_program(cuda_src_path: pathlib.Path, func_name):
|
||||
src = cuda_src_path.read_text()
|
||||
|
||||
prog = Program(src=src,
|
||||
name=func_name,
|
||||
headers=CUDA_HEADER_MAP.values(),
|
||||
include_names=CUDA_HEADER_MAP.keys())
|
||||
prog = Program(src=src, name=func_name, headers=CUDA_HEADER_MAP.values(), include_names=CUDA_HEADER_MAP.keys())
|
||||
ptx = prog.compile([
|
||||
'-use_fast_math',
|
||||
'-default-device',
|
||||
'-std=c++11',
|
||||
'-arch=sm_52', # TODO: is this needed?
|
||||
'-arch=sm_52', # TODO: is this needed?
|
||||
])
|
||||
|
||||
return ptx
|
||||
|
@ -91,27 +84,20 @@ def get_func_from_ptx(ptx, func_name):
|
|||
|
||||
|
||||
def _arg_size(arg_info: ArgInfo):
|
||||
return arg_info.element_num_bytes * reduce(lambda x, y: x * y,
|
||||
arg_info.numpy_shape)
|
||||
return arg_info.element_num_bytes * reduce(lambda x, y: x * y, arg_info.numpy_shape)
|
||||
|
||||
|
||||
def transfer_mem_host_to_cuda(device_args: List, host_args: List[np.array],
|
||||
arg_infos: List[ArgInfo]):
|
||||
for device_arg, host_arg, arg_info in zip(device_args, host_args,
|
||||
arg_infos):
|
||||
if 'input' in arg_info.usage:
|
||||
err, = cuda.cuMemcpyHtoD(device_arg, host_arg.ctypes.data,
|
||||
_arg_size(arg_info))
|
||||
def transfer_mem_host_to_cuda(device_args: List, host_args: List[np.array], arg_infos: List[ArgInfo]):
|
||||
for device_arg, host_arg, arg_info in zip(device_args, host_args, arg_infos):
|
||||
if 'input' in arg_info.usage.value:
|
||||
err, = cuda.cuMemcpyHtoD(device_arg, host_arg.ctypes.data, _arg_size(arg_info))
|
||||
ASSERT_DRV(err)
|
||||
|
||||
|
||||
def transfer_mem_cuda_to_host(device_args: List, host_args: List[np.array],
|
||||
arg_infos: List[ArgInfo]):
|
||||
for device_arg, host_arg, arg_info in zip(device_args, host_args,
|
||||
arg_infos):
|
||||
if 'output' in arg_info.usage:
|
||||
err, = cuda.cuMemcpyDtoH(host_arg.ctypes.data, device_arg,
|
||||
_arg_size(arg_info))
|
||||
def transfer_mem_cuda_to_host(device_args: List, host_args: List[np.array], arg_infos: List[ArgInfo]):
|
||||
for device_arg, host_arg, arg_info in zip(device_args, host_args, arg_infos):
|
||||
if 'output' in arg_info.usage.value:
|
||||
err, = cuda.cuMemcpyDtoH(host_arg.ctypes.data, device_arg, _arg_size(arg_info))
|
||||
ASSERT_DRV(err)
|
||||
|
||||
|
||||
|
@ -133,10 +119,12 @@ def device_args_to_ptr_list(device_args: List):
|
|||
return ptrs
|
||||
|
||||
|
||||
def create_loader_for_device_function(device_func, hat_details):
|
||||
hat_path: pathlib.Path = hat_details.path
|
||||
cuda_src_path: pathlib.Path = hat_path.parent / device_func["provider"]
|
||||
func_name = device_func["name"]
|
||||
def create_loader_for_device_function(device_func: Function, hat_dir_path: str):
|
||||
if not device_func.provider:
|
||||
raise RuntimeError("Expected a provider for the device function")
|
||||
|
||||
cuda_src_path: pathlib.Path = pathlib.Path(hat_dir_path) / device_func.provider
|
||||
func_name = device_func.name
|
||||
|
||||
ptx = compile_cuda_program(cuda_src_path, func_name)
|
||||
|
||||
|
@ -144,16 +132,14 @@ def create_loader_for_device_function(device_func, hat_details):
|
|||
|
||||
kernel = get_func_from_ptx(ptx, func_name)
|
||||
|
||||
hat_arg_descriptions = device_func["arguments"]
|
||||
hat_arg_descriptions = device_func.arguments
|
||||
arg_infos = [ArgInfo(d) for d in hat_arg_descriptions]
|
||||
launch_parameters = device_func["launch_parameters"]
|
||||
launch_parameters = device_func.launch_parameters
|
||||
|
||||
def f(*args):
|
||||
verify_args(args, arg_infos, func_name)
|
||||
device_mem = allocate_cuda_mem(arg_infos)
|
||||
transfer_mem_host_to_cuda(device_args=device_mem,
|
||||
host_args=args,
|
||||
arg_infos=arg_infos)
|
||||
transfer_mem_host_to_cuda(device_args=device_mem, host_args=args, arg_infos=arg_infos)
|
||||
ptrs = device_args_to_ptr_list(device_mem)
|
||||
|
||||
err, stream = cuda.cuStreamCreate(0)
|
||||
|
@ -161,18 +147,16 @@ def create_loader_for_device_function(device_func, hat_details):
|
|||
|
||||
err, = cuda.cuLaunchKernel(
|
||||
kernel,
|
||||
*launch_parameters, # [ grid[x-z], block[x-z] ]
|
||||
0, # dynamic shared memory
|
||||
stream, # stream
|
||||
ptrs.ctypes.data, # kernel arguments
|
||||
0, # extra (ignore)
|
||||
*launch_parameters, # [ grid[x-z], block[x-z] ]
|
||||
0, # dynamic shared memory
|
||||
stream, # stream
|
||||
ptrs.ctypes.data, # kernel arguments
|
||||
0, # extra (ignore)
|
||||
)
|
||||
ASSERT_DRV(err)
|
||||
err, = cuda.cuStreamSynchronize(stream)
|
||||
ASSERT_DRV(err)
|
||||
|
||||
transfer_mem_cuda_to_host(device_args=device_mem,
|
||||
host_args=args,
|
||||
arg_infos=arg_infos)
|
||||
transfer_mem_cuda_to_host(device_args=device_mem, host_args=args, arg_infos=arg_infos)
|
||||
|
||||
return f
|
||||
|
|
185
hatlib/hat.py
185
hatlib/hat.py
|
@ -24,161 +24,66 @@ For example:
|
|||
# call a package function named 'my_func_698b5e5c'
|
||||
package.my_func_698b5e5c(A, B, D, E)
|
||||
"""
|
||||
import numpy as np
|
||||
from typing import Tuple, Union
|
||||
from functools import reduce
|
||||
|
||||
import ctypes
|
||||
import pathlib
|
||||
import sys
|
||||
import toml
|
||||
from collections import OrderedDict
|
||||
from functools import partial
|
||||
from . import hat_file
|
||||
from . import hat_package
|
||||
from .arg_info import ArgInfo
|
||||
|
||||
try:
|
||||
from . import hat_file
|
||||
from .arg_info import ArgInfo, verify_args, generate_input_sets
|
||||
except:
|
||||
import hat_file
|
||||
from arg_info import ArgInfo, verify_args, generate_input_sets
|
||||
|
||||
try:
|
||||
try:
|
||||
from . import cuda_loader
|
||||
except ModuleNotFoundError:
|
||||
import cuda_loader
|
||||
except:
|
||||
CUDA_AVAILABLE = False
|
||||
else:
|
||||
CUDA_AVAILABLE = True
|
||||
def generate_input_sets_for_func(func: hat_file.Function,
|
||||
input_sets_minimum_size_MB: int = 0,
|
||||
num_additional: int = 0):
|
||||
parameters = list(map(ArgInfo, func.arguments))
|
||||
shapes_to_sizes = [
|
||||
reduce(lambda x, y: x * y, p.numpy_shape) for p in parameters
|
||||
]
|
||||
set_size = reduce(
|
||||
lambda x, y: x + y,
|
||||
map(lambda size, p: size * p.element_num_bytes, shapes_to_sizes,
|
||||
parameters))
|
||||
|
||||
try:
|
||||
try:
|
||||
from . import rocm_loader
|
||||
except ModuleNotFoundError:
|
||||
import rocm_loader
|
||||
except:
|
||||
ROCM_AVAILABLE = False
|
||||
else:
|
||||
ROCM_AVAILABLE = True
|
||||
num_input_sets = (input_sets_minimum_size_MB * 1024 * 1024 //
|
||||
set_size) + 1 + num_additional
|
||||
input_sets = [[
|
||||
np.random.random(p.numpy_shape).astype(p.numpy_dtype)
|
||||
for p in parameters
|
||||
] for _ in range(num_input_sets)]
|
||||
|
||||
NOTIFY_ABOUT_CUDA = not CUDA_AVAILABLE
|
||||
NOTIFY_ABOUT_ROCM = not ROCM_AVAILABLE
|
||||
return input_sets[0] if len(input_sets) == 1 else input_sets
|
||||
|
||||
|
||||
def generate_input_sets_for_hat_file(hat_path):
|
||||
hat_path = pathlib.Path(hat_path).absolute()
|
||||
t: hat_file.HATFile = toml.load(hat_path)
|
||||
t = hat_file.HATFile.Deserialize(hat_path)
|
||||
return {
|
||||
func_name:
|
||||
generate_input_sets(list(map(ArgInfo, func_desc["arguments"])))
|
||||
for func_name, func_desc in t["functions"].items()
|
||||
func_name: generate_input_sets_for_func(func_desc)
|
||||
for func_name, func_desc in t.function_map.items()
|
||||
}
|
||||
|
||||
|
||||
class AttributeDict(OrderedDict):
|
||||
""" Dictionary that allows entries to be accessed like attributes
|
||||
def load(
|
||||
hat_path,
|
||||
try_dynamic_load=True
|
||||
) -> Tuple[hat_package.HATPackage, Union[hat_package.AttributeDict, None]]:
|
||||
"""
|
||||
__getattr__ = OrderedDict.__getitem__
|
||||
|
||||
@property
|
||||
def names(self):
|
||||
return list(self.keys())
|
||||
|
||||
def __getitem__(self, key):
|
||||
for k, v in self.items():
|
||||
if k.startswith(key):
|
||||
return v
|
||||
return OrderedDict.__getitem__(key)
|
||||
|
||||
|
||||
def hat_description_to_python_function(hat_description: hat_file.HATFile,
|
||||
hat_details: AttributeDict):
|
||||
""" Creates a callable function based on a function description in a HAT
|
||||
package
|
||||
Returns a HATPackage object loaded from the path provided. If
|
||||
`try_dynamic_load` is True, a non-empty dictionary object that can be used
|
||||
to invoke the functions in the HATPackage on the current system is the
|
||||
second returned object, `None` otherwise.
|
||||
"""
|
||||
|
||||
for func_name, func_desc in hat_description["functions"].items():
|
||||
pkg = hat_package.HATPackage(hat_file_path=hat_path)
|
||||
|
||||
func_desc: hat_file.Function
|
||||
func_name: str
|
||||
# TODO: Add heuristics to determine whether loading is possible on this system
|
||||
function_dict = None
|
||||
|
||||
launches = func_desc.get("launches")
|
||||
if not launches:
|
||||
hat_library: ctypes.CDLL = hat_details.shared_lib
|
||||
if try_dynamic_load:
|
||||
try:
|
||||
function_dict = hat_package.hat_package_to_func_dict(pkg)
|
||||
except:
|
||||
# TODO: Figure out how to communicate failure better
|
||||
pass
|
||||
|
||||
def f(function_name, hat_arg_descriptions, *args):
|
||||
# verify that the (numpy) input args match the description in
|
||||
# the hat file
|
||||
arg_infos = [ArgInfo(d) for d in hat_arg_descriptions]
|
||||
verify_args(args, arg_infos, function_name)
|
||||
|
||||
# prepare the args to the hat package
|
||||
hat_args = [
|
||||
arg.ctypes.data_as(arg_info.ctypes_pointer_type)
|
||||
for arg, arg_info in zip(args, arg_infos)
|
||||
]
|
||||
|
||||
# call the function in the hat package
|
||||
hat_library[function_name](*hat_args)
|
||||
|
||||
yield func_name, partial(f, func_desc["name"], func_desc["arguments"])
|
||||
|
||||
else:
|
||||
device_func = hat_description.get("device_functions",
|
||||
{}).get(launches)
|
||||
|
||||
func_runtime = func_desc.get("runtime")
|
||||
if not device_func:
|
||||
raise RuntimeError(
|
||||
f"Couldn't find device function for loader: " + launches)
|
||||
if not func_runtime:
|
||||
raise RuntimeError(f"Couldn't find runtime for loader: " +
|
||||
launches)
|
||||
if func_runtime == "CUDA":
|
||||
global NOTIFY_ABOUT_CUDA
|
||||
if CUDA_AVAILABLE:
|
||||
yield (func_name,
|
||||
cuda_loader.create_loader_for_device_function(
|
||||
device_func, hat_details))
|
||||
elif NOTIFY_ABOUT_CUDA:
|
||||
print("CUDA functionality not available on this machine."
|
||||
" Please install the cuda and pvnrtc python modules")
|
||||
NOTIFY_ABOUT_CUDA = False
|
||||
elif func_runtime == "ROCM":
|
||||
global NOTIFY_ABOUT_ROCM
|
||||
if ROCM_AVAILABLE:
|
||||
yield (func_name,
|
||||
rocm_loader.create_loader_for_device_function(
|
||||
device_func, hat_details))
|
||||
elif NOTIFY_ABOUT_ROCM:
|
||||
print("ROCm functionality not available on this machine."
|
||||
" Please install the ROCm 4.2 or higher")
|
||||
NOTIFY_ABOUT_ROCM = False
|
||||
|
||||
|
||||
def load(hat_path):
|
||||
""" Creates a class with static functions based on the function
|
||||
descriptions in a HAT package
|
||||
"""
|
||||
# load the function decscriptions from the hat file
|
||||
hat_path = pathlib.Path(hat_path).absolute()
|
||||
t: hat_file.HATFile = toml.load(hat_path)
|
||||
hat_details = AttributeDict({"path": hat_path})
|
||||
|
||||
# function_descriptions = t["functions"]
|
||||
hat_binary_filename = t["dependencies"]["link_target"]
|
||||
hat_binary_path = hat_path.parent / hat_binary_filename
|
||||
|
||||
# check that the HAT library has a supported file extension
|
||||
supported_extensions = [".dll", ".so"]
|
||||
extension = hat_binary_path.suffix
|
||||
if extension and extension not in supported_extensions:
|
||||
sys.exit(f"Unsupported HAT library extension: {extension}")
|
||||
|
||||
# load the hat_library:
|
||||
hat_library = ctypes.cdll.LoadLibrary(
|
||||
str(hat_binary_path)) if extension else None
|
||||
hat_details["shared_lib"] = hat_library
|
||||
|
||||
# create dictionary of functions defined in the hat file
|
||||
function_dict = AttributeDict(
|
||||
dict(hat_description_to_python_function(t, hat_details)))
|
||||
return function_dict
|
||||
return pkg, function_dict
|
||||
|
|
|
@ -1,12 +1,12 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
# Utility to parse the TOML metadata from HAT files
|
||||
from enum import Enum
|
||||
from dataclasses import dataclass, field
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import tomlkit
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
|
||||
# TODO : type-checking on leaf node values
|
||||
|
||||
|
@ -105,7 +105,8 @@ class Description(AuxiliarySupportedTable):
|
|||
author=table["author"],
|
||||
version=table["version"],
|
||||
license_url=table["license_url"],
|
||||
auxiliary=AuxiliarySupportedTable.parse_auxiliary(table))
|
||||
auxiliary=AuxiliarySupportedTable.parse_auxiliary(table)
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -147,14 +148,9 @@ class Parameter:
|
|||
# TODO : change "usage" to "role" in schema
|
||||
@staticmethod
|
||||
def parse_from_table(param_table):
|
||||
required_table_entries = [
|
||||
"name", "description", "logical_type", "declared_type",
|
||||
"element_type", "usage"
|
||||
]
|
||||
required_table_entries = ["name", "description", "logical_type", "declared_type", "element_type", "usage"]
|
||||
_check_required_table_entries(param_table, required_table_entries)
|
||||
affine_array_required_table_entries = [
|
||||
"shape", "affine_map", "affine_offset"
|
||||
]
|
||||
affine_array_required_table_entries = ["shape", "affine_map", "affine_offset"]
|
||||
runtime_array_required_table_entries = ["size"]
|
||||
|
||||
name = param_table["name"]
|
||||
|
@ -164,23 +160,23 @@ class Parameter:
|
|||
element_type = param_table["element_type"]
|
||||
usage = UsageType(param_table["usage"])
|
||||
|
||||
param = Parameter(name=name,
|
||||
description=description,
|
||||
logical_type=logical_type,
|
||||
declared_type=declared_type,
|
||||
element_type=element_type,
|
||||
usage=usage)
|
||||
param = Parameter(
|
||||
name=name,
|
||||
description=description,
|
||||
logical_type=logical_type,
|
||||
declared_type=declared_type,
|
||||
element_type=element_type,
|
||||
usage=usage
|
||||
)
|
||||
|
||||
if logical_type == ParameterType.AffineArray:
|
||||
_check_required_table_entries(param_table,
|
||||
affine_array_required_table_entries)
|
||||
_check_required_table_entries(param_table, affine_array_required_table_entries)
|
||||
param.shape = param_table["shape"]
|
||||
param.affine_map = param_table["affine_map"]
|
||||
param.affine_offset = param_table["affine_offset"]
|
||||
|
||||
elif logical_type == ParameterType.RuntimeArray:
|
||||
_check_required_table_entries(
|
||||
param_table, runtime_array_required_table_entries)
|
||||
_check_required_table_entries(param_table, runtime_array_required_table_entries)
|
||||
param.size = param_table["size"]
|
||||
|
||||
return param
|
||||
|
@ -189,7 +185,7 @@ class Parameter:
|
|||
@dataclass
|
||||
class Function(AuxiliarySupportedTable):
|
||||
# required
|
||||
arguments: list = field(default_factory=list)
|
||||
arguments: List[Parameter] = field(default_factory=list)
|
||||
calling_convention: CallingConventionType = None
|
||||
description: str = ""
|
||||
hat_file: any = None
|
||||
|
@ -214,7 +210,7 @@ class Function(AuxiliarySupportedTable):
|
|||
arg_array.append(arg_table)
|
||||
table.add(
|
||||
"arguments", arg_array
|
||||
) # TODO : figure out why this isn't indenting after serialization in some cases
|
||||
) # TODO : figure out why this isn't indenting after serialization in some cases
|
||||
|
||||
if self.launch_parameters:
|
||||
table.add("launch_parameters", self.launch_parameters)
|
||||
|
@ -236,44 +232,36 @@ class Function(AuxiliarySupportedTable):
|
|||
|
||||
@staticmethod
|
||||
def parse_from_table(function_table):
|
||||
required_table_entries = [
|
||||
"name", "description", "calling_convention", "arguments", "return"
|
||||
]
|
||||
required_table_entries = ["name", "description", "calling_convention", "arguments", "return"]
|
||||
_check_required_table_entries(function_table, required_table_entries)
|
||||
arguments = [
|
||||
Parameter.parse_from_table(param_table)
|
||||
for param_table in function_table["arguments"]
|
||||
]
|
||||
arguments = [Parameter.parse_from_table(param_table) for param_table in function_table["arguments"]]
|
||||
|
||||
launch_parameters = function_table[
|
||||
"launch_parameters"] if "launch_parameters" in function_table else []
|
||||
launch_parameters = function_table["launch_parameters"] if "launch_parameters" in function_table else []
|
||||
|
||||
launches = function_table[
|
||||
"launches"] if "launches" in function_table else ""
|
||||
launches = function_table["launches"] if "launches" in function_table else ""
|
||||
|
||||
provider = function_table[
|
||||
"provider"] if "provider" in function_table else ""
|
||||
provider = function_table["provider"] if "provider" in function_table else ""
|
||||
|
||||
runtime = function_table[
|
||||
"runtime"] if "runtime" in function_table else ""
|
||||
runtime = function_table["runtime"] if "runtime" in function_table else ""
|
||||
|
||||
return_info = Parameter.parse_from_table(function_table["return"])
|
||||
|
||||
return Function(
|
||||
name=function_table["name"],
|
||||
description=function_table["description"],
|
||||
calling_convention=CallingConventionType(
|
||||
function_table["calling_convention"]),
|
||||
calling_convention=CallingConventionType(function_table["calling_convention"]),
|
||||
arguments=arguments,
|
||||
return_info=return_info,
|
||||
launch_parameters=launch_parameters,
|
||||
launches=launches,
|
||||
provider=provider,
|
||||
runtime=runtime,
|
||||
auxiliary=AuxiliarySupportedTable.parse_auxiliary(function_table))
|
||||
auxiliary=AuxiliarySupportedTable.parse_auxiliary(function_table)
|
||||
)
|
||||
|
||||
|
||||
class FunctionTableCommon:
|
||||
|
||||
def __init__(self, function_map):
|
||||
self.function_map = function_map
|
||||
self.functions = self.function_map.values()
|
||||
|
@ -281,15 +269,13 @@ class FunctionTableCommon:
|
|||
def to_table(self):
|
||||
func_table = tomlkit.table()
|
||||
for function_key in self.function_map:
|
||||
func_table.add(function_key,
|
||||
self.function_map[function_key].to_table())
|
||||
func_table.add(function_key, self.function_map[function_key].to_table())
|
||||
return func_table
|
||||
|
||||
@classmethod
|
||||
def parse_from_table(cls, all_functions_table):
|
||||
function_map = {
|
||||
function_key:
|
||||
Function.parse_from_table(all_functions_table[function_key])
|
||||
function_key: Function.parse_from_table(all_functions_table[function_key])
|
||||
for function_key in all_functions_table
|
||||
}
|
||||
return cls(function_map)
|
||||
|
@ -305,8 +291,10 @@ class DeviceFunctionTable(FunctionTableCommon):
|
|||
|
||||
@dataclass
|
||||
class Target:
|
||||
|
||||
@dataclass
|
||||
class Required:
|
||||
|
||||
@dataclass
|
||||
class CPU:
|
||||
TableName = TargetType.CPU.value
|
||||
|
@ -336,9 +324,8 @@ class Target:
|
|||
runtime = table.get("runtime", "")
|
||||
|
||||
return Target.Required.CPU(
|
||||
architecture=table["architecture"],
|
||||
extensions=table["extensions"],
|
||||
runtime=runtime)
|
||||
architecture=table["architecture"], extensions=table["extensions"], runtime=runtime
|
||||
)
|
||||
|
||||
@dataclass
|
||||
class GPU:
|
||||
|
@ -357,8 +344,7 @@ class Target:
|
|||
table.add("model", self.model)
|
||||
table.add("runtime", self.runtime)
|
||||
table.add("blocks", self.blocks)
|
||||
table.add("instruction_set_version",
|
||||
self.instruction_set_version)
|
||||
table.add("instruction_set_version", self.instruction_set_version)
|
||||
table.add("min_threads", self.min_threads)
|
||||
table.add("min_global_memory_KB", self.min_global_memory_KB)
|
||||
table.add("min_shared_memory_KB", self.min_shared_memory_KB)
|
||||
|
@ -382,7 +368,8 @@ class Target:
|
|||
min_threads=table["min_threads"],
|
||||
min_global_memory_KB=table["min_global_memory_KB"],
|
||||
min_shared_memory_KB=table["min_shared_memory_KB"],
|
||||
min_texture_memory_KB=table["min_texture_memory_KB"])
|
||||
min_texture_memory_KB=table["min_texture_memory_KB"]
|
||||
)
|
||||
|
||||
TableName = "required"
|
||||
os: OperatingSystem = None
|
||||
|
@ -401,11 +388,9 @@ class Target:
|
|||
def parse_from_table(table):
|
||||
required_table_entries = ["os", Target.Required.CPU.TableName]
|
||||
_check_required_table_entries(table, required_table_entries)
|
||||
cpu_info = Target.Required.CPU.parse_from_table(
|
||||
table[Target.Required.CPU.TableName])
|
||||
cpu_info = Target.Required.CPU.parse_from_table(table[Target.Required.CPU.TableName])
|
||||
if Target.Required.GPU.TableName in table:
|
||||
gpu_info = Target.Required.GPU.parse_from_table(
|
||||
table[Target.Required.GPU.TableName])
|
||||
gpu_info = Target.Required.GPU.parse_from_table(table[Target.Required.GPU.TableName])
|
||||
else:
|
||||
gpu_info = Target.Required.GPU()
|
||||
return Target.Required(os=table["os"], cpu=cpu_info, gpu=gpu_info)
|
||||
|
@ -429,19 +414,16 @@ class Target:
|
|||
table = tomlkit.table()
|
||||
table.add(Target.Required.TableName, self.required.to_table())
|
||||
if self.optimized_for is not None:
|
||||
table.add(Target.OptimizedFor.TableName,
|
||||
self.optimized_for.to_table())
|
||||
table.add(Target.OptimizedFor.TableName, self.optimized_for.to_table())
|
||||
return table
|
||||
|
||||
@staticmethod
|
||||
def parse_from_table(target_table):
|
||||
required_table_entries = [Target.Required.TableName]
|
||||
_check_required_table_entries(target_table, required_table_entries)
|
||||
required_data = Target.Required.parse_from_table(
|
||||
target_table[Target.Required.TableName])
|
||||
required_data = Target.Required.parse_from_table(target_table[Target.Required.TableName])
|
||||
if Target.OptimizedFor.TableName in target_table:
|
||||
optimized_for_data = Target.OptimizedFor.parse_from_table(
|
||||
target_table[Target.OptimizedFor.TableName])
|
||||
optimized_for_data = Target.OptimizedFor.parse_from_table(target_table[Target.OptimizedFor.TableName])
|
||||
else:
|
||||
optimized_for_data = Target.OptimizedFor()
|
||||
return Target(required=required_data, optimized_for=optimized_for_data)
|
||||
|
@ -462,9 +444,7 @@ class LibraryReference:
|
|||
|
||||
@staticmethod
|
||||
def parse_from_table(table):
|
||||
return LibraryReference(name=table["name"],
|
||||
version=table["version"],
|
||||
target_file=table["target_file"])
|
||||
return LibraryReference(name=table["name"], version=table["version"], target_file=table["target_file"])
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -490,17 +470,14 @@ class Dependencies(AuxiliarySupportedTable):
|
|||
@staticmethod
|
||||
def parse_from_table(dependencies_table):
|
||||
required_table_entries = ["link_target", "deploy_files", "dynamic"]
|
||||
_check_required_table_entries(dependencies_table,
|
||||
required_table_entries)
|
||||
dynamic = [
|
||||
LibraryReference.parse_from_table(lib_ref_table)
|
||||
for lib_ref_table in dependencies_table["dynamic"]
|
||||
]
|
||||
return Dependencies(link_target=dependencies_table["link_target"],
|
||||
deploy_files=dependencies_table["deploy_files"],
|
||||
dynamic=dynamic,
|
||||
auxiliary=AuxiliarySupportedTable.parse_auxiliary(
|
||||
dependencies_table))
|
||||
_check_required_table_entries(dependencies_table, required_table_entries)
|
||||
dynamic = [LibraryReference.parse_from_table(lib_ref_table) for lib_ref_table in dependencies_table["dynamic"]]
|
||||
return Dependencies(
|
||||
link_target=dependencies_table["link_target"],
|
||||
deploy_files=dependencies_table["deploy_files"],
|
||||
dynamic=dynamic,
|
||||
auxiliary=AuxiliarySupportedTable.parse_auxiliary(dependencies_table)
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -527,16 +504,16 @@ class CompiledWith:
|
|||
@staticmethod
|
||||
def parse_from_table(compiled_with_table):
|
||||
required_table_entries = ["compiler", "flags", "crt", "libraries"]
|
||||
_check_required_table_entries(compiled_with_table,
|
||||
required_table_entries)
|
||||
_check_required_table_entries(compiled_with_table, required_table_entries)
|
||||
libraries = [
|
||||
LibraryReference.parse_from_table(lib_ref_table)
|
||||
for lib_ref_table in compiled_with_table["libraries"]
|
||||
LibraryReference.parse_from_table(lib_ref_table) for lib_ref_table in compiled_with_table["libraries"]
|
||||
]
|
||||
return CompiledWith(compiler=compiled_with_table["compiler"],
|
||||
flags=compiled_with_table["flags"],
|
||||
crt=compiled_with_table["crt"],
|
||||
libraries=libraries)
|
||||
return CompiledWith(
|
||||
compiler=compiled_with_table["compiler"],
|
||||
flags=compiled_with_table["flags"],
|
||||
crt=compiled_with_table["crt"],
|
||||
libraries=libraries
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -552,8 +529,7 @@ class Declaration:
|
|||
@staticmethod
|
||||
def parse_from_table(declaration_table):
|
||||
required_table_entries = ["code"]
|
||||
_check_required_table_entries(declaration_table,
|
||||
required_table_entries)
|
||||
_check_required_table_entries(declaration_table, required_table_entries)
|
||||
return Declaration(code=declaration_table["code"])
|
||||
|
||||
|
||||
|
@ -573,8 +549,8 @@ class HATFile:
|
|||
_device_function_table: DeviceFunctionTable = None
|
||||
functions: list = field(default_factory=list)
|
||||
device_functions: list = field(default_factory=list)
|
||||
function_map: dict = field(default_factory=dict)
|
||||
device_function_map: list = field(default_factory=list)
|
||||
function_map: Dict[str, Function] = field(default_factory=dict)
|
||||
device_function_map: Dict[str, Function] = field(default_factory=dict)
|
||||
target: Target = None
|
||||
dependencies: Dependencies = None
|
||||
compiled_with: CompiledWith = None
|
||||
|
@ -589,8 +565,7 @@ class HATFile:
|
|||
self.function_map = self._function_table.function_map
|
||||
for func in self.functions:
|
||||
func.hat_file = self
|
||||
func.link_target = Path(self.path).resolve(
|
||||
).parent / self.dependencies.link_target
|
||||
func.link_target = Path(self.path).resolve().parent / self.dependencies.link_target
|
||||
|
||||
if not self._device_function_table:
|
||||
self._device_function_table = DeviceFunctionTable({})
|
||||
|
@ -604,11 +579,9 @@ class HATFile:
|
|||
filepath = self.path
|
||||
root_table = tomlkit.table()
|
||||
root_table.add(Description.TableName, self.description.to_table())
|
||||
root_table.add(FunctionTable.TableName,
|
||||
self._function_table.to_table())
|
||||
root_table.add(FunctionTable.TableName, self._function_table.to_table())
|
||||
if self.device_function_map:
|
||||
root_table.add(DeviceFunctionTable.TableName,
|
||||
self._device_function_table.to_table())
|
||||
root_table.add(DeviceFunctionTable.TableName, self._device_function_table.to_table())
|
||||
root_table.add(Target.TableName, self.target.to_table())
|
||||
root_table.add(Dependencies.TableName, self.dependencies.to_table())
|
||||
root_table.add(CompiledWith.TableName, self.compiled_with.to_table())
|
||||
|
@ -626,29 +599,22 @@ class HATFile:
|
|||
hat_toml = _read_toml_file(filepath)
|
||||
name = os.path.splitext(os.path.basename(filepath))[0]
|
||||
required_entries = [
|
||||
Description.TableName, FunctionTable.TableName, Target.TableName,
|
||||
Dependencies.TableName, CompiledWith.TableName,
|
||||
Declaration.TableName
|
||||
Description.TableName, FunctionTable.TableName, Target.TableName, Dependencies.TableName,
|
||||
CompiledWith.TableName, Declaration.TableName
|
||||
]
|
||||
_check_required_table_entries(hat_toml, required_entries)
|
||||
device_function_table = None
|
||||
if DeviceFunctionTable.TableName in hat_toml:
|
||||
device_function_table = DeviceFunctionTable.parse_from_table(
|
||||
hat_toml[DeviceFunctionTable.TableName])
|
||||
device_function_table = DeviceFunctionTable.parse_from_table(hat_toml[DeviceFunctionTable.TableName])
|
||||
hat_file = HATFile(
|
||||
name=name,
|
||||
description=Description.parse_from_table(
|
||||
hat_toml[Description.TableName]),
|
||||
_function_table=FunctionTable.parse_from_table(
|
||||
hat_toml[FunctionTable.TableName]),
|
||||
description=Description.parse_from_table(hat_toml[Description.TableName]),
|
||||
_function_table=FunctionTable.parse_from_table(hat_toml[FunctionTable.TableName]),
|
||||
_device_function_table=device_function_table,
|
||||
target=Target.parse_from_table(
|
||||
hat_toml[Target.TableName]),
|
||||
dependencies=Dependencies.parse_from_table(
|
||||
hat_toml[Dependencies.TableName]),
|
||||
compiled_with=CompiledWith.parse_from_table(
|
||||
hat_toml[CompiledWith.TableName]),
|
||||
declaration=Declaration.parse_from_table(
|
||||
hat_toml[Declaration.TableName]),
|
||||
path=Path(filepath).resolve())
|
||||
target=Target.parse_from_table(hat_toml[Target.TableName]),
|
||||
dependencies=Dependencies.parse_from_table(hat_toml[Dependencies.TableName]),
|
||||
compiled_with=CompiledWith.parse_from_table(hat_toml[CompiledWith.TableName]),
|
||||
declaration=Declaration.parse_from_table(hat_toml[Declaration.TableName]),
|
||||
path=Path(filepath).resolve()
|
||||
)
|
||||
return hat_file
|
||||
|
|
|
@ -2,11 +2,19 @@
|
|||
|
||||
# Utility to parse and validate a HAT package
|
||||
|
||||
from .hat_file import HATFile
|
||||
import ctypes
|
||||
from typing import List, Union
|
||||
from collections import OrderedDict
|
||||
from functools import partial
|
||||
|
||||
from .hat_file import HATFile, Function, Parameter
|
||||
from .arg_info import ArgInfo, verify_args
|
||||
|
||||
import os
|
||||
|
||||
|
||||
class HATPackage:
|
||||
|
||||
def __init__(self, hat_file_path):
|
||||
"""A HAT Package is defined to be a HAT file and corresponding binary file, located in the same directory.
|
||||
The binary file is specified in the HAT file's link_target attribute.
|
||||
|
@ -18,16 +26,18 @@ class HATPackage:
|
|||
self.hat_file = HATFile.Deserialize(hat_file_path)
|
||||
|
||||
self.link_target = self.hat_file.dependencies.link_target
|
||||
self.link_target_path = os.path.join(os.path.split(self.hat_file_path)[0], self.hat_file.dependencies.link_target)
|
||||
if not os.path.isfile(self.link_target_path):
|
||||
raise ValueError(f"HAT file {self.hat_file_path} references link_target {self.hat_file.dependencies.link_target} which is not found in same directory as HAT file (expecting it to be in {os.path.split(self.hat_file_path)[0]}")
|
||||
self.link_target_path = os.path.join(
|
||||
os.path.split(self.hat_file_path)[0], self.hat_file.dependencies.link_target
|
||||
)
|
||||
|
||||
self.functions = self.hat_file.functions
|
||||
|
||||
def get_functions(self):
|
||||
return self.hat_file.functions
|
||||
|
||||
def get_functions_for_target(self, os: str, arch: str, required_extensions:list = []):
|
||||
def get_functions_for_target(self, os: str, arch: str, required_extensions: list = []):
|
||||
all_functions = self.get_functions()
|
||||
|
||||
def matches_target(hat_function):
|
||||
hat_file = hat_function.hat_file
|
||||
if hat_file.target.required.os != os or hat_file.target.required.cpu.architecture != arch:
|
||||
|
@ -38,3 +48,140 @@ class HATPackage:
|
|||
return True
|
||||
|
||||
return list(filter(matches_target, all_functions))
|
||||
|
||||
|
||||
class AttributeDict(OrderedDict):
|
||||
""" Dictionary that allows entries to be accessed like attributes
|
||||
"""
|
||||
__getattr__ = OrderedDict.__getitem__
|
||||
|
||||
@property
|
||||
def names(self):
|
||||
return list(self.keys())
|
||||
|
||||
def __getitem__(self, key):
|
||||
for k, v in self.items():
|
||||
if k.startswith(key):
|
||||
return v
|
||||
return OrderedDict.__getitem__(key)
|
||||
|
||||
|
||||
def _make_cpu_func(shared_lib: ctypes.CDLL, function_name: str, arg_infos: List[Parameter]):
|
||||
arg_infos = [ArgInfo(d) for d in arg_infos]
|
||||
fn = shared_lib[function_name]
|
||||
|
||||
def f(*args):
|
||||
# verify that the (numpy) input args match the description in
|
||||
# the hat file
|
||||
verify_args(args, arg_infos, function_name)
|
||||
|
||||
# prepare the args to the hat package
|
||||
hat_args = [arg.ctypes.data_as(arg_info.ctypes_pointer_type) for arg, arg_info in zip(args, arg_infos)]
|
||||
|
||||
# call the function in the hat package
|
||||
fn(*hat_args)
|
||||
|
||||
return f
|
||||
|
||||
|
||||
def _make_device_func(func_runtime: str, hat_dir_path: str, func: Function):
|
||||
if func_runtime == "CUDA":
|
||||
from . import cuda_loader
|
||||
return cuda_loader.create_loader_for_device_function(func, hat_dir_path)
|
||||
elif func_runtime == "ROCM":
|
||||
from . import rocm_loader
|
||||
return rocm_loader.create_loader_for_device_function(func, hat_dir_path)
|
||||
|
||||
|
||||
def _load_pkg_binary_module(hat_pkg: HATPackage):
|
||||
shared_lib = None
|
||||
if os.path.isfile(hat_pkg.link_target_path):
|
||||
|
||||
supported_extensions = [".dll", ".so"]
|
||||
_, extension = os.path.splitext(hat_pkg.link_target_path)
|
||||
|
||||
if extension and extension not in supported_extensions:
|
||||
# TODO: Should this be an error? Maybe just move on to the
|
||||
# device function section?
|
||||
raise RuntimeError(f"Unsupported HAT library extension: {extension}")
|
||||
|
||||
hat_binary_path = os.path.abspath(hat_pkg.link_target_path)
|
||||
|
||||
# load the hat_library:
|
||||
hat_library = ctypes.cdll.LoadLibrary(hat_binary_path) if extension else None
|
||||
shared_lib = hat_library
|
||||
|
||||
return shared_lib
|
||||
|
||||
|
||||
def hat_package_to_func_dict(hat_pkg: HATPackage) -> AttributeDict:
|
||||
|
||||
try:
|
||||
try:
|
||||
from . import cuda_loader
|
||||
except ModuleNotFoundError:
|
||||
import cuda_loader
|
||||
except:
|
||||
CUDA_AVAILABLE = False
|
||||
else:
|
||||
CUDA_AVAILABLE = True
|
||||
|
||||
try:
|
||||
try:
|
||||
from . import rocm_loader
|
||||
except ModuleNotFoundError:
|
||||
import rocm_loader
|
||||
except:
|
||||
ROCM_AVAILABLE = False
|
||||
else:
|
||||
ROCM_AVAILABLE = True
|
||||
|
||||
NOTIFY_ABOUT_CUDA = not CUDA_AVAILABLE
|
||||
NOTIFY_ABOUT_ROCM = not ROCM_AVAILABLE
|
||||
|
||||
# check that the HAT library has a supported file extension
|
||||
func_dict = AttributeDict()
|
||||
shared_lib = _load_pkg_binary_module(hat_pkg)
|
||||
hat_dir_path, _ = os.path.split(hat_pkg.hat_file_path)
|
||||
|
||||
for func_name, func_desc in hat_pkg.hat_file.function_map.items():
|
||||
|
||||
launches = func_desc.launches
|
||||
if not launches and shared_lib:
|
||||
|
||||
func_dict[func_name] = _make_cpu_func(shared_lib, func_desc.name, func_desc.arguments)
|
||||
else:
|
||||
device_func = hat_pkg.hat_file.device_function_map.get(launches)
|
||||
|
||||
func_runtime = func_desc.runtime
|
||||
if not device_func:
|
||||
raise RuntimeError(f"Couldn't find device function for loader: " + launches)
|
||||
if not func_runtime:
|
||||
raise RuntimeError(f"Couldn't find runtime for loader: " + launches)
|
||||
|
||||
# TODO: Generalize this concept to work so it's not CUDA/ROCM specific
|
||||
if func_runtime == "CUDA" and not CUDA_AVAILABLE:
|
||||
|
||||
# TODO: printing to stdout only makes sense in tool mode
|
||||
if NOTIFY_ABOUT_CUDA:
|
||||
print(
|
||||
"CUDA functionality not available on this machine. Please install the cuda and pvnrtc python modules"
|
||||
)
|
||||
NOTIFY_ABOUT_CUDA = False
|
||||
|
||||
continue
|
||||
|
||||
elif func_runtime == "ROCM" and not ROCM_AVAILABLE:
|
||||
|
||||
# TODO: printing to stdout only makes sense in tool mode
|
||||
if NOTIFY_ABOUT_ROCM:
|
||||
print("ROCm functionality not available on this machine. Please install ROCm 4.2 or higher")
|
||||
NOTIFY_ABOUT_ROCM = False
|
||||
|
||||
continue
|
||||
|
||||
func_dict[func_name] = _make_device_func(
|
||||
func_runtime=func_runtime, hat_dir_path=hat_dir_path, func=device_func
|
||||
)
|
||||
|
||||
return func_dict
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
"""Converts a statically-linked HAT package into a Dynamically-linked HAT package
|
||||
|
||||
HAT packages come in two varieties: statically-linked and dynamically-linked. A statically-linked
|
||||
|
@ -24,21 +23,23 @@ import argparse
|
|||
import shutil
|
||||
from secrets import token_hex
|
||||
|
||||
if __package__:
|
||||
from .hat_file import HATFile, OperatingSystem
|
||||
from .platform_utilities import get_platform, ensure_compiler_in_path, run_command
|
||||
else:
|
||||
from hat_file import HATFile, OperatingSystem
|
||||
from platform_utilities import get_platform, ensure_compiler_in_path, run_command
|
||||
from .hat_file import HATFile, OperatingSystem
|
||||
from .hat_package import HATPackage
|
||||
from .platform_utilities import get_platform, ensure_compiler_in_path, run_command
|
||||
|
||||
|
||||
def linux_create_dynamic_package(input_hat_path, input_hat_binary_path, output_hat_path, hat_file, quiet=True):
|
||||
def linux_create_dynamic_package(input_hat_path,
|
||||
input_hat_binary_path,
|
||||
output_hat_path,
|
||||
hat_file,
|
||||
quiet=True):
|
||||
"""Creates a dynamic HAT (.so) from a static HAT (.o/.a) on a Linux/macOS platform"""
|
||||
# Confirm that this is a static hat library
|
||||
_, extension = os.path.splitext(input_hat_binary_path)
|
||||
if extension not in [".o", ".a"]:
|
||||
sys.exit(
|
||||
f"ERROR: Expected input library to have extension .o or .a, but received {input_hat_binary_path} instead")
|
||||
f"ERROR: Expected input library to have extension .o or .a, but received {input_hat_binary_path} instead"
|
||||
)
|
||||
|
||||
# Create a C source file to resolve inline functions defined in the static HAT package
|
||||
include_path = os.path.dirname(input_hat_binary_path)
|
||||
|
@ -48,7 +49,8 @@ def linux_create_dynamic_package(input_hat_path, input_hat_binary_path, output_h
|
|||
f.write(f"#include <{os.path.basename(input_hat_path)}>")
|
||||
# compile it separately so that we can suppress the warnings about the missing terminating ' character
|
||||
run_command(
|
||||
f'gcc -c -w -fPIC -o "{inline_obj_path}" -I"{include_path}" "{inline_c_path}"', quiet=quiet)
|
||||
f'gcc -c -w -fPIC -o "{inline_obj_path}" -I"{include_path}" "{inline_c_path}"',
|
||||
quiet=quiet)
|
||||
|
||||
# create new HAT binary
|
||||
prefix, _ = os.path.splitext(output_hat_path)
|
||||
|
@ -58,7 +60,8 @@ def linux_create_dynamic_package(input_hat_path, input_hat_binary_path, output_h
|
|||
libraries = " ".join(
|
||||
[d.target_file for d in hat_file.dependencies.dynamic])
|
||||
run_command(
|
||||
f'gcc -shared -fPIC -o "{output_hat_binary_path}" "{inline_obj_path}" "{input_hat_binary_path}" {libraries}', quiet=quiet)
|
||||
f'gcc -shared -fPIC -o "{output_hat_binary_path}" "{inline_obj_path}" "{input_hat_binary_path}" {libraries}',
|
||||
quiet=quiet)
|
||||
|
||||
# create new HAT file
|
||||
# previous dependencies are now part of the binary
|
||||
|
@ -67,15 +70,22 @@ def linux_create_dynamic_package(input_hat_path, input_hat_binary_path, output_h
|
|||
output_hat_binary_path)
|
||||
hat_file.Serialize(output_hat_path)
|
||||
|
||||
return HATPackage(output_hat_path)
|
||||
|
||||
def windows_create_dynamic_package(input_hat_path, input_hat_binary_path, output_hat_path, hat_file, quiet=True):
|
||||
|
||||
def windows_create_dynamic_package(input_hat_path,
|
||||
input_hat_binary_path,
|
||||
output_hat_path,
|
||||
hat_file,
|
||||
quiet=True):
|
||||
"""Creates a Windows dynamic HAT package (.dll) from a static HAT package (.obj/.lib)"""
|
||||
|
||||
# Confirm that this is a static hat library
|
||||
_, extension = os.path.splitext(input_hat_binary_path)
|
||||
if extension not in [".obj", ".lib"]:
|
||||
sys.exit(
|
||||
f"ERROR: Expected input library to have extension .obj or .lib, but received {input_hat_binary_path} instead")
|
||||
f"ERROR: Expected input library to have extension .obj or .lib, but received {input_hat_binary_path} instead"
|
||||
)
|
||||
|
||||
# Create all file in a directory named build
|
||||
if not os.path.exists("build"):
|
||||
|
@ -91,9 +101,11 @@ def windows_create_dynamic_package(input_hat_path, input_hat_binary_path, output
|
|||
# Resolve inline functions defined in the static HAT package
|
||||
f.write("#include <{}>\n".format(os.path.basename(input_hat_path)))
|
||||
f.write(
|
||||
"BOOL APIENTRY DllMain(HMODULE, DWORD, LPVOID) { return TRUE; }\n")
|
||||
"BOOL APIENTRY DllMain(HMODULE, DWORD, LPVOID) { return TRUE; }\n"
|
||||
)
|
||||
run_command(
|
||||
f'cl.exe /nologo /I"{os.path.dirname(input_hat_path)}" /Fodllmain.obj /c dllmain.cpp', quiet=quiet)
|
||||
f'cl.exe /nologo /I"{os.path.dirname(input_hat_path)}" /Fodllmain.obj /c dllmain.cpp',
|
||||
quiet=quiet)
|
||||
|
||||
# create the new HAT binary dll
|
||||
# always create a new dll (avoids case where dll is already loaded)
|
||||
|
@ -121,36 +133,54 @@ def windows_create_dynamic_package(input_hat_path, input_hat_binary_path, output
|
|||
finally:
|
||||
os.chdir(cwd) # restore the current working directory
|
||||
|
||||
return HATPackage(output_hat_path)
|
||||
|
||||
|
||||
def parse_args():
|
||||
"""Parses and checks the command line arguments"""
|
||||
parser = argparse.ArgumentParser(description="Creates a dynamically-linked HAT package from a statically-linked HAT package.\n"
|
||||
"Example:\n"
|
||||
" hatlib.hat_to_dynamic input.hat output.hat\n")
|
||||
parser = argparse.ArgumentParser(
|
||||
description=
|
||||
"Creates a dynamically-linked HAT package from a statically-linked HAT package.\n"
|
||||
"Example:\n"
|
||||
" hatlib.hat_to_dynamic input.hat output.hat\n")
|
||||
|
||||
parser.add_argument("input_hat_path", type=str,
|
||||
help="Path to the existing HAT file, which represents a statically-linked HAT package")
|
||||
parser.add_argument("output_hat_path", type=str,
|
||||
help="Path to the new HAT file, which will represent a dynamically-linked HAT package")
|
||||
parser.add_argument('-v', "--verbose", action='store_true', help="Enable verbose output")
|
||||
parser.add_argument(
|
||||
"input_hat_path",
|
||||
type=str,
|
||||
help=
|
||||
"Path to the existing HAT file, which represents a statically-linked HAT package"
|
||||
)
|
||||
parser.add_argument(
|
||||
"output_hat_path",
|
||||
type=str,
|
||||
help=
|
||||
"Path to the new HAT file, which will represent a dynamically-linked HAT package"
|
||||
)
|
||||
parser.add_argument('-v',
|
||||
"--verbose",
|
||||
action='store_true',
|
||||
help="Enable verbose output")
|
||||
args = parser.parse_args()
|
||||
|
||||
# check args
|
||||
if not os.path.exists(args.input_hat_path):
|
||||
sys.exit(f"ERROR: File {args.input_hat_path} not found")
|
||||
|
||||
if os.path.abspath(args.input_hat_path) == os.path.abspath(args.output_hat_path):
|
||||
if os.path.abspath(args.input_hat_path) == os.path.abspath(
|
||||
args.output_hat_path):
|
||||
sys.exit("ERROR: Output file must be different from input file")
|
||||
|
||||
_, extension = os.path.splitext(args.input_hat_path)
|
||||
if extension != ".hat":
|
||||
sys.exit(
|
||||
f"ERROR: Expected input file to have extension .hat, but received {extension} instead")
|
||||
f"ERROR: Expected input file to have extension .hat, but received {extension} instead"
|
||||
)
|
||||
|
||||
_, extension = os.path.splitext(args.output_hat_path)
|
||||
if extension != ".hat":
|
||||
sys.exit(
|
||||
f"ERROR: Expected output file to have extension .hat, but received {extension} instead")
|
||||
f"ERROR: Expected output file to have extension .hat, but received {extension} instead"
|
||||
)
|
||||
|
||||
return args
|
||||
|
||||
|
@ -165,22 +195,30 @@ def create_dynamic_package(input_hat_path, output_hat_path, quiet=True):
|
|||
|
||||
# get the absolute path to the input binary
|
||||
input_hat_binary_filename = hat_file.dependencies.link_target
|
||||
input_hat_binary_path = os.path.join(
|
||||
os.path.dirname(input_hat_path), input_hat_binary_filename)
|
||||
input_hat_binary_path = os.path.join(os.path.dirname(input_hat_path),
|
||||
input_hat_binary_filename)
|
||||
|
||||
# create the dynamic package
|
||||
output_hat_path = os.path.abspath(output_hat_path)
|
||||
if platform == OperatingSystem.Windows:
|
||||
windows_create_dynamic_package(
|
||||
input_hat_path, input_hat_binary_path, output_hat_path, hat_file, quiet=quiet)
|
||||
windows_create_dynamic_package(input_hat_path,
|
||||
input_hat_binary_path,
|
||||
output_hat_path,
|
||||
hat_file,
|
||||
quiet=quiet)
|
||||
elif platform in [OperatingSystem.Linux, OperatingSystem.MacOS]:
|
||||
linux_create_dynamic_package(
|
||||
input_hat_path, input_hat_binary_path, output_hat_path, hat_file, quiet=quiet)
|
||||
linux_create_dynamic_package(input_hat_path,
|
||||
input_hat_binary_path,
|
||||
output_hat_path,
|
||||
hat_file,
|
||||
quiet=quiet)
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
create_dynamic_package(args.input_hat_path, args.output_hat_path, quiet=not args.verbose)
|
||||
create_dynamic_package(args.input_hat_path,
|
||||
args.output_hat_path,
|
||||
quiet=not args.verbose)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -23,12 +23,8 @@ import os
|
|||
import argparse
|
||||
import shutil
|
||||
|
||||
if __package__:
|
||||
from .hat_file import HATFile, OperatingSystem
|
||||
from .platform_utilities import get_platform, ensure_compiler_in_path, run_command
|
||||
else:
|
||||
from hat_file import HATFile, OperatingSystem
|
||||
from platform_utilities import get_platform, ensure_compiler_in_path, run_command
|
||||
from .hat_file import HATFile, OperatingSystem
|
||||
from .platform_utilities import get_platform, ensure_compiler_in_path, run_command
|
||||
|
||||
|
||||
def linux_create_static_package(input_hat_binary_path, output_hat_path, hat_file, quiet=True):
|
||||
|
|
|
@ -8,10 +8,7 @@ import shutil
|
|||
import subprocess
|
||||
import sys
|
||||
|
||||
if __package__:
|
||||
from .hat_file import OperatingSystem
|
||||
else:
|
||||
from hat_file import OperatingSystem
|
||||
from .hat_file import OperatingSystem
|
||||
|
||||
|
||||
def _preprocess_command(command_to_run, shell):
|
||||
|
@ -33,13 +30,8 @@ def _dump_file_contents(iostream):
|
|||
|
||||
|
||||
def run_command(
|
||||
command_to_run,
|
||||
working_directory=None,
|
||||
stdout=None,
|
||||
stderr=None,
|
||||
shell=False,
|
||||
pretend=False,
|
||||
quiet=True):
|
||||
command_to_run, working_directory=None, stdout=None, stderr=None, shell=False, pretend=False, quiet=True
|
||||
):
|
||||
if not working_directory:
|
||||
working_directory = os.getcwd()
|
||||
|
||||
|
@ -50,20 +42,14 @@ def run_command(
|
|||
command_to_run = _preprocess_command(command_to_run, shell)
|
||||
|
||||
if not pretend:
|
||||
with subprocess.Popen(
|
||||
command_to_run,
|
||||
close_fds=(platform.system() != "Windows"),
|
||||
shell=shell,
|
||||
stdout=stdout,
|
||||
stderr=stderr,
|
||||
cwd=working_directory) as proc:
|
||||
with subprocess.Popen(command_to_run, close_fds=(platform.system() != "Windows"), shell=shell, stdout=stdout,
|
||||
stderr=stderr, cwd=working_directory) as proc:
|
||||
|
||||
proc.wait()
|
||||
if proc.returncode:
|
||||
_dump_file_contents(stderr)
|
||||
_dump_file_contents(stdout)
|
||||
raise subprocess.CalledProcessError(
|
||||
proc.returncode, command_to_run)
|
||||
raise subprocess.CalledProcessError(proc.returncode, command_to_run)
|
||||
|
||||
|
||||
def get_platform():
|
||||
|
@ -84,8 +70,7 @@ def linux_ensure_compiler_in_path():
|
|||
Prompts the user if not found."""
|
||||
compiler = os.environ.get("CXX") or (os.environ.get("CC") or "gcc")
|
||||
if not shutil.which(compiler):
|
||||
sys.exit(
|
||||
'ERROR: Could not find any valid C or C++ compiler, please install gcc before continuing')
|
||||
sys.exit('ERROR: Could not find any valid C or C++ compiler, please install gcc before continuing')
|
||||
|
||||
|
||||
def windows_ensure_compiler_in_path():
|
||||
|
@ -94,15 +79,14 @@ def windows_ensure_compiler_in_path():
|
|||
import vswhere
|
||||
vs_path = vswhere.get_latest_path()
|
||||
if not vs_path:
|
||||
sys.exit(
|
||||
"ERROR: Could not find Visual Studio, please ensure that you have Visual Studio installed")
|
||||
sys.exit("ERROR: Could not find Visual Studio, please ensure that you have Visual Studio installed")
|
||||
|
||||
# Check if cl.exe is in PATH
|
||||
if not shutil.which("cl"): # returns 0 if found, !0 otherwise
|
||||
vcvars_script_path = os.path.join(
|
||||
vs_path, r"VC\Auxiliary\Build\vcvars64.bat")
|
||||
if not shutil.which("cl"): # returns 0 if found, !0 otherwise
|
||||
vcvars_script_path = os.path.join(vs_path, r"VC\Auxiliary\Build\vcvars64.bat")
|
||||
sys.exit(
|
||||
f'ERROR: Could not find cl.exe, please run "{vcvars_script_path}" (including quotes) to setup your command prompt')
|
||||
f'ERROR: Could not find cl.exe, please run "{vcvars_script_path}" (including quotes) to setup your command prompt'
|
||||
)
|
||||
|
||||
|
||||
def ensure_compiler_in_path():
|
||||
|
|
|
@ -4,7 +4,8 @@ lifted from https://github.com/jatinx/PyHIP
|
|||
TODO: move to a submodule
|
||||
"""
|
||||
|
||||
import sys, ctypes
|
||||
import ctypes
|
||||
import sys
|
||||
|
||||
_libhip_libname = 'libamdhip64.so'
|
||||
|
||||
|
@ -15,7 +16,7 @@ else:
|
|||
# Currently we do not support windows, mainly because I do not have a windows build of hip
|
||||
raise RuntimeError('Only linux is supported')
|
||||
|
||||
if _libhip == None:
|
||||
if _libhip is None:
|
||||
raise OSError('hiprtc library not found')
|
||||
|
||||
|
||||
|
@ -482,7 +483,7 @@ def hipMalloc(count, ctype=None):
|
|||
ptr = ctypes.c_void_p()
|
||||
status = _libhip.hipMalloc(ctypes.byref(ptr), count)
|
||||
hipCheckStatus(status)
|
||||
if ctype != None:
|
||||
if ctype is not None:
|
||||
ptr = ctypes.cast(ptr, ctypes.POINTER(ctype))
|
||||
return ptr
|
||||
|
|
@ -4,7 +4,8 @@ lifted from https://github.com/jatinx/PyHIP
|
|||
TODO: move to a submodule
|
||||
"""
|
||||
|
||||
import sys, ctypes
|
||||
import ctypes
|
||||
import sys
|
||||
|
||||
# _libhiprtc_libname = 'libhiprtc.so' # Currently its the same library
|
||||
_libhiprtc_libname = 'libamdhip64.so'
|
||||
|
@ -16,7 +17,7 @@ else:
|
|||
# Currently we do not support windows, mainly because I do not have a windows build of hip
|
||||
raise RuntimeError('Only linux is supported')
|
||||
|
||||
if _libhiprtc == None:
|
||||
if _libhiprtc is None:
|
||||
raise OSError('hiprtc library not found')
|
||||
|
||||
|
|
@ -4,21 +4,15 @@ import numpy as np
|
|||
from functools import reduce
|
||||
from typing import List
|
||||
|
||||
try:
|
||||
from .arg_info import ArgInfo, verify_args
|
||||
from .gpu_headers import ROCM_HEADER_MAP
|
||||
from .pyhip_hip import *
|
||||
from .pyhip_hiprtc import *
|
||||
except ModuleNotFoundError:
|
||||
from arg_info import ArgInfo, verify_args
|
||||
from gpu_headers import ROCM_HEADER_MAP
|
||||
from pyhip_hip import *
|
||||
from pyhip_hiprtc import *
|
||||
from .arg_info import ArgInfo, verify_args
|
||||
from .hat_file import Function
|
||||
from .gpu_headers import ROCM_HEADER_MAP
|
||||
from .pyhip.hip import *
|
||||
from .pyhip.hiprtc import *
|
||||
|
||||
|
||||
def _arg_size(arg_info: ArgInfo):
|
||||
return arg_info.element_num_bytes * reduce(lambda x, y: x * y,
|
||||
arg_info.numpy_shape)
|
||||
return arg_info.element_num_bytes * reduce(lambda x, y: x * y, arg_info.numpy_shape)
|
||||
|
||||
|
||||
def initialize_rocm():
|
||||
|
@ -29,15 +23,14 @@ def initialize_rocm():
|
|||
def compile_rocm_program(rocm_src_path: pathlib.Path, func_name):
|
||||
src = rocm_src_path.read_text()
|
||||
|
||||
prog = hiprtcCreateProgram(source=src,
|
||||
name=func_name + ".cu",
|
||||
header_names=ROCM_HEADER_MAP.keys(),
|
||||
header_sources=ROCM_HEADER_MAP.values())
|
||||
prog = hiprtcCreateProgram(
|
||||
source=src,
|
||||
name=func_name + ".cu",
|
||||
header_names=ROCM_HEADER_MAP.keys(),
|
||||
header_sources=ROCM_HEADER_MAP.values()
|
||||
)
|
||||
device_properties = hipGetDeviceProperties(0)
|
||||
hiprtcCompileProgram(prog, [
|
||||
f'--offload-arch={device_properties.gcnArchName}',
|
||||
'-D__HIP_PLATFORM_AMD__'
|
||||
])
|
||||
hiprtcCompileProgram(prog, [f'--offload-arch={device_properties.gcnArchName}', '-D__HIP_PLATFORM_AMD__'])
|
||||
print(hiprtcGetProgramLog(prog))
|
||||
code = hiprtcGetCode(prog)
|
||||
|
||||
|
@ -59,24 +52,16 @@ def allocate_rocm_mem(arg_infos: List[ArgInfo]):
|
|||
return device_mem
|
||||
|
||||
|
||||
def transfer_mem_host_to_rocm(device_args: List, host_args: List[np.array],
|
||||
arg_infos: List[ArgInfo]):
|
||||
for device_arg, host_arg, arg_info in zip(device_args, host_args,
|
||||
arg_infos):
|
||||
if 'input' in arg_info.usage:
|
||||
hipMemcpy_htod(dst=device_arg,
|
||||
src=host_arg.ctypes.data,
|
||||
count=_arg_size(arg_info))
|
||||
def transfer_mem_host_to_rocm(device_args: List, host_args: List[np.array], arg_infos: List[ArgInfo]):
|
||||
for device_arg, host_arg, arg_info in zip(device_args, host_args, arg_infos):
|
||||
if 'input' in arg_info.usage.value:
|
||||
hipMemcpy_htod(dst=device_arg, src=host_arg.ctypes.data, count=_arg_size(arg_info))
|
||||
|
||||
|
||||
def transfer_mem_rocm_to_host(device_args: List, host_args: List[np.array],
|
||||
arg_infos: List[ArgInfo]):
|
||||
for device_arg, host_arg, arg_info in zip(device_args, host_args,
|
||||
arg_infos):
|
||||
if 'output' in arg_info.usage:
|
||||
hipMemcpy_dtoh(dst=host_arg.ctypes.data,
|
||||
src=device_arg,
|
||||
count=_arg_size(arg_info))
|
||||
def transfer_mem_rocm_to_host(device_args: List, host_args: List[np.array], arg_infos: List[ArgInfo]):
|
||||
for device_arg, host_arg, arg_info in zip(device_args, host_args, arg_infos):
|
||||
if 'output' in arg_info.usage.value:
|
||||
hipMemcpy_dtoh(dst=host_arg.ctypes.data, src=device_arg, count=_arg_size(arg_info))
|
||||
|
||||
|
||||
def device_args_to_ptr_list(device_args: List):
|
||||
|
@ -86,10 +71,12 @@ def device_args_to_ptr_list(device_args: List):
|
|||
return ptrs
|
||||
|
||||
|
||||
def create_loader_for_device_function(device_func, hat_details):
|
||||
hat_path: pathlib.Path = hat_details.path
|
||||
rocm_src_path: pathlib.Path = hat_path.parent / device_func["provider"]
|
||||
func_name = device_func["name"]
|
||||
def create_loader_for_device_function(device_func: Function, hat_dir_path: str):
|
||||
if not device_func.provider:
|
||||
raise RuntimeError("Expected a provider for the device function")
|
||||
|
||||
rocm_src_path: pathlib.Path = pathlib.Path(hat_dir_path) / device_func.provider
|
||||
func_name = device_func.name
|
||||
|
||||
rocm_program = compile_rocm_program(rocm_src_path, func_name)
|
||||
|
||||
|
@ -97,33 +84,28 @@ def create_loader_for_device_function(device_func, hat_details):
|
|||
|
||||
kernel = get_func_from_rocm_program(rocm_program, func_name)
|
||||
|
||||
hat_arg_descriptions = device_func["arguments"]
|
||||
hat_arg_descriptions = device_func.arguments
|
||||
arg_infos = [ArgInfo(d) for d in hat_arg_descriptions]
|
||||
launch_parameters = device_func["launch_parameters"]
|
||||
launch_parameters = device_func.launch_parameters
|
||||
|
||||
class DataStruct(ctypes.Structure):
|
||||
_fields_ = [(f"arg{i}", ctypes.c_void_p)
|
||||
for i in range(len(arg_infos))]
|
||||
_fields_ = [(f"arg{i}", ctypes.c_void_p) for i in range(len(arg_infos))]
|
||||
|
||||
def f(*args):
|
||||
verify_args(args, arg_infos, func_name)
|
||||
device_mem = allocate_rocm_mem(arg_infos)
|
||||
transfer_mem_host_to_rocm(device_args=device_mem,
|
||||
host_args=args,
|
||||
arg_infos=arg_infos)
|
||||
transfer_mem_host_to_rocm(device_args=device_mem, host_args=args, arg_infos=arg_infos)
|
||||
data = DataStruct(*device_mem)
|
||||
|
||||
hipModuleLaunchKernel(
|
||||
kernel,
|
||||
*launch_parameters, # [ grid[x-z], block[x-z] ]
|
||||
0, # dynamic shared memory
|
||||
0, # stream
|
||||
data, # data
|
||||
*launch_parameters, # [ grid[x-z], block[x-z] ]
|
||||
0, # dynamic shared memory
|
||||
0, # stream
|
||||
data, # data
|
||||
)
|
||||
hipDeviceSynchronize()
|
||||
|
||||
transfer_mem_rocm_to_host(device_args=device_mem,
|
||||
host_args=args,
|
||||
arg_infos=arg_infos)
|
||||
transfer_mem_rocm_to_host(device_args=device_mem, host_args=args, arg_infos=arg_infos)
|
||||
|
||||
return f
|
||||
|
|
|
@ -5,10 +5,7 @@ from ast import arg
|
|||
import enum
|
||||
import sys
|
||||
|
||||
if __package__:
|
||||
from . import hat
|
||||
else:
|
||||
import hat
|
||||
from . import hat
|
||||
|
||||
|
||||
def verify_hat_package(hat_path):
|
||||
|
@ -20,29 +17,25 @@ def verify_hat_package(hat_path):
|
|||
|
||||
print("Inputs before function call:")
|
||||
for i, func_input in enumerate(func_inputs):
|
||||
print(
|
||||
f"\tInput {i}: {','.join(map(str, func_input.flatten()[:32]))}"
|
||||
)
|
||||
print(f"\tInput {i}: {','.join(map(str, func_input.ravel()[:32]))}")
|
||||
|
||||
fn(*inputs[name])
|
||||
|
||||
print("Inputs after function call:")
|
||||
for i, func_input in enumerate(func_inputs):
|
||||
print(
|
||||
f"\tInput {i}: {','.join(map(str, func_input.flatten()[:32]))}"
|
||||
)
|
||||
print(f"\tInput {i}: {','.join(map(str, func_input.ravel()[:32]))}")
|
||||
|
||||
|
||||
def main():
|
||||
arg_parser = argparse.ArgumentParser(
|
||||
description="Executes every available function in the hat package \
|
||||
with randomized inputs. Meant for quick verification.\n"
|
||||
"Example:\n"
|
||||
" hatlib.verify_hat_package <hat_path>\n")
|
||||
description=(
|
||||
"Executes every available function in the hat package with randomized inputs. Meant for quick verification.\n"
|
||||
"Example:\n"
|
||||
" hatlib.verify_hat_package <hat_path>\n"
|
||||
)
|
||||
)
|
||||
|
||||
arg_parser.add_argument("hat_path",
|
||||
help="Path to the HAT file",
|
||||
default=None)
|
||||
arg_parser.add_argument("hat_path", help="Path to the HAT file", default=None)
|
||||
|
||||
args = vars(arg_parser.parse_args())
|
||||
verify_hat_package(args["hat_path"])
|
||||
|
|
|
@ -4,4 +4,21 @@ build-backend = "setuptools.build_meta"
|
|||
|
||||
[tool.setuptools_scm]
|
||||
version_scheme = "python-simplified-semver"
|
||||
local_scheme = "no-local-version"
|
||||
local_scheme = "no-local-version"
|
||||
|
||||
[tool.yapf]
|
||||
allow_multiline_dictionary_keys = false
|
||||
allow_split_before_dict_value = true
|
||||
based_on_style = "pep8"
|
||||
coalesce_brackets = true
|
||||
column_limit = 120
|
||||
dedent_closing_brackets = true
|
||||
each_dict_entry_on_separate_line = true
|
||||
force_multiline_dict = true
|
||||
indent_dictionary_value = true
|
||||
spaces_before_comment = 4
|
||||
split_before_first_argument = true
|
||||
split_before_logical_operator = true
|
||||
|
||||
[tool.pydocstyle]
|
||||
max-line-length = 120
|
||||
|
|
|
@ -36,3 +36,6 @@ console_scripts =
|
|||
hatlib.benchmark_hat = hatlib.benchmark_hat_package:main_command
|
||||
hatlib.hat_to_dynamic = hatlib.hat_to_dynamic:main
|
||||
hatlib.verify_hat = hatlib.verify_hat_package:main
|
||||
|
||||
[pycodestyle]
|
||||
max-line-length = 120
|
||||
|
|
|
@ -3,13 +3,11 @@ import unittest
|
|||
import os
|
||||
import sys
|
||||
import accera as acc
|
||||
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
|
||||
|
||||
from benchmark_hat_package import run_benchmark
|
||||
from hatlib import run_benchmark
|
||||
|
||||
|
||||
class BenchmarkHATPackage_test(unittest.TestCase):
|
||||
|
||||
def test_benchmark(self):
|
||||
A = acc.Array(role=acc.Array.Role.INPUT, shape=(256, 256))
|
||||
B = acc.Array(role=acc.Array.Role.INPUT, shape=(256, 256))
|
||||
|
@ -24,21 +22,23 @@ class BenchmarkHATPackage_test(unittest.TestCase):
|
|||
|
||||
package = acc.Package()
|
||||
package.add(nest, args=(A, B, C), base_name="test_function")
|
||||
package.build(name="BenchmarkHATPackage_test_benchmark",
|
||||
output_dir="test_acccgen",
|
||||
format=acc.Package.Format.HAT_DYNAMIC)
|
||||
package.build(
|
||||
name="BenchmarkHATPackage_test_benchmark", output_dir="test_acccgen", format=acc.Package.Format.HAT_DYNAMIC
|
||||
)
|
||||
|
||||
run_benchmark("test_acccgen/BenchmarkHATPackage_test_benchmark.hat",
|
||||
store_in_hat=False,
|
||||
batch_size=2,
|
||||
min_time_in_sec=1,
|
||||
input_sets_minimum_size_MB=1)
|
||||
run_benchmark(
|
||||
"test_acccgen/BenchmarkHATPackage_test_benchmark.hat",
|
||||
store_in_hat=False,
|
||||
batch_size=2,
|
||||
min_time_in_sec=1,
|
||||
input_sets_minimum_size_MB=1
|
||||
)
|
||||
|
||||
def test_benchmark_multiple_functions(self):
|
||||
A = acc.Array(role=acc.Array.Role.INPUT, shape=(256, 256))
|
||||
B = acc.Array(role=acc.Array.Role.INPUT, shape=(256, 256))
|
||||
C = acc.Array(role=acc.Array.Role.INPUT_OUTPUT, shape=(256, 256))
|
||||
D = acc.Array(role=acc.Array.Role.INPUT, shape=(256, 256)) # dummy argument
|
||||
D = acc.Array(role=acc.Array.Role.INPUT, shape=(256, 256)) # dummy argument
|
||||
|
||||
nest = acc.Nest(shape=(256, 256, 256))
|
||||
i, j, k = nest.get_indices()
|
||||
|
@ -54,15 +54,17 @@ class BenchmarkHATPackage_test(unittest.TestCase):
|
|||
# with the correct signature
|
||||
package.add(nest, args=(A, B, C, D), base_name="test_function_dummy")
|
||||
|
||||
package.build(name="BenchmarkHATPackage_test_benchmark",
|
||||
output_dir="test_acccgen",
|
||||
format=acc.Package.Format.HAT_DYNAMIC)
|
||||
package.build(
|
||||
name="BenchmarkHATPackage_test_benchmark", output_dir="test_acccgen", format=acc.Package.Format.HAT_DYNAMIC
|
||||
)
|
||||
|
||||
run_benchmark("test_acccgen/BenchmarkHATPackage_test_benchmark.hat",
|
||||
store_in_hat=False,
|
||||
batch_size=2,
|
||||
min_time_in_sec=1,
|
||||
input_sets_minimum_size_MB=1)
|
||||
run_benchmark(
|
||||
"test_acccgen/BenchmarkHATPackage_test_benchmark.hat",
|
||||
store_in_hat=False,
|
||||
batch_size=2,
|
||||
min_time_in_sec=1,
|
||||
input_sets_minimum_size_MB=1
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
|
@ -4,12 +4,7 @@ import numpy as np
|
|||
import os
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
|
||||
|
||||
from hat import load
|
||||
from hat_to_dynamic import create_dynamic_package
|
||||
from hat_to_lib import create_static_package
|
||||
from hatlib import load, create_dynamic_package, create_static_package
|
||||
|
||||
|
||||
class HAT_test(unittest.TestCase):
|
||||
|
@ -32,17 +27,13 @@ class HAT_test(unittest.TestCase):
|
|||
|
||||
for mode in [acc.Package.Mode.RELEASE, acc.Package.Mode.DEBUG]:
|
||||
package_name = f"HAT_test_load_{mode.value}"
|
||||
package.build(name=package_name,
|
||||
output_dir="test_acccgen",
|
||||
format=acc.Package.Format.HAT_STATIC,
|
||||
mode=mode)
|
||||
package.build(name=package_name, output_dir="test_acccgen", format=acc.Package.Format.HAT_STATIC, mode=mode)
|
||||
|
||||
create_dynamic_package(f"test_acccgen/{package_name}.hat",
|
||||
f"test_acccgen/{package_name}.dyn.hat")
|
||||
create_dynamic_package(f"test_acccgen/{package_name}.hat", f"test_acccgen/{package_name}.dyn.hat")
|
||||
|
||||
hat_package = load(f"test_acccgen/{package_name}.dyn.hat")
|
||||
_, func_dict = load(f"test_acccgen/{package_name}.dyn.hat")
|
||||
|
||||
for name in hat_package.names:
|
||||
for name in func_dict.names:
|
||||
print(name)
|
||||
|
||||
# create numpy arguments with the correct shape and dtype
|
||||
|
@ -51,7 +42,7 @@ class HAT_test(unittest.TestCase):
|
|||
B_ref = B + A
|
||||
|
||||
# find the function by basename
|
||||
test_function = hat_package["test_function"]
|
||||
test_function = func_dict["test_function"]
|
||||
test_function(A, B)
|
||||
|
||||
# check for correctness
|
||||
|
@ -59,7 +50,7 @@ class HAT_test(unittest.TestCase):
|
|||
|
||||
# find the function by actual name
|
||||
B_ref = B + A
|
||||
test_function1 = hat_package[function.name]
|
||||
test_function1 = func_dict[function.name]
|
||||
test_function1(A, B)
|
||||
|
||||
# check for correctness
|
|
@ -4,13 +4,10 @@ import os
|
|||
import sys
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
|
||||
|
||||
from hat_file import (CallingConventionType, CompiledWith, Declaration,
|
||||
Dependencies, Description, Function, FunctionTable,
|
||||
HATFile, OperatingSystem, Parameter, ParameterType,
|
||||
Target, UsageType)
|
||||
from hatlib import (
|
||||
CallingConventionType, CompiledWith, Declaration, Dependencies, Description, Function, FunctionTable, HATFile,
|
||||
OperatingSystem, Parameter, ParameterType, Target, UsageType
|
||||
)
|
||||
|
||||
|
||||
class HATFile_test(unittest.TestCase):
|
||||
|
@ -22,34 +19,38 @@ class HATFile_test(unittest.TestCase):
|
|||
name="my_function",
|
||||
description="Some description",
|
||||
calling_convention=CallingConventionType.StdCall,
|
||||
return_info=Parameter(logical_type=ParameterType.RuntimeArray,
|
||||
declared_type="float*",
|
||||
element_type="float",
|
||||
usage=UsageType.Input,
|
||||
shape="[16, 16]",
|
||||
affine_map=[16, 1],
|
||||
size="16 * 16 * sizeof(float)"))
|
||||
return_info=Parameter(
|
||||
logical_type=ParameterType.RuntimeArray,
|
||||
declared_type="float*",
|
||||
element_type="float",
|
||||
usage=UsageType.Input,
|
||||
shape="[16, 16]",
|
||||
affine_map=[16, 1],
|
||||
size="16 * 16 * sizeof(float)"
|
||||
)
|
||||
)
|
||||
# Create the function table
|
||||
functions = FunctionTable({"my_function": my_function})
|
||||
# Create the HATFile object
|
||||
hat_file1 = HATFile(
|
||||
name="test_file",
|
||||
description=Description(
|
||||
version="0.0.1",
|
||||
author="me",
|
||||
license_url="https://www.apache.org/licenses/LICENSE-2.0.html"
|
||||
version="0.0.1", author="me", license_url="https://www.apache.org/licenses/LICENSE-2.0.html"
|
||||
),
|
||||
_function_table=functions,
|
||||
target=Target(required=Target.Required(os=OperatingSystem.Windows,
|
||||
cpu=Target.Required.CPU(
|
||||
architecture="Haswell",
|
||||
extensions=["AVX2"]),
|
||||
gpu=None),
|
||||
optimized_for=Target.OptimizedFor()),
|
||||
target=Target(
|
||||
required=Target.Required(
|
||||
os=OperatingSystem.Windows,
|
||||
cpu=Target.Required.CPU(architecture="Haswell", extensions=["AVX2"]),
|
||||
gpu=None
|
||||
),
|
||||
optimized_for=Target.OptimizedFor()
|
||||
),
|
||||
dependencies=Dependencies(link_target="my_lib.lib"),
|
||||
compiled_with=CompiledWith(compiler="VC++"),
|
||||
declaration=Declaration(),
|
||||
path=Path(".").resolve())
|
||||
path=Path(".").resolve()
|
||||
)
|
||||
# Serialize it to disk
|
||||
test_file_name = "test_file_serialize.hat"
|
||||
|
||||
|
@ -65,15 +66,14 @@ class HATFile_test(unittest.TestCase):
|
|||
# specified when we created the HATFile directly
|
||||
self.assertEqual(hat_file1.description, hat_file2.description)
|
||||
self.assertEqual(hat_file1.dependencies, hat_file2.dependencies)
|
||||
self.assertEqual(hat_file1.compiled_with.to_table(),
|
||||
hat_file2.compiled_with.to_table())
|
||||
self.assertEqual(hat_file1.compiled_with.to_table(), hat_file2.compiled_with.to_table())
|
||||
self.assertTrue("my_function" in hat_file2.function_map)
|
||||
|
||||
def test_file_basic_deserialize(self):
|
||||
# Load a HAT file from the samples directory
|
||||
hat_file1 = HATFile.Deserialize(
|
||||
os.path.join(os.path.dirname(__file__), "..", "..", "samples",
|
||||
"sample_gemm_library.hat"))
|
||||
os.path.join(os.path.dirname(__file__), "..", "samples", "sample_gemm_library.hat")
|
||||
)
|
||||
description = {
|
||||
"author": "John Doe",
|
||||
"version": "1.2.3.5",
|
||||
|
@ -82,8 +82,7 @@ class HATFile_test(unittest.TestCase):
|
|||
|
||||
# Do basic verification of known values in the file
|
||||
# Verify the description has entries we expect
|
||||
self.assertLessEqual(description.items(),
|
||||
hat_file1.description.to_table().items())
|
||||
self.assertLessEqual(description.items(), hat_file1.description.to_table().items())
|
||||
# Verify the list of functions
|
||||
self.assertTrue(len(hat_file1.functions) == 2)
|
||||
self.assertTrue("GEMM_B94D27B9934D3E08" in hat_file1.function_map)
|
|
@ -4,13 +4,10 @@ import sys
|
|||
import os
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
|
||||
|
||||
from hat_file import (CallingConventionType, CompiledWith, Declaration,
|
||||
Dependencies, Description, Function, FunctionTable,
|
||||
HATFile, OperatingSystem, Parameter, ParameterType,
|
||||
Target, UsageType)
|
||||
from hatlib import (
|
||||
CallingConventionType, CompiledWith, Declaration, Dependencies, Description, Function, FunctionTable, HATFile,
|
||||
OperatingSystem, Parameter, ParameterType, Target, UsageType
|
||||
)
|
||||
|
||||
|
||||
class HATFile_test(unittest.TestCase):
|
||||
|
@ -22,34 +19,38 @@ class HATFile_test(unittest.TestCase):
|
|||
name="my_function",
|
||||
description="Some description",
|
||||
calling_convention=CallingConventionType.StdCall,
|
||||
return_info=Parameter(logical_type=ParameterType.RuntimeArray,
|
||||
declared_type="float*",
|
||||
element_type="float",
|
||||
usage=UsageType.Input,
|
||||
shape="[16, 16]",
|
||||
affine_map=[16, 1],
|
||||
size="16 * 16 * sizeof(float)"))
|
||||
return_info=Parameter(
|
||||
logical_type=ParameterType.RuntimeArray,
|
||||
declared_type="float*",
|
||||
element_type="float",
|
||||
usage=UsageType.Input,
|
||||
shape="[16, 16]",
|
||||
affine_map=[16, 1],
|
||||
size="16 * 16 * sizeof(float)"
|
||||
)
|
||||
)
|
||||
# Create the function table
|
||||
functions = FunctionTable({"my_function": my_function})
|
||||
# Create the HATFile object
|
||||
hat_file1 = HATFile(
|
||||
name="test_file",
|
||||
description=Description(
|
||||
version="0.0.1",
|
||||
author="me",
|
||||
license_url="https://www.apache.org/licenses/LICENSE-2.0.html"
|
||||
version="0.0.1", author="me", license_url="https://www.apache.org/licenses/LICENSE-2.0.html"
|
||||
),
|
||||
_function_table=functions,
|
||||
target=Target(required=Target.Required(os=OperatingSystem.Windows,
|
||||
cpu=Target.Required.CPU(
|
||||
architecture="Haswell",
|
||||
extensions=["AVX2"]),
|
||||
gpu=None),
|
||||
optimized_for=Target.OptimizedFor()),
|
||||
target=Target(
|
||||
required=Target.Required(
|
||||
os=OperatingSystem.Windows,
|
||||
cpu=Target.Required.CPU(architecture="Haswell", extensions=["AVX2"]),
|
||||
gpu=None
|
||||
),
|
||||
optimized_for=Target.OptimizedFor()
|
||||
),
|
||||
dependencies=Dependencies(link_target="my_lib.lib"),
|
||||
compiled_with=CompiledWith(compiler="VC++"),
|
||||
declaration=Declaration(),
|
||||
path=Path(".").resolve())
|
||||
path=Path(".").resolve()
|
||||
)
|
||||
# Serialize it to disk
|
||||
test_file_name = "test_file_serialize.hat"
|
||||
|
||||
|
@ -65,15 +66,14 @@ class HATFile_test(unittest.TestCase):
|
|||
# specified when we created the HATFile directly
|
||||
self.assertEqual(hat_file1.description, hat_file2.description)
|
||||
self.assertEqual(hat_file1.dependencies, hat_file2.dependencies)
|
||||
self.assertEqual(hat_file1.compiled_with.to_table(),
|
||||
hat_file2.compiled_with.to_table())
|
||||
self.assertEqual(hat_file1.compiled_with.to_table(), hat_file2.compiled_with.to_table())
|
||||
self.assertTrue("my_function" in hat_file2.function_map)
|
||||
|
||||
def test_file_basic_deserialize(self):
|
||||
# Load a HAT file from the samples directory
|
||||
hat_file1 = HATFile.Deserialize(
|
||||
os.path.join(os.path.dirname(__file__), "..", "..", "samples",
|
||||
"sample_gemm_library.hat"))
|
||||
os.path.join(os.path.dirname(__file__), "..", "samples", "sample_gemm_library.hat")
|
||||
)
|
||||
description = {
|
||||
"author": "John Doe",
|
||||
"version": "1.2.3.5",
|
||||
|
@ -82,8 +82,7 @@ class HATFile_test(unittest.TestCase):
|
|||
|
||||
# Do basic verification of known values in the file
|
||||
# Verify the description has entries we expect
|
||||
self.assertLessEqual(description.items(),
|
||||
hat_file1.description.to_table().items())
|
||||
self.assertLessEqual(description.items(), hat_file1.description.to_table().items())
|
||||
# Verify the list of functions
|
||||
self.assertTrue(len(hat_file1.function_map) == 2)
|
||||
self.assertTrue("GEMM_B94D27B9934D3E08" in hat_file1.function_map)
|
Загрузка…
Ссылка в новой задаче