2021-03-16 21:01:10 +03:00
|
|
|
#!/usr/bin/env python3
|
|
|
|
|
|
|
|
# Utility to parse the TOML metadata from HAT files
|
2021-11-03 19:50:32 +03:00
|
|
|
import os
|
2021-03-16 21:01:10 +03:00
|
|
|
import tomlkit
|
2022-03-30 23:50:00 +03:00
|
|
|
from dataclasses import dataclass, field
|
|
|
|
from enum import Enum
|
|
|
|
from pathlib import Path
|
|
|
|
from typing import Dict, List
|
2021-03-16 21:01:10 +03:00
|
|
|
|
|
|
|
# TODO : type-checking on leaf node values
|
|
|
|
|
2022-03-15 09:20:41 +03:00
|
|
|
|
2021-03-16 21:01:10 +03:00
|
|
|
def _read_toml_file(filepath):
|
|
|
|
path = os.path.abspath(filepath)
|
|
|
|
toml_doc = None
|
|
|
|
with open(path, "r") as f:
|
|
|
|
file_contents = f.read()
|
|
|
|
toml_doc = tomlkit.parse(file_contents)
|
|
|
|
return toml_doc
|
|
|
|
|
2022-03-15 09:20:41 +03:00
|
|
|
|
2021-03-16 21:01:10 +03:00
|
|
|
def _check_required_table_entry(table, key):
|
|
|
|
if key not in table:
|
|
|
|
# TODO : add more context to this error message
|
2021-03-17 01:05:57 +03:00
|
|
|
raise ValueError(f"Invalid HAT file: missing required key {key}")
|
2021-03-16 21:01:10 +03:00
|
|
|
|
2022-03-15 09:20:41 +03:00
|
|
|
|
2021-03-16 21:01:10 +03:00
|
|
|
def _check_required_table_entries(table, keys):
|
|
|
|
for key in keys:
|
|
|
|
_check_required_table_entry(table, key)
|
|
|
|
|
|
|
|
|
|
|
|
class ParameterType(Enum):
|
|
|
|
AffineArray = "affine_array"
|
|
|
|
RuntimeArray = "runtime_array"
|
|
|
|
Element = "element"
|
|
|
|
Void = "void"
|
|
|
|
|
2022-03-15 09:20:41 +03:00
|
|
|
|
2021-03-16 21:01:10 +03:00
|
|
|
class UsageType(Enum):
|
|
|
|
Input = "input"
|
|
|
|
Output = "output"
|
|
|
|
InputOutput = "input_output"
|
|
|
|
|
2022-03-15 09:20:41 +03:00
|
|
|
|
2021-03-16 21:01:10 +03:00
|
|
|
class CallingConventionType(Enum):
|
|
|
|
StdCall = "stdcall"
|
|
|
|
CDecl = "cdecl"
|
|
|
|
FastCall = "fastcall"
|
|
|
|
VectorCall = "vectorcall"
|
2022-03-15 09:20:41 +03:00
|
|
|
Device = "devicecall"
|
|
|
|
|
2021-03-16 21:01:10 +03:00
|
|
|
|
|
|
|
class TargetType(Enum):
|
|
|
|
CPU = "CPU"
|
|
|
|
GPU = "GPU"
|
|
|
|
|
2022-03-15 09:20:41 +03:00
|
|
|
|
2021-03-16 21:01:10 +03:00
|
|
|
class OperatingSystem(Enum):
|
|
|
|
Windows = "windows"
|
|
|
|
MacOS = "macos"
|
|
|
|
Linux = "linux"
|
|
|
|
|
2022-06-09 11:05:07 +03:00
|
|
|
@staticmethod
|
|
|
|
def host():
|
|
|
|
import platform
|
|
|
|
platform_name = platform.system().lower()
|
|
|
|
if platform_name == "darwin":
|
|
|
|
return OperatingSystem.MacOS
|
|
|
|
return OperatingSystem(platform_name)
|
|
|
|
|
2022-03-15 09:20:41 +03:00
|
|
|
|
2021-03-16 21:01:10 +03:00
|
|
|
@dataclass
|
|
|
|
class AuxiliarySupportedTable:
|
|
|
|
AuxiliaryKey = "auxiliary"
|
|
|
|
auxiliary: dict = field(default_factory=dict)
|
|
|
|
|
|
|
|
def add_auxiliary_table(self, table):
|
|
|
|
if len(self.auxiliary) > 0:
|
|
|
|
table.add(self.AuxiliaryKey, self.auxiliary)
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def parse_auxiliary(table):
|
|
|
|
if AuxiliarySupportedTable.AuxiliaryKey in table:
|
|
|
|
return table[AuxiliarySupportedTable.AuxiliaryKey]
|
|
|
|
else:
|
|
|
|
return {}
|
|
|
|
|
2022-03-15 09:20:41 +03:00
|
|
|
|
2021-03-16 21:01:10 +03:00
|
|
|
@dataclass
|
2021-04-27 05:24:49 +03:00
|
|
|
class Description(AuxiliarySupportedTable):
|
2021-03-16 21:01:10 +03:00
|
|
|
TableName: str = "description"
|
|
|
|
comment: str = ""
|
|
|
|
author: str = ""
|
|
|
|
version: str = ""
|
|
|
|
license_url: str = ""
|
|
|
|
|
|
|
|
def to_table(self):
|
|
|
|
description_table = tomlkit.table()
|
2021-11-03 19:50:32 +03:00
|
|
|
description_table.add("comment", self.comment)
|
2021-03-16 21:01:10 +03:00
|
|
|
description_table.add("author", self.author)
|
|
|
|
description_table.add("version", self.version)
|
|
|
|
description_table.add("license_url", self.license_url)
|
2021-04-27 05:24:49 +03:00
|
|
|
|
|
|
|
self.add_auxiliary_table(description_table)
|
|
|
|
|
2021-03-16 21:01:10 +03:00
|
|
|
return description_table
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def parse_from_table(table):
|
2022-03-15 09:20:41 +03:00
|
|
|
return Description(
|
|
|
|
author=table["author"],
|
|
|
|
version=table["version"],
|
|
|
|
license_url=table["license_url"],
|
2022-03-30 23:50:00 +03:00
|
|
|
auxiliary=AuxiliarySupportedTable.parse_auxiliary(table)
|
|
|
|
)
|
2022-03-15 09:20:41 +03:00
|
|
|
|
2021-03-16 21:01:10 +03:00
|
|
|
|
|
|
|
@dataclass
|
|
|
|
class Parameter:
|
|
|
|
# All parameter keys
|
|
|
|
name: str = ""
|
|
|
|
description: str = ""
|
|
|
|
logical_type: ParameterType = None
|
|
|
|
declared_type: str = ""
|
|
|
|
element_type: str = ""
|
|
|
|
usage: UsageType = None
|
|
|
|
|
|
|
|
# Affine array parameter keys
|
2022-06-09 11:05:07 +03:00
|
|
|
shape: list = field(default_factory=list)
|
2021-03-16 21:01:10 +03:00
|
|
|
affine_map: list = field(default_factory=list)
|
|
|
|
affine_offset: int = -1
|
|
|
|
|
|
|
|
# Runtime array parameter keys
|
|
|
|
size: str = ""
|
|
|
|
|
|
|
|
def to_table(self):
|
|
|
|
table = tomlkit.inline_table()
|
|
|
|
table.append("name", self.name)
|
|
|
|
table.append("description", self.description)
|
|
|
|
table.append("logical_type", self.logical_type.value)
|
|
|
|
table.append("declared_type", self.declared_type)
|
|
|
|
table.append("element_type", self.element_type)
|
|
|
|
table.append("usage", self.usage.value)
|
|
|
|
|
|
|
|
if self.logical_type == ParameterType.AffineArray:
|
|
|
|
table.append("shape", self.shape)
|
|
|
|
table.append("affine_map", self.affine_map)
|
|
|
|
table.append("affine_offset", self.affine_offset)
|
|
|
|
elif self.logical_type == ParameterType.RuntimeArray:
|
|
|
|
table.append("size", self.size)
|
|
|
|
|
|
|
|
return table
|
|
|
|
|
|
|
|
# TODO : change "usage" to "role" in schema
|
|
|
|
@staticmethod
|
|
|
|
def parse_from_table(param_table):
|
2022-03-30 23:50:00 +03:00
|
|
|
required_table_entries = ["name", "description", "logical_type", "declared_type", "element_type", "usage"]
|
2021-03-16 21:01:10 +03:00
|
|
|
_check_required_table_entries(param_table, required_table_entries)
|
2022-03-30 23:50:00 +03:00
|
|
|
affine_array_required_table_entries = ["shape", "affine_map", "affine_offset"]
|
2021-03-16 21:01:10 +03:00
|
|
|
runtime_array_required_table_entries = ["size"]
|
|
|
|
|
|
|
|
name = param_table["name"]
|
|
|
|
description = param_table["description"]
|
|
|
|
logical_type = ParameterType(param_table["logical_type"])
|
|
|
|
declared_type = param_table["declared_type"]
|
|
|
|
element_type = param_table["element_type"]
|
|
|
|
usage = UsageType(param_table["usage"])
|
|
|
|
|
2022-03-30 23:50:00 +03:00
|
|
|
param = Parameter(
|
|
|
|
name=name,
|
|
|
|
description=description,
|
|
|
|
logical_type=logical_type,
|
|
|
|
declared_type=declared_type,
|
|
|
|
element_type=element_type,
|
|
|
|
usage=usage
|
|
|
|
)
|
2022-03-15 09:20:41 +03:00
|
|
|
|
2021-03-16 21:01:10 +03:00
|
|
|
if logical_type == ParameterType.AffineArray:
|
2022-03-30 23:50:00 +03:00
|
|
|
_check_required_table_entries(param_table, affine_array_required_table_entries)
|
2021-03-16 21:01:10 +03:00
|
|
|
param.shape = param_table["shape"]
|
|
|
|
param.affine_map = param_table["affine_map"]
|
|
|
|
param.affine_offset = param_table["affine_offset"]
|
2022-03-15 09:20:41 +03:00
|
|
|
|
2021-03-16 21:01:10 +03:00
|
|
|
elif logical_type == ParameterType.RuntimeArray:
|
2022-03-30 23:50:00 +03:00
|
|
|
_check_required_table_entries(param_table, runtime_array_required_table_entries)
|
2021-03-16 21:01:10 +03:00
|
|
|
param.size = param_table["size"]
|
|
|
|
|
|
|
|
return param
|
|
|
|
|
2022-06-09 11:05:07 +03:00
|
|
|
@staticmethod
|
|
|
|
def void():
|
|
|
|
return Parameter(
|
|
|
|
logical_type=ParameterType.Void,
|
|
|
|
declared_type="void",
|
|
|
|
element_type="void",
|
|
|
|
usage=UsageType.Output
|
|
|
|
)
|
2021-03-16 21:01:10 +03:00
|
|
|
|
|
|
|
@dataclass
|
|
|
|
class Function(AuxiliarySupportedTable):
|
2022-03-15 09:20:41 +03:00
|
|
|
# required
|
2022-03-30 23:50:00 +03:00
|
|
|
arguments: List[Parameter] = field(default_factory=list)
|
2022-03-15 09:20:41 +03:00
|
|
|
calling_convention: CallingConventionType = None
|
|
|
|
description: str = ""
|
2021-03-26 00:31:27 +03:00
|
|
|
hat_file: any = None
|
|
|
|
link_target: Path = None
|
2022-03-15 09:20:41 +03:00
|
|
|
name: str = ""
|
|
|
|
return_info: Parameter = None
|
|
|
|
|
|
|
|
# optional
|
|
|
|
launch_parameters: list = field(default_factory=list)
|
|
|
|
launches: str = ""
|
|
|
|
provider: str = ""
|
|
|
|
runtime: str = ""
|
2021-03-16 21:01:10 +03:00
|
|
|
|
|
|
|
def to_table(self):
|
|
|
|
table = tomlkit.table()
|
|
|
|
table.add("name", self.name)
|
|
|
|
table.add("description", self.description)
|
|
|
|
table.add("calling_convention", self.calling_convention.value)
|
|
|
|
arg_tables = [arg.to_table() for arg in self.arguments]
|
|
|
|
arg_array = tomlkit.array()
|
|
|
|
for arg_table in arg_tables:
|
|
|
|
arg_array.append(arg_table)
|
2022-03-15 09:20:41 +03:00
|
|
|
table.add(
|
|
|
|
"arguments", arg_array
|
2022-03-30 23:50:00 +03:00
|
|
|
) # TODO : figure out why this isn't indenting after serialization in some cases
|
2022-03-15 09:20:41 +03:00
|
|
|
|
|
|
|
if self.launch_parameters:
|
|
|
|
table.add("launch_parameters", self.launch_parameters)
|
|
|
|
|
|
|
|
if self.launches:
|
|
|
|
table.add("launches", self.launches)
|
|
|
|
|
|
|
|
if self.provider:
|
|
|
|
table.add("provider", self.provider)
|
|
|
|
|
|
|
|
if self.runtime:
|
|
|
|
table.add("runtime", self.runtime)
|
|
|
|
|
2021-03-16 21:01:10 +03:00
|
|
|
table.add("return", self.return_info.to_table())
|
|
|
|
|
|
|
|
self.add_auxiliary_table(table)
|
|
|
|
|
|
|
|
return table
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def parse_from_table(function_table):
|
2022-03-30 23:50:00 +03:00
|
|
|
required_table_entries = ["name", "description", "calling_convention", "arguments", "return"]
|
2021-03-16 21:01:10 +03:00
|
|
|
_check_required_table_entries(function_table, required_table_entries)
|
2022-03-30 23:50:00 +03:00
|
|
|
arguments = [Parameter.parse_from_table(param_table) for param_table in function_table["arguments"]]
|
2021-03-16 21:01:10 +03:00
|
|
|
|
2022-03-30 23:50:00 +03:00
|
|
|
launch_parameters = function_table["launch_parameters"] if "launch_parameters" in function_table else []
|
2021-03-16 21:01:10 +03:00
|
|
|
|
2022-03-30 23:50:00 +03:00
|
|
|
launches = function_table["launches"] if "launches" in function_table else ""
|
2022-03-15 09:20:41 +03:00
|
|
|
|
2022-03-30 23:50:00 +03:00
|
|
|
provider = function_table["provider"] if "provider" in function_table else ""
|
2022-03-15 09:20:41 +03:00
|
|
|
|
2022-03-30 23:50:00 +03:00
|
|
|
runtime = function_table["runtime"] if "runtime" in function_table else ""
|
2022-03-15 09:20:41 +03:00
|
|
|
|
|
|
|
return_info = Parameter.parse_from_table(function_table["return"])
|
|
|
|
|
|
|
|
return Function(
|
|
|
|
name=function_table["name"],
|
|
|
|
description=function_table["description"],
|
2022-03-30 23:50:00 +03:00
|
|
|
calling_convention=CallingConventionType(function_table["calling_convention"]),
|
2022-03-15 09:20:41 +03:00
|
|
|
arguments=arguments,
|
|
|
|
return_info=return_info,
|
|
|
|
launch_parameters=launch_parameters,
|
|
|
|
launches=launches,
|
|
|
|
provider=provider,
|
|
|
|
runtime=runtime,
|
2022-03-30 23:50:00 +03:00
|
|
|
auxiliary=AuxiliarySupportedTable.parse_auxiliary(function_table)
|
|
|
|
)
|
2022-03-15 09:20:41 +03:00
|
|
|
|
|
|
|
|
|
|
|
class FunctionTableCommon:
|
2022-03-30 23:50:00 +03:00
|
|
|
|
2022-01-05 10:46:04 +03:00
|
|
|
def __init__(self, function_map):
|
|
|
|
self.function_map = function_map
|
2021-03-16 21:01:10 +03:00
|
|
|
|
|
|
|
def to_table(self):
|
|
|
|
func_table = tomlkit.table()
|
2022-01-05 10:46:04 +03:00
|
|
|
for function_key in self.function_map:
|
2022-03-30 23:50:00 +03:00
|
|
|
func_table.add(function_key, self.function_map[function_key].to_table())
|
2021-03-16 21:01:10 +03:00
|
|
|
return func_table
|
|
|
|
|
2022-03-15 09:20:41 +03:00
|
|
|
@classmethod
|
|
|
|
def parse_from_table(cls, all_functions_table):
|
|
|
|
function_map = {
|
2022-03-30 23:50:00 +03:00
|
|
|
function_key: Function.parse_from_table(all_functions_table[function_key])
|
2022-03-15 09:20:41 +03:00
|
|
|
for function_key in all_functions_table
|
|
|
|
}
|
|
|
|
return cls(function_map)
|
|
|
|
|
|
|
|
|
|
|
|
class FunctionTable(FunctionTableCommon):
|
|
|
|
TableName = "functions"
|
|
|
|
|
|
|
|
|
|
|
|
class DeviceFunctionTable(FunctionTableCommon):
|
|
|
|
TableName = "device_functions"
|
2021-03-16 21:01:10 +03:00
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
class Target:
|
2022-03-30 23:50:00 +03:00
|
|
|
|
2021-03-16 21:01:10 +03:00
|
|
|
@dataclass
|
|
|
|
class Required:
|
2022-03-30 23:50:00 +03:00
|
|
|
|
2021-03-16 21:01:10 +03:00
|
|
|
@dataclass
|
|
|
|
class CPU:
|
|
|
|
TableName = TargetType.CPU.value
|
2022-03-15 09:20:41 +03:00
|
|
|
|
|
|
|
# required
|
2021-03-16 21:01:10 +03:00
|
|
|
architecture: str = ""
|
|
|
|
extensions: list = field(default_factory=list)
|
|
|
|
|
2022-03-15 09:20:41 +03:00
|
|
|
# optional
|
|
|
|
runtime: str = ""
|
|
|
|
|
2021-03-16 21:01:10 +03:00
|
|
|
def to_table(self):
|
|
|
|
table = tomlkit.table()
|
|
|
|
table.add("architecture", self.architecture)
|
|
|
|
table.add("extensions", self.extensions)
|
2022-03-15 09:20:41 +03:00
|
|
|
|
|
|
|
if self.runtime:
|
|
|
|
table.add("runtime", self.runtime)
|
|
|
|
|
2021-03-16 21:01:10 +03:00
|
|
|
return table
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def parse_from_table(table):
|
|
|
|
required_table_entries = ["architecture", "extensions"]
|
|
|
|
_check_required_table_entries(table, required_table_entries)
|
|
|
|
|
2022-03-15 09:20:41 +03:00
|
|
|
runtime = table.get("runtime", "")
|
|
|
|
|
|
|
|
return Target.Required.CPU(
|
2022-03-30 23:50:00 +03:00
|
|
|
architecture=table["architecture"], extensions=table["extensions"], runtime=runtime
|
|
|
|
)
|
2022-03-15 09:20:41 +03:00
|
|
|
|
|
|
|
@dataclass
|
2021-03-16 21:01:10 +03:00
|
|
|
class GPU:
|
2022-03-15 09:20:41 +03:00
|
|
|
TableName = TargetType.GPU.value
|
|
|
|
blocks: int = 0
|
|
|
|
instruction_set_version: str = ""
|
|
|
|
min_threads: int = 0
|
|
|
|
min_global_memory_KB: int = 0
|
|
|
|
min_shared_memory_KB: int = 0
|
|
|
|
min_texture_memory_KB: int = 0
|
|
|
|
model: str = ""
|
|
|
|
runtime: str = ""
|
2021-03-16 21:01:10 +03:00
|
|
|
|
|
|
|
def to_table(self):
|
2022-03-15 09:20:41 +03:00
|
|
|
table = tomlkit.table()
|
|
|
|
table.add("model", self.model)
|
|
|
|
table.add("runtime", self.runtime)
|
|
|
|
table.add("blocks", self.blocks)
|
2022-03-30 23:50:00 +03:00
|
|
|
table.add("instruction_set_version", self.instruction_set_version)
|
2022-03-15 09:20:41 +03:00
|
|
|
table.add("min_threads", self.min_threads)
|
|
|
|
table.add("min_global_memory_KB", self.min_global_memory_KB)
|
|
|
|
table.add("min_shared_memory_KB", self.min_shared_memory_KB)
|
|
|
|
table.add("min_texture_memory_KB", self.min_texture_memory_KB)
|
|
|
|
|
|
|
|
return table
|
2021-03-16 21:01:10 +03:00
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def parse_from_table(table):
|
2022-03-15 09:20:41 +03:00
|
|
|
required_table_entries = [
|
|
|
|
"runtime",
|
|
|
|
"model",
|
|
|
|
]
|
|
|
|
_check_required_table_entries(table, required_table_entries)
|
2021-03-16 21:01:10 +03:00
|
|
|
|
2022-03-15 09:20:41 +03:00
|
|
|
return Target.Required.GPU(
|
|
|
|
runtime=table["runtime"],
|
|
|
|
model=table["model"],
|
|
|
|
blocks=table["blocks"],
|
|
|
|
instruction_set_version=table["instruction_set_version"],
|
|
|
|
min_threads=table["min_threads"],
|
|
|
|
min_global_memory_KB=table["min_global_memory_KB"],
|
|
|
|
min_shared_memory_KB=table["min_shared_memory_KB"],
|
2022-03-30 23:50:00 +03:00
|
|
|
min_texture_memory_KB=table["min_texture_memory_KB"]
|
|
|
|
)
|
2021-03-16 21:01:10 +03:00
|
|
|
|
|
|
|
TableName = "required"
|
2022-06-09 11:05:07 +03:00
|
|
|
os: OperatingSystem = OperatingSystem.host()
|
|
|
|
cpu: CPU = CPU()
|
2021-03-16 21:01:10 +03:00
|
|
|
gpu: GPU = None
|
|
|
|
|
|
|
|
def to_table(self):
|
|
|
|
table = tomlkit.table()
|
|
|
|
table.add("os", self.os.value)
|
|
|
|
table.add(Target.Required.CPU.TableName, self.cpu.to_table())
|
2022-03-15 09:20:41 +03:00
|
|
|
if self.gpu and self.gpu.runtime:
|
2021-03-16 21:01:10 +03:00
|
|
|
table.add(Target.Required.GPU.TableName, self.gpu.to_table())
|
|
|
|
return table
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def parse_from_table(table):
|
|
|
|
required_table_entries = ["os", Target.Required.CPU.TableName]
|
|
|
|
_check_required_table_entries(table, required_table_entries)
|
2022-03-30 23:50:00 +03:00
|
|
|
cpu_info = Target.Required.CPU.parse_from_table(table[Target.Required.CPU.TableName])
|
2021-03-16 21:01:10 +03:00
|
|
|
if Target.Required.GPU.TableName in table:
|
2022-03-30 23:50:00 +03:00
|
|
|
gpu_info = Target.Required.GPU.parse_from_table(table[Target.Required.GPU.TableName])
|
2021-03-16 21:01:10 +03:00
|
|
|
else:
|
|
|
|
gpu_info = Target.Required.GPU()
|
|
|
|
return Target.Required(os=table["os"], cpu=cpu_info, gpu=gpu_info)
|
|
|
|
|
|
|
|
# TODO : support optimized_for table
|
|
|
|
class OptimizedFor:
|
|
|
|
TableName = "optimized_for"
|
|
|
|
|
|
|
|
def to_table(self):
|
|
|
|
return tomlkit.table()
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def parse_from_table(table):
|
|
|
|
pass
|
|
|
|
|
|
|
|
TableName = "target"
|
2022-06-09 11:05:07 +03:00
|
|
|
required: Required = Required()
|
|
|
|
optimized_for: OptimizedFor = Required()
|
2021-03-16 21:01:10 +03:00
|
|
|
|
|
|
|
def to_table(self):
|
|
|
|
table = tomlkit.table()
|
|
|
|
table.add(Target.Required.TableName, self.required.to_table())
|
|
|
|
if self.optimized_for is not None:
|
2022-03-30 23:50:00 +03:00
|
|
|
table.add(Target.OptimizedFor.TableName, self.optimized_for.to_table())
|
2021-03-16 21:01:10 +03:00
|
|
|
return table
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def parse_from_table(target_table):
|
|
|
|
required_table_entries = [Target.Required.TableName]
|
|
|
|
_check_required_table_entries(target_table, required_table_entries)
|
2022-03-30 23:50:00 +03:00
|
|
|
required_data = Target.Required.parse_from_table(target_table[Target.Required.TableName])
|
2021-03-16 21:01:10 +03:00
|
|
|
if Target.OptimizedFor.TableName in target_table:
|
2022-03-30 23:50:00 +03:00
|
|
|
optimized_for_data = Target.OptimizedFor.parse_from_table(target_table[Target.OptimizedFor.TableName])
|
2021-03-16 21:01:10 +03:00
|
|
|
else:
|
|
|
|
optimized_for_data = Target.OptimizedFor()
|
|
|
|
return Target(required=required_data, optimized_for=optimized_for_data)
|
|
|
|
|
2022-03-15 09:20:41 +03:00
|
|
|
|
2021-03-16 21:01:10 +03:00
|
|
|
@dataclass
|
|
|
|
class LibraryReference:
|
|
|
|
name: str = ""
|
|
|
|
version: str = ""
|
|
|
|
target_file: str = ""
|
|
|
|
|
|
|
|
def to_table(self):
|
|
|
|
table = tomlkit.inline_table()
|
|
|
|
table.append("name", self.name)
|
|
|
|
table.append("version", self.version)
|
|
|
|
table.append("target_file", self.target_file)
|
|
|
|
return table
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def parse_from_table(table):
|
2022-03-30 23:50:00 +03:00
|
|
|
return LibraryReference(name=table["name"], version=table["version"], target_file=table["target_file"])
|
2021-03-16 21:01:10 +03:00
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
class Dependencies(AuxiliarySupportedTable):
|
|
|
|
TableName = "dependencies"
|
|
|
|
link_target: str = ""
|
|
|
|
deploy_files: list = field(default_factory=list)
|
|
|
|
dynamic: list = field(default_factory=list)
|
|
|
|
|
|
|
|
def to_table(self):
|
|
|
|
table = tomlkit.table()
|
|
|
|
table.add("link_target", self.link_target)
|
|
|
|
table.add("deploy_files", self.deploy_files)
|
|
|
|
|
|
|
|
dynamic_arr = tomlkit.array()
|
|
|
|
for elt in self.dynamic:
|
|
|
|
dynamic_arr.append(elt.to_table())
|
|
|
|
table.add("dynamic", dynamic_arr)
|
|
|
|
|
|
|
|
self.add_auxiliary_table(table)
|
|
|
|
return table
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def parse_from_table(dependencies_table):
|
|
|
|
required_table_entries = ["link_target", "deploy_files", "dynamic"]
|
2022-03-30 23:50:00 +03:00
|
|
|
_check_required_table_entries(dependencies_table, required_table_entries)
|
|
|
|
dynamic = [LibraryReference.parse_from_table(lib_ref_table) for lib_ref_table in dependencies_table["dynamic"]]
|
|
|
|
return Dependencies(
|
|
|
|
link_target=dependencies_table["link_target"],
|
|
|
|
deploy_files=dependencies_table["deploy_files"],
|
|
|
|
dynamic=dynamic,
|
|
|
|
auxiliary=AuxiliarySupportedTable.parse_auxiliary(dependencies_table)
|
|
|
|
)
|
2022-03-15 09:20:41 +03:00
|
|
|
|
2021-03-16 21:01:10 +03:00
|
|
|
|
|
|
|
@dataclass
|
|
|
|
class CompiledWith:
|
|
|
|
TableName = "compiled_with"
|
|
|
|
compiler: str = ""
|
|
|
|
flags: str = ""
|
|
|
|
crt: str = ""
|
|
|
|
libraries: list = field(default_factory=list)
|
|
|
|
|
|
|
|
def to_table(self):
|
|
|
|
table = tomlkit.table()
|
|
|
|
table.add("compiler", self.compiler)
|
|
|
|
table.add("flags", self.flags)
|
|
|
|
table.add("crt", self.crt)
|
|
|
|
|
|
|
|
libraries_arr = tomlkit.array()
|
|
|
|
for elt in self.libraries:
|
|
|
|
libraries_arr.append(elt.to_table())
|
|
|
|
table.add("libraries", libraries_arr)
|
|
|
|
|
|
|
|
return table
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def parse_from_table(compiled_with_table):
|
|
|
|
required_table_entries = ["compiler", "flags", "crt", "libraries"]
|
2022-03-30 23:50:00 +03:00
|
|
|
_check_required_table_entries(compiled_with_table, required_table_entries)
|
2022-03-15 09:20:41 +03:00
|
|
|
libraries = [
|
2022-03-30 23:50:00 +03:00
|
|
|
LibraryReference.parse_from_table(lib_ref_table) for lib_ref_table in compiled_with_table["libraries"]
|
2022-03-15 09:20:41 +03:00
|
|
|
]
|
2022-03-30 23:50:00 +03:00
|
|
|
return CompiledWith(
|
|
|
|
compiler=compiled_with_table["compiler"],
|
|
|
|
flags=compiled_with_table["flags"],
|
|
|
|
crt=compiled_with_table["crt"],
|
|
|
|
libraries=libraries
|
|
|
|
)
|
2021-03-16 21:01:10 +03:00
|
|
|
|
2022-03-15 09:20:41 +03:00
|
|
|
|
2021-03-16 21:01:10 +03:00
|
|
|
@dataclass
|
|
|
|
class Declaration:
|
|
|
|
TableName = "declaration"
|
|
|
|
code: str = ""
|
|
|
|
|
|
|
|
def to_table(self):
|
|
|
|
table = tomlkit.table()
|
2022-06-09 23:23:56 +03:00
|
|
|
code_str = self.code
|
|
|
|
if len(code_str) > 0 and code_str[0] != '\n':
|
|
|
|
code_str = "\n" + code_str
|
|
|
|
code_str = tomlkit.string(code_str, multiline=True)
|
2022-06-09 11:05:07 +03:00
|
|
|
table.add("code", code_str)
|
2021-03-16 21:01:10 +03:00
|
|
|
return table
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def parse_from_table(declaration_table):
|
|
|
|
required_table_entries = ["code"]
|
2022-03-30 23:50:00 +03:00
|
|
|
_check_required_table_entries(declaration_table, required_table_entries)
|
2021-03-16 21:01:10 +03:00
|
|
|
return Declaration(code=declaration_table["code"])
|
|
|
|
|
2022-03-15 09:20:41 +03:00
|
|
|
|
2021-03-16 21:01:10 +03:00
|
|
|
@dataclass
|
|
|
|
class HATFile:
|
2022-03-15 09:20:41 +03:00
|
|
|
"""Encapsulates a HAT file. An instance of this class can be created by calling the
|
2021-11-09 02:47:56 +03:00
|
|
|
Deserialize class method e.g.:
|
2022-03-15 09:20:41 +03:00
|
|
|
some_hat_file = Deserialize('someFile.hat')
|
2021-11-09 02:47:56 +03:00
|
|
|
Similarly, HAT files can be serialized but creating/modifying a HATFile instance
|
|
|
|
and then calling Serilize e.g.:
|
|
|
|
some_hat_file.name = 'some new name'
|
|
|
|
some_hat_file.Serialize(`someFile.hat`)
|
|
|
|
"""
|
2021-03-19 22:02:01 +03:00
|
|
|
name: str = ""
|
2022-06-09 11:05:07 +03:00
|
|
|
description: Description = Description()
|
|
|
|
_function_table: FunctionTable = FunctionTable({})
|
|
|
|
_device_function_table: DeviceFunctionTable = DeviceFunctionTable({})
|
2022-01-05 10:46:04 +03:00
|
|
|
functions: list = field(default_factory=list)
|
2022-03-15 09:20:41 +03:00
|
|
|
device_functions: list = field(default_factory=list)
|
2022-03-30 23:50:00 +03:00
|
|
|
function_map: Dict[str, Function] = field(default_factory=dict)
|
|
|
|
device_function_map: Dict[str, Function] = field(default_factory=dict)
|
2022-06-09 11:05:07 +03:00
|
|
|
target: Target = Target()
|
|
|
|
dependencies: Dependencies = Dependencies()
|
|
|
|
compiled_with: CompiledWith = CompiledWith()
|
|
|
|
declaration: Declaration = Declaration()
|
2021-03-26 00:31:27 +03:00
|
|
|
path: Path = None
|
2021-03-16 21:01:10 +03:00
|
|
|
|
2021-03-19 22:02:01 +03:00
|
|
|
HATPrologue = "\n#ifndef __{0}__\n#define __{0}__\n\n#ifdef TOML\n"
|
2022-01-26 05:12:02 +03:00
|
|
|
HATEpilogue = "\n#endif // TOML\n\n#endif // __{0}__\n"
|
2021-03-16 21:01:10 +03:00
|
|
|
|
|
|
|
def __post_init__(self):
|
2022-01-05 10:46:04 +03:00
|
|
|
for func in self.functions:
|
|
|
|
func.hat_file = self
|
2022-03-30 23:50:00 +03:00
|
|
|
func.link_target = Path(self.path).resolve().parent / self.dependencies.link_target
|
2022-03-15 09:20:41 +03:00
|
|
|
|
2022-06-09 11:05:07 +03:00
|
|
|
@property
|
|
|
|
def functions(self):
|
|
|
|
return self._function_table.function_map.values()
|
|
|
|
|
|
|
|
@functions.setter
|
|
|
|
def functions(self, func_list_or_dict):
|
|
|
|
if isinstance(func_list_or_dict, property):
|
|
|
|
return
|
|
|
|
if isinstance(func_list_or_dict, dict):
|
|
|
|
name_to_func_map = func_list_or_dict
|
|
|
|
else:
|
|
|
|
name_to_func_map = { func.name : func for func in func_list_or_dict }
|
|
|
|
if self._function_table is None:
|
|
|
|
self._function_table = FunctionTable()
|
|
|
|
self._function_table.function_map = name_to_func_map
|
|
|
|
|
|
|
|
@property
|
|
|
|
def function_map(self):
|
|
|
|
return self._function_table.function_map
|
|
|
|
|
|
|
|
@function_map.setter
|
|
|
|
def function_map(self, func_map):
|
|
|
|
if isinstance(func_map, property):
|
|
|
|
return
|
|
|
|
self._function_table.function_map = func_map
|
|
|
|
|
|
|
|
@property
|
|
|
|
def device_functions(self):
|
2022-06-09 23:23:56 +03:00
|
|
|
return self._device_function_table.function_map.values()
|
2022-06-09 11:05:07 +03:00
|
|
|
|
|
|
|
@device_functions.setter
|
|
|
|
def device_functions(self, func_list_or_dict):
|
|
|
|
if isinstance(func_list_or_dict, property):
|
|
|
|
return
|
|
|
|
if isinstance(func_list_or_dict, dict):
|
|
|
|
name_to_func_map = func_list_or_dict
|
|
|
|
else:
|
|
|
|
name_to_func_map = { func.name : func for func in func_list_or_dict }
|
|
|
|
if self._device_function_table is None:
|
|
|
|
self._device_function_table = FunctionTable()
|
|
|
|
self._device_function_table.function_map = name_to_func_map
|
|
|
|
|
|
|
|
@property
|
|
|
|
def device_function_map(self):
|
|
|
|
return self._device_function_table.function_map
|
|
|
|
|
|
|
|
@device_function_map.setter
|
|
|
|
def device_function_map(self, func_map):
|
|
|
|
if isinstance(func_map, property):
|
|
|
|
return
|
|
|
|
self._device_function_table.function_map = func_map
|
2021-03-16 21:01:10 +03:00
|
|
|
|
2021-03-26 00:31:27 +03:00
|
|
|
def Serialize(self, filepath=None):
|
2022-06-09 23:23:56 +03:00
|
|
|
"""Serializes the HATFile to disk using the file location specified by `filepath`.
|
2021-11-09 02:47:56 +03:00
|
|
|
If `filepath` is not specified then the object's `path` attribute is used."""
|
2021-03-26 00:31:27 +03:00
|
|
|
if filepath is None:
|
|
|
|
filepath = self.path
|
2021-03-16 21:01:10 +03:00
|
|
|
root_table = tomlkit.table()
|
|
|
|
root_table.add(Description.TableName, self.description.to_table())
|
2022-03-30 23:50:00 +03:00
|
|
|
root_table.add(FunctionTable.TableName, self._function_table.to_table())
|
2022-03-15 09:20:41 +03:00
|
|
|
if self.device_function_map:
|
2022-03-30 23:50:00 +03:00
|
|
|
root_table.add(DeviceFunctionTable.TableName, self._device_function_table.to_table())
|
2021-03-16 21:01:10 +03:00
|
|
|
root_table.add(Target.TableName, self.target.to_table())
|
|
|
|
root_table.add(Dependencies.TableName, self.dependencies.to_table())
|
|
|
|
root_table.add(CompiledWith.TableName, self.compiled_with.to_table())
|
|
|
|
root_table.add(Declaration.TableName, self.declaration.to_table())
|
|
|
|
with open(filepath, "w") as out_file:
|
2022-01-25 02:01:47 +03:00
|
|
|
# MSVC does not allow "." in macro definitions
|
|
|
|
name = self.name.replace(".", "_")
|
|
|
|
out_file.write(self.HATPrologue.format(name))
|
2021-03-16 21:01:10 +03:00
|
|
|
out_file.write(tomlkit.dumps(root_table))
|
2022-01-25 02:01:47 +03:00
|
|
|
out_file.write(self.HATEpilogue.format(name))
|
2021-03-16 21:01:10 +03:00
|
|
|
|
|
|
|
@staticmethod
|
2022-03-15 09:20:41 +03:00
|
|
|
def Deserialize(filepath) -> "HATFile":
|
2021-11-09 02:47:56 +03:00
|
|
|
"""Creates an instance of A HATFile class by deserializing the contents of the file at `filepath`"""
|
2021-03-16 21:01:10 +03:00
|
|
|
hat_toml = _read_toml_file(filepath)
|
2021-03-19 22:05:15 +03:00
|
|
|
name = os.path.splitext(os.path.basename(filepath))[0]
|
2022-03-15 09:20:41 +03:00
|
|
|
required_entries = [
|
2022-03-30 23:50:00 +03:00
|
|
|
Description.TableName, FunctionTable.TableName, Target.TableName, Dependencies.TableName,
|
|
|
|
CompiledWith.TableName, Declaration.TableName
|
2022-03-15 09:20:41 +03:00
|
|
|
]
|
2021-03-16 21:01:10 +03:00
|
|
|
_check_required_table_entries(hat_toml, required_entries)
|
2022-06-09 23:23:56 +03:00
|
|
|
device_function_table = DeviceFunctionTable({})
|
2022-03-15 09:20:41 +03:00
|
|
|
if DeviceFunctionTable.TableName in hat_toml:
|
2022-03-30 23:50:00 +03:00
|
|
|
device_function_table = DeviceFunctionTable.parse_from_table(hat_toml[DeviceFunctionTable.TableName])
|
2022-03-15 09:20:41 +03:00
|
|
|
hat_file = HATFile(
|
|
|
|
name=name,
|
2022-03-30 23:50:00 +03:00
|
|
|
description=Description.parse_from_table(hat_toml[Description.TableName]),
|
|
|
|
_function_table=FunctionTable.parse_from_table(hat_toml[FunctionTable.TableName]),
|
2022-03-15 09:20:41 +03:00
|
|
|
_device_function_table=device_function_table,
|
2022-03-30 23:50:00 +03:00
|
|
|
target=Target.parse_from_table(hat_toml[Target.TableName]),
|
|
|
|
dependencies=Dependencies.parse_from_table(hat_toml[Dependencies.TableName]),
|
|
|
|
compiled_with=CompiledWith.parse_from_table(hat_toml[CompiledWith.TableName]),
|
|
|
|
declaration=Declaration.parse_from_table(hat_toml[Declaration.TableName]),
|
|
|
|
path=Path(filepath).resolve()
|
|
|
|
)
|
2021-03-16 21:01:10 +03:00
|
|
|
return hat_file
|