Make it easier to author a hat file from python

- Instantiate more default objects for simple cases
- Enable modifying the list of functions in the hat file
- Make hat file functions and function_map wrap the same underlying
  object
This commit is contained in:
Mason Remy 2022-06-09 01:05:07 -07:00
Родитель aea0a8d3fc
Коммит efa2e7063c
4 изменённых файлов: 220 добавлений и 21 удалений

Просмотреть файл

@ -62,6 +62,14 @@ class OperatingSystem(Enum):
MacOS = "macos"
Linux = "linux"
@staticmethod
def host():
import platform
platform_name = platform.system().lower()
if platform_name == "darwin":
return OperatingSystem.MacOS
return OperatingSystem(platform_name)
@dataclass
class AuxiliarySupportedTable:
@ -120,7 +128,7 @@ class Parameter:
usage: UsageType = None
# Affine array parameter keys
shape: str = ""
shape: list = field(default_factory=list)
affine_map: list = field(default_factory=list)
affine_offset: int = -1
@ -181,6 +189,14 @@ class Parameter:
return param
@staticmethod
def void():
return Parameter(
logical_type=ParameterType.Void,
declared_type="void",
element_type="void",
usage=UsageType.Output
)
@dataclass
class Function(AuxiliarySupportedTable):
@ -264,7 +280,6 @@ class FunctionTableCommon:
def __init__(self, function_map):
self.function_map = function_map
self.functions = self.function_map.values()
def to_table(self):
func_table = tomlkit.table()
@ -372,8 +387,8 @@ class Target:
)
TableName = "required"
os: OperatingSystem = None
cpu: CPU = None
os: OperatingSystem = OperatingSystem.host()
cpu: CPU = CPU()
gpu: GPU = None
def to_table(self):
@ -407,8 +422,8 @@ class Target:
pass
TableName = "target"
required: Required = None
optimized_for: OptimizedFor = None
required: Required = Required()
optimized_for: OptimizedFor = Required()
def to_table(self):
table = tomlkit.table()
@ -523,7 +538,8 @@ class Declaration:
def to_table(self):
table = tomlkit.table()
table.add("code", self.code)
code_str = tomlkit.string(self.code, multiline=True)
table.add("code", code_str)
return table
@staticmethod
@ -544,33 +560,86 @@ class HATFile:
some_hat_file.Serialize(`someFile.hat`)
"""
name: str = ""
description: Description = None
_function_table: FunctionTable = None
_device_function_table: DeviceFunctionTable = None
description: Description = Description()
_function_table: FunctionTable = FunctionTable({})
_device_function_table: DeviceFunctionTable = DeviceFunctionTable({})
functions: list = field(default_factory=list)
device_functions: list = field(default_factory=list)
function_map: Dict[str, Function] = field(default_factory=dict)
device_function_map: Dict[str, Function] = field(default_factory=dict)
target: Target = None
dependencies: Dependencies = None
compiled_with: CompiledWith = None
declaration: Declaration = None
target: Target = Target()
dependencies: Dependencies = Dependencies()
compiled_with: CompiledWith = CompiledWith()
declaration: Declaration = Declaration()
path: Path = None
HATPrologue = "\n#ifndef __{0}__\n#define __{0}__\n\n#ifdef TOML\n"
HATEpilogue = "\n#endif // TOML\n\n#endif // __{0}__\n"
def __post_init__(self):
self.functions = self._function_table.functions
self.function_map = self._function_table.function_map
for func in self.functions:
func.hat_file = self
func.link_target = Path(self.path).resolve().parent / self.dependencies.link_target
if not self._device_function_table:
self._device_function_table = DeviceFunctionTable({})
self.device_function_map = self._device_function_table.function_map
self.device_functions = self._device_function_table.functions
@property
def functions(self):
if self._function_table is None:
return []
return self._function_table.function_map.values()
@functions.setter
def functions(self, func_list_or_dict):
if isinstance(func_list_or_dict, property):
return
if isinstance(func_list_or_dict, dict):
name_to_func_map = func_list_or_dict
else:
name_to_func_map = { func.name : func for func in func_list_or_dict }
if self._function_table is None:
self._function_table = FunctionTable()
self._function_table.function_map = name_to_func_map
@property
def function_map(self):
if self._function_table is None:
return None
return self._function_table.function_map
@function_map.setter
def function_map(self, func_map):
if isinstance(func_map, property):
return
self._function_table.function_map = func_map
@property
def device_functions(self):
if self._device_function_table is None:
return []
return self._device_function_table.function_map.value()
@device_functions.setter
def device_functions(self, func_list_or_dict):
if isinstance(func_list_or_dict, property):
return
if isinstance(func_list_or_dict, dict):
name_to_func_map = func_list_or_dict
else:
name_to_func_map = { func.name : func for func in func_list_or_dict }
if self._device_function_table is None:
self._device_function_table = FunctionTable()
self._device_function_table.function_map = name_to_func_map
@property
def device_function_map(self):
if self._device_function_table is None:
return None
return self._device_function_table.function_map
@device_function_map.setter
def device_function_map(self, func_map):
if isinstance(func_map, property):
return
self._device_function_table.function_map = func_map
def Serialize(self, filepath=None):
"""Serilizes the HATFile to disk using the file location specified by `filepath`.

Просмотреть файл

@ -1,5 +1,5 @@
numpy
pandas
toml
tomlkit
tomlkit>=0.11.0
vswhere; sys_platform == 'win32'

1
test/.gitignore поставляемый Normal file
Просмотреть файл

@ -0,0 +1 @@
*.hat

Просмотреть файл

@ -0,0 +1,129 @@
#!/usr/bin/env python3
import unittest
import os
import hatlib as hat
SAMPLE_MATMUL_DECL_CODE = '''
#endif // TOML
#pragma once
#include <stdint.h>
#if defined(__cplusplus)
extern "C"
{
#endif // defined(__cplusplus)
//
// Functions
//
void MatMul(const float* A, const float* B, float* C);
#if defined(__cplusplus)
} // extern "C"
#endif // defined(__cplusplus)
#ifdef TOML
'''
class CreateSimpleHatFile_test(unittest.TestCase):
def test_create_simple_hat_file(self):
a_shape = (1024, 512)
a_strides = (a_shape[1], 1) # "first major" / "row major"
b_shape = (512, 256)
b_strides = (1, b_shape[0]) # "last major" / "column major"
c_shape = (1024, 256)
c_strides = (c_shape[1], 1) # "first major" / "row major"
param_A = hat.Parameter(
name="A",
description="the A input matrix argument",
logical_type=hat.ParameterType.AffineArray,
declared_type="float*",
element_type="float",
usage=hat.UsageType.Input,
# Affine array parameter keys
shape=a_shape,
affine_map=a_strides,
affine_offset=0
)
param_B = hat.Parameter(
name="B",
description="the B input matrix argument",
logical_type=hat.ParameterType.AffineArray,
declared_type="float*",
element_type="float",
usage=hat.UsageType.Input,
# Affine array parameter keys
shape=b_shape,
affine_map=b_strides,
affine_offset=0
)
param_C = hat.Parameter(
name="C",
description="the C input matrix argument",
logical_type=hat.ParameterType.AffineArray,
declared_type="float*",
element_type="float",
usage=hat.UsageType.InputOutput,
# Affine array parameter keys
shape=c_shape,
affine_map=c_strides,
affine_offset=0
)
arguments = [param_A, param_B, param_C]
return_arg = hat.Parameter.void()
func_name = "MatMul"
hat_function = hat.Function(
arguments=arguments,
calling_convention=hat.CallingConventionType.StdCall,
description="Sample matmul hat declaration",
name=func_name,
return_info=return_arg
)
auxiliary_key_name = "test_auxiliary_key"
hat_function.auxiliary[auxiliary_key_name] = { "name" : "matmul" }
link_target_path = "./fake_link_target.lib"
hat_file_path = "./test_simple_hat_path.hat"
new_hat_file = hat.HATFile(
name="simple_hat_file",
functions=[hat_function],
dependencies=hat.Dependencies(link_target=link_target_path),
declaration=hat.Declaration(code=SAMPLE_MATMUL_DECL_CODE),
path=hat_file_path
)
if os.path.exists(hat_file_path):
os.remove(hat_file_path)
new_hat_file.Serialize(hat_file_path)
self.assertTrue(os.path.exists(hat_file_path))
parsed_hat_file = hat.HATFile.Deserialize(hat_file_path)
self.assertTrue(func_name in parsed_hat_file.function_map)
self.assertEqual(parsed_hat_file.dependencies.link_target, link_target_path)
self.assertEqual(len(parsed_hat_file.function_map[func_name].arguments), 3)
self.assertEqual(parsed_hat_file.function_map[func_name].arguments[0].name, param_A.name)
self.assertEqual(parsed_hat_file.function_map[func_name].arguments[1].name, param_B.name)
self.assertEqual(parsed_hat_file.function_map[func_name].arguments[2].name, param_C.name)
self.assertEqual(parsed_hat_file.function_map[func_name].arguments[0].shape, list(param_A.shape))
self.assertEqual(parsed_hat_file.function_map[func_name].arguments[1].shape, list(param_B.shape))
self.assertEqual(parsed_hat_file.function_map[func_name].arguments[2].shape, list(param_C.shape))
if __name__ == '__main__':
unittest.main()