Merged PR 4653: Added simple mapping for C++ to C# types. Added the includes for the C++ wrapper.

Added a mapping from some C++ to C# types. This list will have to be updated/modified after we confirm the types we support. It will also need to be modified when we support structs.

Prior, the C++ wrapper didn't include any of the `#includes` from the C++ files. Now, it parses the includes list and includes them as well.

NOTE: The includes list does not say whether it was `#includes "file"` or `#includes <file>` and after talking with @<Teo Magnino Chaban>, determining that is not trivial. According to the [C standard](http://www.open-std.org/jtc1/sc22/wg14/www/docs/n1570.pdf#page=182), section 6.10.2 the only real difference is where the compiler looks initially for the file. If it is `#includes "file"` the compiler will first look somewhere (specific to each compiler, but usually in that directory) and then if it can't find the file will reprocess that line as if it were `#includes <file>`. Due to this, I have decided to do all includes as `#includes "file"`. The only potential problem I see with this is if there is a file in the local directory with the exact same name as in the system directory. Open to discussion on this point.
This commit is contained in:
Michael Sharp 2019-07-03 21:23:18 +00:00
Родитель 5a2026eba4
Коммит 80e38d5d56
5 изменённых файлов: 431 добавлений и 47 удалений

Просмотреть файл

@ -113,19 +113,21 @@ class Plugin(PluginBase):
if dm.result != 0:
return dm.result
current_filename_index = 0
for value in plugin_context.values():
function_list = value["function_list"]
for function, file_name in zip(function_list, cs_filenames):
status_stream.write("'{}'...".format(file_name))
for function in function_list:
status_stream.write("'{}'...".format(cs_filenames[current_filename_index]))
with status_stream.DoneManager() as dm:
dm.result = CreateCsFile(
function,
file_name,
cs_filenames[current_filename_index],
context["output_name"],
cls._GenerateFileHeader,
)
if dm.result != 0:
return dm.result
current_filename_index += 1
status_stream.write("'{}'...".format(cmake_filename))
with status_stream.DoneManager() as dm:

Просмотреть файл

@ -20,9 +20,11 @@ def CreateCppWrapper(output_filename, data, generate_header_func):
output = [generate_header_func("// ")]
output.append(_GeneratePreProcessorCommands())
for key, value in data.items():
function_list = value["function_list"]
for function in function_list:
for value in data.values():
for include in value["include_list"]:
output.insert(1, "#include \"{}\"\n".format(include.split(os.sep)[-1]))
for function in value["function_list"]:
function_name = function["func_name"]
var_names = function["var_names"]

Просмотреть файл

@ -44,8 +44,8 @@ def GenerateMLNetWrapper(function, output_name):
function_name = function["func_name"]
variable_names = function["var_names"]
variable_types = function["simple_var_types"]
return_type = function["simple_return_type"]
variable_types = list(_CppToCSharpVariableMapping(function["simple_var_types"]))
return_type = next(_CppToCSharpVariableMapping([function["simple_return_type"]]))
extension_class_name = "{0}ExtensionClass".format(function_name)
estimator_class_name = "{0}Estimator".format(function_name)
@ -72,6 +72,39 @@ def GenerateMLNetWrapper(function, output_name):
return "".join(code)
def _CppToCSharpVariableMapping(variable_types):
"""
Takes a list of C++ variable types and returns an equal size list of
C# variable types
"""
# Initialize the dictionary of type mappings
mapping = {
"std::int8_t" : "sbyte",
"int8_t" : "sbyte",
"std::int16_t" : "short",
"int16_t" : "short",
"std::int32_t" : "int",
"int32_t" : "int",
"std::int64_t" : "long",
"int64_t" : "long",
"std::uint8_t" : "byte",
"uint8_t" : "byte",
"std::uint16_t" : "ushort",
"uint16_t" : "ushort",
"std::uint32_t" : "uint",
"uint32_t" : "uint",
"std::uint64_t" : "ulong",
"uint64_t" : "ulong",
"int" : "int",
"float" : "float",
"double" : "double",
"char" : "char",
"bool" : "bool"
}
for variable_type in variable_types:
yield mapping[variable_type]
def _FormatAsCSharpAndWriteFile(output_filename, code):
"""
Does some simple C# Formatting

Просмотреть файл

@ -34,7 +34,7 @@ class StandardSuite(unittest.TestCase):
mocked.return_value = sink
function_input = self.CreateSingleFunctionInput()
function_input = self._CreateSingleFunctionInput()
result = CreateCppWrapper.CreateCppWrapper("ignored", function_input, lambda prefix: "{}The file header!\n".format(prefix))
sink = sink.getvalue()
@ -71,7 +71,7 @@ class StandardSuite(unittest.TestCase):
mocked.return_value = sink
function_input = self.CreateTwoFunctionInput()
function_input = self._CreateTwoFunctionInput()
result = CreateCppWrapper.CreateCppWrapper("ignored", function_input, lambda prefix: "{}The file header!\n".format(prefix))
sink = sink.getvalue()
@ -115,7 +115,7 @@ class StandardSuite(unittest.TestCase):
mocked.return_value = sink
function_input = self.CreateFunctionOneInputParameter()
function_input = self._CreateFunctionOneInputParameter()
result = CreateCppWrapper.CreateCppWrapper("ignored", function_input, lambda prefix: "{}The file header!\n".format(prefix))
sink = sink.getvalue()
@ -152,7 +152,7 @@ class StandardSuite(unittest.TestCase):
mocked.return_value = sink
function_input = self.CreateFunctionNoInputParameters()
function_input = self._CreateFunctionNoInputParameters()
result = CreateCppWrapper.CreateCppWrapper("ignored", function_input, lambda prefix: "{}The file header!\n".format(prefix))
sink = sink.getvalue()
@ -182,65 +182,112 @@ class StandardSuite(unittest.TestCase):
),
)
def test_FunctionWithMultipleTypesAndIncludesParameters(self):
with unittest.mock.patch("MlNetPluginImpl.CreateCppWrapper.open") as mocked:
sink = six.moves.StringIO()
sink.close = lambda: None
mocked.return_value = sink
function_input = self._CreateMultipleTypesWithIncludesFunctionInput()
result = CreateCppWrapper.CreateCppWrapper("ignored", function_input, lambda prefix: "{}The file header!\n".format(prefix))
sink = sink.getvalue()
self.assertEqual(result, 0)
self.assertEqual(
sink,
textwrap.dedent(
"""\
// The file header!
#include "cstdint"
#if defined(_MSC_VER)
# define EXPORT __declspec(dllexport)
#elif defined(__GNUC__)
# define EXPORT __attribute__((visibility("default")))
#else
# error unsupported!
#endif
std::int64_t Add(std::uint32_t, double, char);
extern "C" {
EXPORT std::int64_t AddProxy(std::uint32_t a, double b, char c) {
return Add(a, b, c);
}
}
"""
),
)
"""
TESTING HELPER METHODS
"""
def CreateFunctionOneInputParameter(self):
result = {}
function_list = []
intermediate_object = {}
function = self.CreateFunction("Add",
def _CreateFunctionOneInputParameter(self):
function = self._CreateFunction("Add",
["a"],
"int",
["int"],
["int"],
"int"
)
function_list.append(function)
intermediate_object["function_list"] = function_list
result["file_name"] = intermediate_object
function_list = [function]
intermediate_object = {
"function_list" : function_list,
"include_list" : []
}
result = {
"file_name" : intermediate_object
}
return result
def CreateFunctionNoInputParameters(self):
result = {}
function_list = []
intermediate_object = {}
function = self.CreateFunction("Add",
def _CreateFunctionNoInputParameters(self):
function = self._CreateFunction("Add",
[],
"int",
[],
[],
"int"
)
function_list.append(function)
intermediate_object["function_list"] = function_list
result["file_name"] = intermediate_object
function_list = [function]
intermediate_object = {
"function_list" : function_list,
"include_list" : []
}
result = {
"file_name" : intermediate_object
}
return result
def CreateSingleFunctionInput(self):
result = {}
function_list = []
intermediate_object = {}
function = self.CreateFunction("Add",
def _CreateSingleFunctionInput(self):
function = self._CreateFunction("Add",
["a", "b"],
"int",
["int", "int"],
["int", "int"],
"int"
)
function_list.append(function)
intermediate_object["function_list"] = function_list
result["file_name"] = intermediate_object
function_list = [function]
intermediate_object = {
"function_list" : function_list,
"include_list" : []
}
result = {
"file_name" : intermediate_object
}
return result
def CreateTwoFunctionInput(self):
result = {}
function_list = []
intermediate_object = {}
first_function = self.CreateFunction("Add",
def _CreateTwoFunctionInput(self):
first_function = self._CreateFunction("Add",
["a", "b"],
"int",
["int", "int"],
@ -248,7 +295,7 @@ class StandardSuite(unittest.TestCase):
"int"
)
second_function = self.CreateFunction("Subtract",
second_function = self._CreateFunction("Subtract",
["a", "b"],
"int",
["int", "int"],
@ -256,14 +303,41 @@ class StandardSuite(unittest.TestCase):
"int"
)
function_list.append(first_function)
function_list.append(second_function)
intermediate_object["function_list"] = function_list
result["file_name"] = intermediate_object
function_list = [first_function, second_function]
intermediate_object = {
"function_list" : function_list,
"include_list" : []
}
result = {
"file_name" : intermediate_object
}
return result
def CreateFunction(self, func_name, var_names, simple_return_type, raw_var_type, simple_var_types, raw_return_type):
def _CreateMultipleTypesWithIncludesFunctionInput(self):
function = self._CreateFunction("Add",
["a", "b", "c"],
"std::int64_t",
["std::uint32_t", "double", "char"],
["std::uint32_t", "double", "char"],
"std::int64_t"
)
function_list = [function]
intermediate_object = {
"function_list" : function_list,
"include_list" : ["{0}some{0}file{0}path{0}cstdint".format(os.sep)]
}
result = {
"file_name" : intermediate_object
}
return result
def _CreateFunction(self, func_name, var_names, simple_return_type, raw_var_type, simple_var_types, raw_return_type):
function = {}
function["func_name"] = func_name
function["var_names"] = var_names

Просмотреть файл

@ -684,6 +684,268 @@ class StandardSuite(unittest.TestCase):
),
)
def test_FullDifferentTypesParametersFile(self):
with unittest.mock.patch("MlNetPluginImpl.CreateCsFile.open") as mocked:
sink = six.moves.StringIO()
sink.close = lambda: None
mocked.return_value = sink
function = self._CreateFunctionDifferentTypesInputParameters()
result = CreateCsFile.CreateCsFile(function, "ignored", "Add", lambda prefix: "{}The file header!\n".format(prefix))
sink = sink.getvalue()
self.assertEqual(result, 0)
self.assertEqual(
textwrap.dedent(sink),
textwrap.dedent(
"""\
// The file header!
using Microsoft.ML;
using Microsoft.ML.Data;
using Microsoft.ML.Runtime;
using System;
using System.Linq;
using System.Runtime.InteropServices;
namespace Microsoft.ML.Autogen
{
public static class AddExtensionClass
{
public class TransformParameter<T>
{
private readonly T _rawValue;
public readonly DataViewSchema.Column? Column;
public T GetValue(DataViewRow row)
{
if (Column.HasValue)
{
var column = row.Schema[Column.Value.Name];
var getter = row.GetGetter<T>(column);
T value = default;
getter(ref value);
return value;
}
else
{
return _rawValue;
}
}
public TransformParameter(T value)
{
_rawValue = value;
Column = null;
}
public TransformParameter(DataViewSchema.Column column)
{
_rawValue = default;
Column = column;
}
}
public static AddEstimator Add(this TransformsCatalog catalog, uint a, double b, char c, string outputColumn)
=> AddEstimator.Create(CatalogUtils.GetEnvironment(catalog), a, b, c, outputColumn);
public static AddEstimator Add(this TransformsCatalog catalog, uint a, double b, DataViewSchema.Column c, string outputColumn)
=> AddEstimator.Create(CatalogUtils.GetEnvironment(catalog), a, b, c, outputColumn);
public static AddEstimator Add(this TransformsCatalog catalog, uint a, DataViewSchema.Column b, char c, string outputColumn)
=> AddEstimator.Create(CatalogUtils.GetEnvironment(catalog), a, b, c, outputColumn);
public static AddEstimator Add(this TransformsCatalog catalog, uint a, DataViewSchema.Column b, DataViewSchema.Column c, string outputColumn)
=> AddEstimator.Create(CatalogUtils.GetEnvironment(catalog), a, b, c, outputColumn);
public static AddEstimator Add(this TransformsCatalog catalog, DataViewSchema.Column a, double b, char c, string outputColumn)
=> AddEstimator.Create(CatalogUtils.GetEnvironment(catalog), a, b, c, outputColumn);
public static AddEstimator Add(this TransformsCatalog catalog, DataViewSchema.Column a, double b, DataViewSchema.Column c, string outputColumn)
=> AddEstimator.Create(CatalogUtils.GetEnvironment(catalog), a, b, c, outputColumn);
public static AddEstimator Add(this TransformsCatalog catalog, DataViewSchema.Column a, DataViewSchema.Column b, char c, string outputColumn)
=> AddEstimator.Create(CatalogUtils.GetEnvironment(catalog), a, b, c, outputColumn);
public static AddEstimator Add(this TransformsCatalog catalog, DataViewSchema.Column a, DataViewSchema.Column b, DataViewSchema.Column c, string outputColumn)
=> AddEstimator.Create(CatalogUtils.GetEnvironment(catalog), a, b, c, outputColumn);
}
public class AddEstimator : IEstimator<AddTransformer>
{
private AddExtensionClass.TransformParameter<uint> _a = default;
private AddExtensionClass.TransformParameter<double> _b = default;
private AddExtensionClass.TransformParameter<char> _c = default;
private string _outputColumn = default;
private readonly IHost _host;
private static Type GetParameterType(object obj)
{
if (obj.GetType() == typeof(DataViewSchema.Column))
{
return ((DataViewSchema.Column)obj).Type.RawType;
}
else
{
return obj.GetType();
}
}
private static Type GetParameterClassType(object obj)
{
var type = GetParameterType(obj);
return typeof(AddExtensionClass.TransformParameter<>).MakeGenericType(type);
}
public static AddEstimator Create(IHostEnvironment env, object a, object b, object c, string outputColumn)
{
var aParam = new object[] { a };
var aType = GetParameterClassType(a);
var aInstance = Activator.CreateInstance(aType, aParam);
var bParam = new object[] { b };
var bType = GetParameterClassType(b);
var bInstance = Activator.CreateInstance(bType, bParam);
var cParam = new object[] { c };
var cType = GetParameterClassType(c);
var cInstance = Activator.CreateInstance(cType, cParam);
var param = new object[] { env, aInstance, bInstance, cInstance, outputColumn };
var estimator = Activator.CreateInstance(typeof(AddEstimator), param);
return (AddEstimator)estimator;
}
public AddEstimator(IHostEnvironment env, AddExtensionClass.TransformParameter<uint> a, AddExtensionClass.TransformParameter<double> b, AddExtensionClass.TransformParameter<char> c, string outputColumn)
{
_a = a;
_b = b;
_c = c;
_outputColumn = outputColumn;
_host = env.Register(nameof(AddEstimator));
}
public AddTransformer Fit(IDataView input)
{
return new AddTransformer(_host, _a, _b, _c, _outputColumn);
}
public SchemaShape GetOutputSchema(SchemaShape inputSchema)
{
var columns = inputSchema.ToDictionary(x => x.Name);
SchemaShape.Column col;
col = new SchemaShape.Column(_outputColumn, SchemaShape.Column.VectorKind.Scalar,
ColumnTypeExtensions.PrimitiveTypeFromType(typeof(long)), false, null);
columns[_outputColumn] = col;
return new SchemaShape(columns.Values);
}
}
public class AddTransformer : RowToRowTransformerBase
{
private AddExtensionClass.TransformParameter<uint> _a = default;
private AddExtensionClass.TransformParameter<double> _b = default;
private AddExtensionClass.TransformParameter<char> _c = default;
private string _outputColumn = default;
public AddTransformer(IHost host, AddExtensionClass.TransformParameter<uint> a, AddExtensionClass.TransformParameter<double> b, AddExtensionClass.TransformParameter<char> c, string outputColumn) :
base(host.Register(nameof(AddTransformer)))
{
_a = a;
_b = b;
_c = c;
_outputColumn = outputColumn;
}
protected class Mapper : MapperBase
{
private readonly AddTransformer _parent;
[DllImport("Add.dll", EntryPoint = "AddProxy")]
extern static long Add(uint a, double b, char c);
public Mapper(AddTransformer parent, DataViewSchema inputSchema) :
base(parent.Host.Register(nameof(Mapper)), inputSchema, parent)
{
_parent = parent;
}
protected override DataViewSchema.DetachedColumn[] GetOutputColumnsCore()
{
var output = new DataViewSchema.DetachedColumn(_parent._outputColumn,
ColumnTypeExtensions.PrimitiveTypeFromType(typeof(long)));
return new DataViewSchema.DetachedColumn[] { output };
}
private Delegate MakeGetter(DataViewRow input, int iinfo)
{
ValueGetter<long> result = (ref long dst) =>
{
var aVal = _parent._a.GetValue(input);
var bVal = _parent._b.GetValue(input);
var cVal = _parent._c.GetValue(input);
dst = Add(aVal, bVal, cVal);
};
return result;
}
protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func<int, bool> activeOutput, out Action disposer)
{
disposer = null;
return MakeGetter(input, iinfo);
}
protected override Func<int, bool> GetDependenciesCore(Func<int, bool> activeOutput)
{
var active = new bool[InputSchema.Count];
for (int i = 0; i < InputSchema.Count; i++)
{
if (_parent._a.Column.HasValue && InputSchema[i].Name.Equals(_parent._a.Column.Value.Name))
{
active[i] = true;
}
if (_parent._b.Column.HasValue && InputSchema[i].Name.Equals(_parent._b.Column.Value.Name))
{
active[i] = true;
}
if (_parent._c.Column.HasValue && InputSchema[i].Name.Equals(_parent._c.Column.Value.Name))
{
active[i] = true;
}
}
return col => active[col];
}
protected override void SaveModel(ModelSaveContext ctx)
{
throw new NotImplementedException();
}
}
protected override IRowMapper MakeRowMapper(DataViewSchema schema) => new Mapper(this, schema);
protected override void SaveModel(ModelSaveContext ctx)
{
throw new NotImplementedException();
}
}
}
""",
),
)
"""
TESTING HELPER METHODS
"""
@ -728,6 +990,17 @@ class StandardSuite(unittest.TestCase):
return function
def _CreateFunctionDifferentTypesInputParameters(self):
function = self.CreateFunction("Add",
["a", "b", "c"],
"std::int64_t",
["std::uint32_t", "double", "char"],
["std::uint32_t", "double", "char"],
"std::int64_t"
)
return function
def CreateFunction(self, func_name, var_names, simple_return_type, raw_var_type, simple_var_types, raw_return_type):
function = {}
function["func_name"] = func_name