зеркало из https://github.com/microsoft/hat.git
Make it easier to author a hat file from python
- Instantiate more default objects for simple cases - Enable modifying the list of functions in the hat file - Make hat file functions and function_map wrap the same underlying object
This commit is contained in:
Родитель
aea0a8d3fc
Коммит
efa2e7063c
|
@ -62,6 +62,14 @@ class OperatingSystem(Enum):
|
|||
MacOS = "macos"
|
||||
Linux = "linux"
|
||||
|
||||
@staticmethod
|
||||
def host():
|
||||
import platform
|
||||
platform_name = platform.system().lower()
|
||||
if platform_name == "darwin":
|
||||
return OperatingSystem.MacOS
|
||||
return OperatingSystem(platform_name)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AuxiliarySupportedTable:
|
||||
|
@ -120,7 +128,7 @@ class Parameter:
|
|||
usage: UsageType = None
|
||||
|
||||
# Affine array parameter keys
|
||||
shape: str = ""
|
||||
shape: list = field(default_factory=list)
|
||||
affine_map: list = field(default_factory=list)
|
||||
affine_offset: int = -1
|
||||
|
||||
|
@ -181,6 +189,14 @@ class Parameter:
|
|||
|
||||
return param
|
||||
|
||||
@staticmethod
|
||||
def void():
|
||||
return Parameter(
|
||||
logical_type=ParameterType.Void,
|
||||
declared_type="void",
|
||||
element_type="void",
|
||||
usage=UsageType.Output
|
||||
)
|
||||
|
||||
@dataclass
|
||||
class Function(AuxiliarySupportedTable):
|
||||
|
@ -264,7 +280,6 @@ class FunctionTableCommon:
|
|||
|
||||
def __init__(self, function_map):
|
||||
self.function_map = function_map
|
||||
self.functions = self.function_map.values()
|
||||
|
||||
def to_table(self):
|
||||
func_table = tomlkit.table()
|
||||
|
@ -372,8 +387,8 @@ class Target:
|
|||
)
|
||||
|
||||
TableName = "required"
|
||||
os: OperatingSystem = None
|
||||
cpu: CPU = None
|
||||
os: OperatingSystem = OperatingSystem.host()
|
||||
cpu: CPU = CPU()
|
||||
gpu: GPU = None
|
||||
|
||||
def to_table(self):
|
||||
|
@ -407,8 +422,8 @@ class Target:
|
|||
pass
|
||||
|
||||
TableName = "target"
|
||||
required: Required = None
|
||||
optimized_for: OptimizedFor = None
|
||||
required: Required = Required()
|
||||
optimized_for: OptimizedFor = Required()
|
||||
|
||||
def to_table(self):
|
||||
table = tomlkit.table()
|
||||
|
@ -523,7 +538,8 @@ class Declaration:
|
|||
|
||||
def to_table(self):
|
||||
table = tomlkit.table()
|
||||
table.add("code", self.code)
|
||||
code_str = tomlkit.string(self.code, multiline=True)
|
||||
table.add("code", code_str)
|
||||
return table
|
||||
|
||||
@staticmethod
|
||||
|
@ -544,33 +560,86 @@ class HATFile:
|
|||
some_hat_file.Serialize(`someFile.hat`)
|
||||
"""
|
||||
name: str = ""
|
||||
description: Description = None
|
||||
_function_table: FunctionTable = None
|
||||
_device_function_table: DeviceFunctionTable = None
|
||||
description: Description = Description()
|
||||
_function_table: FunctionTable = FunctionTable({})
|
||||
_device_function_table: DeviceFunctionTable = DeviceFunctionTable({})
|
||||
functions: list = field(default_factory=list)
|
||||
device_functions: list = field(default_factory=list)
|
||||
function_map: Dict[str, Function] = field(default_factory=dict)
|
||||
device_function_map: Dict[str, Function] = field(default_factory=dict)
|
||||
target: Target = None
|
||||
dependencies: Dependencies = None
|
||||
compiled_with: CompiledWith = None
|
||||
declaration: Declaration = None
|
||||
target: Target = Target()
|
||||
dependencies: Dependencies = Dependencies()
|
||||
compiled_with: CompiledWith = CompiledWith()
|
||||
declaration: Declaration = Declaration()
|
||||
path: Path = None
|
||||
|
||||
HATPrologue = "\n#ifndef __{0}__\n#define __{0}__\n\n#ifdef TOML\n"
|
||||
HATEpilogue = "\n#endif // TOML\n\n#endif // __{0}__\n"
|
||||
|
||||
def __post_init__(self):
|
||||
self.functions = self._function_table.functions
|
||||
self.function_map = self._function_table.function_map
|
||||
for func in self.functions:
|
||||
func.hat_file = self
|
||||
func.link_target = Path(self.path).resolve().parent / self.dependencies.link_target
|
||||
|
||||
if not self._device_function_table:
|
||||
self._device_function_table = DeviceFunctionTable({})
|
||||
self.device_function_map = self._device_function_table.function_map
|
||||
self.device_functions = self._device_function_table.functions
|
||||
@property
|
||||
def functions(self):
|
||||
if self._function_table is None:
|
||||
return []
|
||||
return self._function_table.function_map.values()
|
||||
|
||||
@functions.setter
|
||||
def functions(self, func_list_or_dict):
|
||||
if isinstance(func_list_or_dict, property):
|
||||
return
|
||||
if isinstance(func_list_or_dict, dict):
|
||||
name_to_func_map = func_list_or_dict
|
||||
else:
|
||||
name_to_func_map = { func.name : func for func in func_list_or_dict }
|
||||
if self._function_table is None:
|
||||
self._function_table = FunctionTable()
|
||||
self._function_table.function_map = name_to_func_map
|
||||
|
||||
@property
|
||||
def function_map(self):
|
||||
if self._function_table is None:
|
||||
return None
|
||||
return self._function_table.function_map
|
||||
|
||||
@function_map.setter
|
||||
def function_map(self, func_map):
|
||||
if isinstance(func_map, property):
|
||||
return
|
||||
self._function_table.function_map = func_map
|
||||
|
||||
@property
|
||||
def device_functions(self):
|
||||
if self._device_function_table is None:
|
||||
return []
|
||||
return self._device_function_table.function_map.value()
|
||||
|
||||
@device_functions.setter
|
||||
def device_functions(self, func_list_or_dict):
|
||||
if isinstance(func_list_or_dict, property):
|
||||
return
|
||||
if isinstance(func_list_or_dict, dict):
|
||||
name_to_func_map = func_list_or_dict
|
||||
else:
|
||||
name_to_func_map = { func.name : func for func in func_list_or_dict }
|
||||
if self._device_function_table is None:
|
||||
self._device_function_table = FunctionTable()
|
||||
self._device_function_table.function_map = name_to_func_map
|
||||
|
||||
@property
|
||||
def device_function_map(self):
|
||||
if self._device_function_table is None:
|
||||
return None
|
||||
return self._device_function_table.function_map
|
||||
|
||||
@device_function_map.setter
|
||||
def device_function_map(self, func_map):
|
||||
if isinstance(func_map, property):
|
||||
return
|
||||
self._device_function_table.function_map = func_map
|
||||
|
||||
def Serialize(self, filepath=None):
|
||||
"""Serilizes the HATFile to disk using the file location specified by `filepath`.
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
numpy
|
||||
pandas
|
||||
toml
|
||||
tomlkit
|
||||
tomlkit>=0.11.0
|
||||
vswhere; sys_platform == 'win32'
|
|
@ -0,0 +1 @@
|
|||
*.hat
|
|
@ -0,0 +1,129 @@
|
|||
#!/usr/bin/env python3
|
||||
import unittest
|
||||
import os
|
||||
|
||||
import hatlib as hat
|
||||
|
||||
SAMPLE_MATMUL_DECL_CODE = '''
|
||||
#endif // TOML
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
#if defined(__cplusplus)
|
||||
extern "C"
|
||||
{
|
||||
#endif // defined(__cplusplus)
|
||||
//
|
||||
// Functions
|
||||
//
|
||||
|
||||
void MatMul(const float* A, const float* B, float* C);
|
||||
|
||||
#if defined(__cplusplus)
|
||||
} // extern "C"
|
||||
#endif // defined(__cplusplus)
|
||||
|
||||
#ifdef TOML
|
||||
'''
|
||||
|
||||
class CreateSimpleHatFile_test(unittest.TestCase):
|
||||
|
||||
def test_create_simple_hat_file(self):
|
||||
a_shape = (1024, 512)
|
||||
a_strides = (a_shape[1], 1) # "first major" / "row major"
|
||||
|
||||
b_shape = (512, 256)
|
||||
b_strides = (1, b_shape[0]) # "last major" / "column major"
|
||||
|
||||
c_shape = (1024, 256)
|
||||
c_strides = (c_shape[1], 1) # "first major" / "row major"
|
||||
|
||||
param_A = hat.Parameter(
|
||||
name="A",
|
||||
description="the A input matrix argument",
|
||||
logical_type=hat.ParameterType.AffineArray,
|
||||
declared_type="float*",
|
||||
element_type="float",
|
||||
usage=hat.UsageType.Input,
|
||||
|
||||
# Affine array parameter keys
|
||||
shape=a_shape,
|
||||
affine_map=a_strides,
|
||||
affine_offset=0
|
||||
)
|
||||
|
||||
param_B = hat.Parameter(
|
||||
name="B",
|
||||
description="the B input matrix argument",
|
||||
logical_type=hat.ParameterType.AffineArray,
|
||||
declared_type="float*",
|
||||
element_type="float",
|
||||
usage=hat.UsageType.Input,
|
||||
|
||||
# Affine array parameter keys
|
||||
shape=b_shape,
|
||||
affine_map=b_strides,
|
||||
affine_offset=0
|
||||
)
|
||||
|
||||
param_C = hat.Parameter(
|
||||
name="C",
|
||||
description="the C input matrix argument",
|
||||
logical_type=hat.ParameterType.AffineArray,
|
||||
declared_type="float*",
|
||||
element_type="float",
|
||||
usage=hat.UsageType.InputOutput,
|
||||
|
||||
# Affine array parameter keys
|
||||
shape=c_shape,
|
||||
affine_map=c_strides,
|
||||
affine_offset=0
|
||||
)
|
||||
|
||||
arguments = [param_A, param_B, param_C]
|
||||
return_arg = hat.Parameter.void()
|
||||
|
||||
func_name = "MatMul"
|
||||
hat_function = hat.Function(
|
||||
arguments=arguments,
|
||||
calling_convention=hat.CallingConventionType.StdCall,
|
||||
description="Sample matmul hat declaration",
|
||||
name=func_name,
|
||||
return_info=return_arg
|
||||
)
|
||||
auxiliary_key_name = "test_auxiliary_key"
|
||||
hat_function.auxiliary[auxiliary_key_name] = { "name" : "matmul" }
|
||||
|
||||
link_target_path = "./fake_link_target.lib"
|
||||
hat_file_path = "./test_simple_hat_path.hat"
|
||||
new_hat_file = hat.HATFile(
|
||||
name="simple_hat_file",
|
||||
functions=[hat_function],
|
||||
dependencies=hat.Dependencies(link_target=link_target_path),
|
||||
declaration=hat.Declaration(code=SAMPLE_MATMUL_DECL_CODE),
|
||||
path=hat_file_path
|
||||
)
|
||||
|
||||
if os.path.exists(hat_file_path):
|
||||
os.remove(hat_file_path)
|
||||
|
||||
new_hat_file.Serialize(hat_file_path)
|
||||
|
||||
self.assertTrue(os.path.exists(hat_file_path))
|
||||
|
||||
parsed_hat_file = hat.HATFile.Deserialize(hat_file_path)
|
||||
|
||||
self.assertTrue(func_name in parsed_hat_file.function_map)
|
||||
self.assertEqual(parsed_hat_file.dependencies.link_target, link_target_path)
|
||||
self.assertEqual(len(parsed_hat_file.function_map[func_name].arguments), 3)
|
||||
self.assertEqual(parsed_hat_file.function_map[func_name].arguments[0].name, param_A.name)
|
||||
self.assertEqual(parsed_hat_file.function_map[func_name].arguments[1].name, param_B.name)
|
||||
self.assertEqual(parsed_hat_file.function_map[func_name].arguments[2].name, param_C.name)
|
||||
self.assertEqual(parsed_hat_file.function_map[func_name].arguments[0].shape, list(param_A.shape))
|
||||
self.assertEqual(parsed_hat_file.function_map[func_name].arguments[1].shape, list(param_B.shape))
|
||||
self.assertEqual(parsed_hat_file.function_map[func_name].arguments[2].shape, list(param_C.shape))
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Загрузка…
Ссылка в новой задаче