зеркало из https://github.com/microsoft/hat.git
Add runtime output array sample and deserialization test (#60)
* removed duplicate test * sample hat file for dynamic output arrays * verify parsing of range sample Co-authored-by: Lisa Ong <onglisa@microsoft.com>
This commit is contained in:
Родитель
004632c406
Коммит
fa6d727cf6
|
@ -0,0 +1,108 @@
|
|||
#ifndef __Range_library__
|
||||
#define __Range_library__
|
||||
|
||||
#ifdef TOML
|
||||
|
||||
[description]
|
||||
comment = "John Doe's Range Library"
|
||||
author = "John Doe"
|
||||
version = "1.2.3.5"
|
||||
license_url = "https://www.apache.org/licenses/LICENSE-2.0.html"
|
||||
|
||||
[functions]
|
||||
[functions.Range_0911ac6519e78bff5590e40539aee0cf]
|
||||
name = "Range_0911ac6519e78bff5590e40539aee0cf"
|
||||
description = "CPU Implementation of the Range algorithm, based on https://github.com/onnx/onnx/blob/main/docs/Operators.md#Range"
|
||||
calling_convention = "stdcall"
|
||||
arguments = [
|
||||
{name = "start", description = "First entry for the range of output values", logical_type = "affine_array", declared_type = "int32_t*", element_type = "int32_t", usage = "input", shape = [], affine_map = [], affine_offset = 0},
|
||||
{name = "limit", description = "Exclusive upper limit for the range of output values", logical_type = "affine_array", declared_type = "int32_t*", element_type = "int32_t", usage = "input", shape = [], affine_map = [], affine_offset = 0},
|
||||
{name = "delta", description = "Value to step by", logical_type = "affine_array", declared_type = "int32_t*", element_type = "int32_t", usage = "input", shape = [], affine_map = [], affine_offset = 0},
|
||||
{name = "output", description = "A 1-D array with same type as the inputs containing generated range of values", logical_type = "runtime_array", declared_type = "int32_t**", element_type = "int32_t", usage = "output", size = "output_dim"},
|
||||
{name = "output_dim", description = "Number of elements of the output array", logical_type = "element", declared_type = "uint32_t*", element_type = "uint32_t", usage = "output"}
|
||||
]
|
||||
|
||||
return = {name = "", description = "", logical_type = "void", declared_type = "void", element_type = "void", usage = "output"}
|
||||
|
||||
[[functions.Range_0911ac6519e78bff5590e40539aee0cf.auxiliary.onnx]]
|
||||
op_type = "Range"
|
||||
input_shapes = [[], [], []]
|
||||
output_shapes = [["*"]]
|
||||
|
||||
[target]
|
||||
[target.required]
|
||||
os = "linux"
|
||||
|
||||
[target.required.CPU]
|
||||
architecture = ""
|
||||
extensions = []
|
||||
|
||||
[target.optimized_for]
|
||||
os = "linux"
|
||||
|
||||
[target.optimized_for.CPU]
|
||||
architecture = ""
|
||||
extensions = []
|
||||
|
||||
# The dependencies table provides information that a consumer of this .hat file
|
||||
# will need to act on in order to properly consume the package, such as library
|
||||
# files to link and dynamic libraries to make available at runtime
|
||||
[dependencies]
|
||||
link_target = "Range_model_s.a"
|
||||
deploy_files = []
|
||||
dynamic = []
|
||||
|
||||
[dependencies.auxiliary]
|
||||
dynamic = "Range_model_d.so"
|
||||
static = "Range_model_s.a"
|
||||
|
||||
# The compiled_with table provides information that a consumer of this .hat file
|
||||
# may find useful but may not necessarily need to act on in order to successfully
|
||||
# consume this package
|
||||
[compiled_with]
|
||||
compiler = ""
|
||||
flags = ""
|
||||
crt = ""
|
||||
libraries = []
|
||||
|
||||
[declaration]
|
||||
code = """
|
||||
#endif // TOML
|
||||
//
|
||||
// Header for Range library
|
||||
//
|
||||
|
||||
#include <float.h>
|
||||
#include <math.h>
|
||||
#include <stdbool.h>
|
||||
#include <stdint.h>
|
||||
#include <string.h>
|
||||
|
||||
#if defined(__cplusplus)
|
||||
extern "C"
|
||||
{
|
||||
#endif // defined(__cplusplus)
|
||||
|
||||
//
|
||||
// Functions
|
||||
//
|
||||
|
||||
void Range_0911ac6519e78bff5590e40539aee0cf(const int32_t start[1], const int32_t limit[1], const int32_t delta[1], int32_t** output, uint32_t* output_dim );
|
||||
|
||||
#ifndef __Range_0911ac6519e78bff5590e40539aee0cf_DEFINED__
|
||||
#define __Range_0911ac6519e78bff5590e40539aee0cf_DEFINED__
|
||||
void (*Range_0911ac6519e78bff5590e40539aee0cf)(int32_t*, int32_t*, int32_t*, int32_t**, uint32_t*) = Range_0911ac6519e78bff5590e40539aee0cf;
|
||||
#endif
|
||||
|
||||
|
||||
#if defined(__cplusplus)
|
||||
} // extern "C"
|
||||
#endif // defined(__cplusplus)
|
||||
|
||||
#ifdef TOML
|
||||
|
||||
"""
|
||||
|
||||
#endif // TOML
|
||||
|
||||
#endif // __Range_library__
|
|
@ -12,7 +12,7 @@ from hatlib import (
|
|||
|
||||
class HATFile_test(unittest.TestCase):
|
||||
|
||||
def test_file_basic_serialize(self):
|
||||
def test_basic_serialize(self):
|
||||
# Construct a HAT file from scratch
|
||||
# Start with a function definition
|
||||
my_function = Function(
|
||||
|
@ -69,7 +69,7 @@ class HATFile_test(unittest.TestCase):
|
|||
self.assertEqual(hat_file1.compiled_with.to_table(), hat_file2.compiled_with.to_table())
|
||||
self.assertTrue("my_function" in hat_file2.function_map)
|
||||
|
||||
def test_file_basic_deserialize(self):
|
||||
def test_sample_gemm_deserialize(self):
|
||||
# Load a HAT file from the samples directory
|
||||
hat_file1 = HATFile.Deserialize(
|
||||
os.path.join(os.path.dirname(__file__), "..", "samples", "sample_gemm_library.hat")
|
||||
|
@ -88,6 +88,26 @@ class HATFile_test(unittest.TestCase):
|
|||
self.assertTrue("GEMM_B94D27B9934D3E08" in hat_file1.function_map)
|
||||
self.assertTrue("blas_sgemm_row_major" in hat_file1.function_map)
|
||||
|
||||
def test_sample_range_deserialize(self):
|
||||
hat_file1 = HATFile.Deserialize(
|
||||
os.path.join(os.path.dirname(__file__), "..", "samples", "sample_range_library.hat")
|
||||
)
|
||||
|
||||
function1 = hat_file1.function_map["Range_0911ac6519e78bff5590e40539aee0cf"]
|
||||
output_arg = function1.arguments[-2]
|
||||
output_dim_arg = function1.arguments[-1]
|
||||
|
||||
self.assertEqual(output_dim_arg.logical_type, ParameterType.Element)
|
||||
self.assertEqual(output_dim_arg.declared_type, "uint32_t*")
|
||||
self.assertEqual(output_dim_arg.element_type, "uint32_t")
|
||||
self.assertEqual(output_arg.usage, UsageType.Output)
|
||||
|
||||
self.assertEqual(output_arg.logical_type, ParameterType.RuntimeArray)
|
||||
self.assertEqual(output_arg.declared_type, "int32_t**")
|
||||
self.assertEqual(output_arg.element_type, "int32_t")
|
||||
self.assertEqual(output_arg.usage, UsageType.Output)
|
||||
self.assertEqual(output_arg.size, output_dim_arg.name) # references the output_dim arg for the size
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
@ -1,93 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
import sys
|
||||
import os
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from hatlib import (
|
||||
CallingConventionType, CompiledWith, Declaration, Dependencies, Description, Function, FunctionTable, HATFile,
|
||||
OperatingSystem, Parameter, ParameterType, Target, UsageType
|
||||
)
|
||||
|
||||
|
||||
class HATFile_test(unittest.TestCase):
|
||||
|
||||
def test_file_basic_serialize(self):
|
||||
# Construct a HAT file from scratch
|
||||
# Start with a function definition
|
||||
my_function = Function(
|
||||
name="my_function",
|
||||
description="Some description",
|
||||
calling_convention=CallingConventionType.StdCall,
|
||||
return_info=Parameter(
|
||||
logical_type=ParameterType.RuntimeArray,
|
||||
declared_type="float*",
|
||||
element_type="float",
|
||||
usage=UsageType.Input,
|
||||
shape="[16, 16]",
|
||||
affine_map=[16, 1],
|
||||
size="16 * 16 * sizeof(float)"
|
||||
)
|
||||
)
|
||||
# Create the function table
|
||||
functions = FunctionTable({"my_function": my_function})
|
||||
# Create the HATFile object
|
||||
hat_file1 = HATFile(
|
||||
name="test_file",
|
||||
description=Description(
|
||||
version="0.0.1", author="me", license_url="https://www.apache.org/licenses/LICENSE-2.0.html"
|
||||
),
|
||||
_function_table=functions,
|
||||
target=Target(
|
||||
required=Target.Required(
|
||||
os=OperatingSystem.Windows,
|
||||
cpu=Target.Required.CPU(architecture="Haswell", extensions=["AVX2"]),
|
||||
gpu=None
|
||||
),
|
||||
optimized_for=Target.OptimizedFor()
|
||||
),
|
||||
dependencies=Dependencies(link_target="my_lib.lib"),
|
||||
compiled_with=CompiledWith(compiler="VC++"),
|
||||
declaration=Declaration(),
|
||||
path=Path(".").resolve()
|
||||
)
|
||||
# Serialize it to disk
|
||||
test_file_name = "test_file_serialize.hat"
|
||||
|
||||
try:
|
||||
hat_file1.Serialize(test_file_name)
|
||||
# Deserialize it and verify it has what we expect
|
||||
hat_file2 = HATFile.Deserialize(test_file_name)
|
||||
finally:
|
||||
# Remove the file
|
||||
os.remove(test_file_name)
|
||||
|
||||
# Do basic verification that the deserialized HatFile contains what we
|
||||
# specified when we created the HATFile directly
|
||||
self.assertEqual(hat_file1.description, hat_file2.description)
|
||||
self.assertEqual(hat_file1.dependencies, hat_file2.dependencies)
|
||||
self.assertEqual(hat_file1.compiled_with.to_table(), hat_file2.compiled_with.to_table())
|
||||
self.assertTrue("my_function" in hat_file2.function_map)
|
||||
|
||||
def test_file_basic_deserialize(self):
|
||||
# Load a HAT file from the samples directory
|
||||
hat_file1 = HATFile.Deserialize(
|
||||
os.path.join(os.path.dirname(__file__), "..", "samples", "sample_gemm_library.hat")
|
||||
)
|
||||
description = {
|
||||
"author": "John Doe",
|
||||
"version": "1.2.3.5",
|
||||
"license_url": "https://www.apache.org/licenses/LICENSE-2.0.html",
|
||||
}
|
||||
|
||||
# Do basic verification of known values in the file
|
||||
# Verify the description has entries we expect
|
||||
self.assertLessEqual(description.items(), hat_file1.description.to_table().items())
|
||||
# Verify the list of functions
|
||||
self.assertTrue(len(hat_file1.function_map) == 2)
|
||||
self.assertTrue("GEMM_B94D27B9934D3E08" in hat_file1.function_map)
|
||||
self.assertTrue("blas_sgemm_row_major" in hat_file1.function_map)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Загрузка…
Ссылка в новой задаче