Add runtime output array sample and deserialization test (#60)

* removed duplicate test

* sample hat file for dynamic output arrays

* verify parsing of range sample

Co-authored-by: Lisa Ong <onglisa@microsoft.com>
This commit is contained in:
Lisa Ong 2022-08-08 11:53:53 +08:00 коммит произвёл GitHub
Родитель 004632c406
Коммит fa6d727cf6
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
3 изменённых файлов: 130 добавлений и 95 удалений

Просмотреть файл

@ -0,0 +1,108 @@
#ifndef __Range_library__
#define __Range_library__
#ifdef TOML
[description]
comment = "John Doe's Range Library"
author = "John Doe"
version = "1.2.3.5"
license_url = "https://www.apache.org/licenses/LICENSE-2.0.html"
[functions]
[functions.Range_0911ac6519e78bff5590e40539aee0cf]
name = "Range_0911ac6519e78bff5590e40539aee0cf"
description = "CPU Implementation of the Range algorithm, based on https://github.com/onnx/onnx/blob/main/docs/Operators.md#Range"
calling_convention = "stdcall"
arguments = [
{name = "start", description = "First entry for the range of output values", logical_type = "affine_array", declared_type = "int32_t*", element_type = "int32_t", usage = "input", shape = [], affine_map = [], affine_offset = 0},
{name = "limit", description = "Exclusive upper limit for the range of output values", logical_type = "affine_array", declared_type = "int32_t*", element_type = "int32_t", usage = "input", shape = [], affine_map = [], affine_offset = 0},
{name = "delta", description = "Value to step by", logical_type = "affine_array", declared_type = "int32_t*", element_type = "int32_t", usage = "input", shape = [], affine_map = [], affine_offset = 0},
{name = "output", description = "A 1-D array with same type as the inputs containing generated range of values", logical_type = "runtime_array", declared_type = "int32_t**", element_type = "int32_t", usage = "output", size = "output_dim"},
{name = "output_dim", description = "Number of elements of the output array", logical_type = "element", declared_type = "uint32_t*", element_type = "uint32_t", usage = "output"}
]
return = {name = "", description = "", logical_type = "void", declared_type = "void", element_type = "void", usage = "output"}
[[functions.Range_0911ac6519e78bff5590e40539aee0cf.auxiliary.onnx]]
op_type = "Range"
input_shapes = [[], [], []]
output_shapes = [["*"]]
[target]
[target.required]
os = "linux"
[target.required.CPU]
architecture = ""
extensions = []
[target.optimized_for]
os = "linux"
[target.optimized_for.CPU]
architecture = ""
extensions = []
# The dependencies table provides information that a consumer of this .hat file
# will need to act on in order to properly consume the package, such as library
# files to link and dynamic libraries to make available at runtime
[dependencies]
link_target = "Range_model_s.a"
deploy_files = []
dynamic = []
[dependencies.auxiliary]
dynamic = "Range_model_d.so"
static = "Range_model_s.a"
# The compiled_with table provides information that a consumer of this .hat file
# may find useful but may not necessarily need to act on in order to successfully
# consume this package
[compiled_with]
compiler = ""
flags = ""
crt = ""
libraries = []
[declaration]
code = """
#endif // TOML
//
// Header for Range library
//
#include <float.h>
#include <math.h>
#include <stdbool.h>
#include <stdint.h>
#include <string.h>
#if defined(__cplusplus)
extern "C"
{
#endif // defined(__cplusplus)
//
// Functions
//
void Range_0911ac6519e78bff5590e40539aee0cf(const int32_t start[1], const int32_t limit[1], const int32_t delta[1], int32_t** output, uint32_t* output_dim );
#ifndef __Range_0911ac6519e78bff5590e40539aee0cf_DEFINED__
#define __Range_0911ac6519e78bff5590e40539aee0cf_DEFINED__
void (*Range_0911ac6519e78bff5590e40539aee0cf)(int32_t*, int32_t*, int32_t*, int32_t**, uint32_t*) = Range_0911ac6519e78bff5590e40539aee0cf;
#endif
#if defined(__cplusplus)
} // extern "C"
#endif // defined(__cplusplus)
#ifdef TOML
"""
#endif // TOML
#endif // __Range_library__

Просмотреть файл

@ -12,7 +12,7 @@ from hatlib import (
class HATFile_test(unittest.TestCase):
def test_file_basic_serialize(self):
def test_basic_serialize(self):
# Construct a HAT file from scratch
# Start with a function definition
my_function = Function(
@ -69,7 +69,7 @@ class HATFile_test(unittest.TestCase):
self.assertEqual(hat_file1.compiled_with.to_table(), hat_file2.compiled_with.to_table())
self.assertTrue("my_function" in hat_file2.function_map)
def test_file_basic_deserialize(self):
def test_sample_gemm_deserialize(self):
# Load a HAT file from the samples directory
hat_file1 = HATFile.Deserialize(
os.path.join(os.path.dirname(__file__), "..", "samples", "sample_gemm_library.hat")
@ -88,6 +88,26 @@ class HATFile_test(unittest.TestCase):
self.assertTrue("GEMM_B94D27B9934D3E08" in hat_file1.function_map)
self.assertTrue("blas_sgemm_row_major" in hat_file1.function_map)
def test_sample_range_deserialize(self):
hat_file1 = HATFile.Deserialize(
os.path.join(os.path.dirname(__file__), "..", "samples", "sample_range_library.hat")
)
function1 = hat_file1.function_map["Range_0911ac6519e78bff5590e40539aee0cf"]
output_arg = function1.arguments[-2]
output_dim_arg = function1.arguments[-1]
self.assertEqual(output_dim_arg.logical_type, ParameterType.Element)
self.assertEqual(output_dim_arg.declared_type, "uint32_t*")
self.assertEqual(output_dim_arg.element_type, "uint32_t")
self.assertEqual(output_arg.usage, UsageType.Output)
self.assertEqual(output_arg.logical_type, ParameterType.RuntimeArray)
self.assertEqual(output_arg.declared_type, "int32_t**")
self.assertEqual(output_arg.element_type, "int32_t")
self.assertEqual(output_arg.usage, UsageType.Output)
self.assertEqual(output_arg.size, output_dim_arg.name) # references the output_dim arg for the size
if __name__ == '__main__':
unittest.main()

Просмотреть файл

@ -1,93 +0,0 @@
#!/usr/bin/env python3
import sys
import os
import unittest
from pathlib import Path
from hatlib import (
CallingConventionType, CompiledWith, Declaration, Dependencies, Description, Function, FunctionTable, HATFile,
OperatingSystem, Parameter, ParameterType, Target, UsageType
)
class HATFile_test(unittest.TestCase):
def test_file_basic_serialize(self):
# Construct a HAT file from scratch
# Start with a function definition
my_function = Function(
name="my_function",
description="Some description",
calling_convention=CallingConventionType.StdCall,
return_info=Parameter(
logical_type=ParameterType.RuntimeArray,
declared_type="float*",
element_type="float",
usage=UsageType.Input,
shape="[16, 16]",
affine_map=[16, 1],
size="16 * 16 * sizeof(float)"
)
)
# Create the function table
functions = FunctionTable({"my_function": my_function})
# Create the HATFile object
hat_file1 = HATFile(
name="test_file",
description=Description(
version="0.0.1", author="me", license_url="https://www.apache.org/licenses/LICENSE-2.0.html"
),
_function_table=functions,
target=Target(
required=Target.Required(
os=OperatingSystem.Windows,
cpu=Target.Required.CPU(architecture="Haswell", extensions=["AVX2"]),
gpu=None
),
optimized_for=Target.OptimizedFor()
),
dependencies=Dependencies(link_target="my_lib.lib"),
compiled_with=CompiledWith(compiler="VC++"),
declaration=Declaration(),
path=Path(".").resolve()
)
# Serialize it to disk
test_file_name = "test_file_serialize.hat"
try:
hat_file1.Serialize(test_file_name)
# Deserialize it and verify it has what we expect
hat_file2 = HATFile.Deserialize(test_file_name)
finally:
# Remove the file
os.remove(test_file_name)
# Do basic verification that the deserialized HatFile contains what we
# specified when we created the HATFile directly
self.assertEqual(hat_file1.description, hat_file2.description)
self.assertEqual(hat_file1.dependencies, hat_file2.dependencies)
self.assertEqual(hat_file1.compiled_with.to_table(), hat_file2.compiled_with.to_table())
self.assertTrue("my_function" in hat_file2.function_map)
def test_file_basic_deserialize(self):
# Load a HAT file from the samples directory
hat_file1 = HATFile.Deserialize(
os.path.join(os.path.dirname(__file__), "..", "samples", "sample_gemm_library.hat")
)
description = {
"author": "John Doe",
"version": "1.2.3.5",
"license_url": "https://www.apache.org/licenses/LICENSE-2.0.html",
}
# Do basic verification of known values in the file
# Verify the description has entries we expect
self.assertLessEqual(description.items(), hat_file1.description.to_table().items())
# Verify the list of functions
self.assertTrue(len(hat_file1.function_map) == 2)
self.assertTrue("GEMM_B94D27B9934D3E08" in hat_file1.function_map)
self.assertTrue("blas_sgemm_row_major" in hat_file1.function_map)
if __name__ == '__main__':
unittest.main()