Allow use of property-based row classes in ML.NET (#616)

Now schema comprehension works with public properties as well as public fields.
This commit is contained in:
Don Syme 2018-08-02 16:57:51 +01:00 коммит произвёл Pete Luferenko
Родитель 89dfc82f5e
Коммит f6934a0705
13 изменённых файлов: 962 добавлений и 102 удалений

Просмотреть файл

@ -97,6 +97,8 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.CodeAnalyzer",
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.CodeAnalyzer.Tests", "test\Microsoft.ML.CodeAnalyzer.Tests\Microsoft.ML.CodeAnalyzer.Tests.csproj", "{3E4ABF07-7970-4BE6-B45B-A13D3C397545}"
EndProject
Project("{F2A71F9B-5D33-465A-A702-920D77279786}") = "Microsoft.ML.FSharp.Tests", "test\Microsoft.ML.FSharp.Tests\Microsoft.ML.FSharp.Tests.fsproj", "{802233D6-8CC0-46AD-9F23-FEE1E9AED9B3}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.ImageAnalytics", "src\Microsoft.ML.ImageAnalytics\Microsoft.ML.ImageAnalytics.csproj", "{00E38F77-1E61-4CDF-8F97-1417D4E85053}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.HalLearners", "src\Microsoft.ML.HalLearners\Microsoft.ML.HalLearners.csproj", "{A7222F41-1CF0-47D9-B80C-B4D77B027A61}"
@ -333,6 +335,14 @@ Global
{3E4ABF07-7970-4BE6-B45B-A13D3C397545}.Release|Any CPU.Build.0 = Release|Any CPU
{3E4ABF07-7970-4BE6-B45B-A13D3C397545}.Release-Intrinsics|Any CPU.ActiveCfg = Release|Any CPU
{3E4ABF07-7970-4BE6-B45B-A13D3C397545}.Release-Intrinsics|Any CPU.Build.0 = Release|Any CPU
{802233D6-8CC0-46AD-9F23-FEE1E9AED9B3}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{802233D6-8CC0-46AD-9F23-FEE1E9AED9B3}.Debug|Any CPU.Build.0 = Debug|Any CPU
{802233D6-8CC0-46AD-9F23-FEE1E9AED9B3}.Debug-Intrinsics|Any CPU.ActiveCfg = Debug|Any CPU
{802233D6-8CC0-46AD-9F23-FEE1E9AED9B3}.Debug-Intrinsics|Any CPU.Build.0 = Debug|Any CPU
{802233D6-8CC0-46AD-9F23-FEE1E9AED9B3}.Release|Any CPU.ActiveCfg = Release|Any CPU
{802233D6-8CC0-46AD-9F23-FEE1E9AED9B3}.Release|Any CPU.Build.0 = Release|Any CPU
{802233D6-8CC0-46AD-9F23-FEE1E9AED9B3}.Release-Intrinsics|Any CPU.ActiveCfg = Release|Any CPU
{802233D6-8CC0-46AD-9F23-FEE1E9AED9B3}.Release-Intrinsics|Any CPU.Build.0 = Release|Any CPU
{00E38F77-1E61-4CDF-8F97-1417D4E85053}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{00E38F77-1E61-4CDF-8F97-1417D4E85053}.Debug|Any CPU.Build.0 = Debug|Any CPU
{00E38F77-1E61-4CDF-8F97-1417D4E85053}.Debug-Intrinsics|Any CPU.ActiveCfg = Debug|Any CPU
@ -387,6 +397,7 @@ Global
{BF66A305-DF10-47E4-8D81-42049B149D2B} = {D3D38B03-B557-484D-8348-8BADEE4DF592}
{B4E55B2D-2A92-46E7-B72F-E76D6FD83440} = {7F13E156-3EBA-4021-84A5-CD56BA72F99E}
{3E4ABF07-7970-4BE6-B45B-A13D3C397545} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4}
{802233D6-8CC0-46AD-9F23-FEE1E9AED9B3} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4}
{00E38F77-1E61-4CDF-8F97-1417D4E85053} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
{A7222F41-1CF0-47D9-B80C-B4D77B027A61} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
EndGlobalSection

Просмотреть файл

@ -51,14 +51,31 @@ namespace Microsoft.ML.Runtime.Api
/// </summary>
internal static Delegate GeneratePeek<TOwn, TRow>(InternalSchemaDefinition.Column column)
{
var fieldInfo = column.FieldInfo;
Type fieldType = fieldInfo.FieldType;
switch (column.MemberInfo)
{
case FieldInfo fieldInfo:
Type fieldType = fieldInfo.FieldType;
var assignmentOpCode = GetAssignmentOpCode(fieldType);
Func<FieldInfo, OpCode, Delegate> func = GeneratePeek<TOwn, TRow, int>;
var methInfo = func.GetMethodInfo().GetGenericMethodDefinition()
.MakeGenericMethod(typeof(TOwn), typeof(TRow), fieldType);
return (Delegate)methInfo.Invoke(null, new object[] { fieldInfo, assignmentOpCode });
var assignmentOpCode = GetAssignmentOpCode(fieldType);
Func<FieldInfo, OpCode, Delegate> func = GeneratePeek<TOwn, TRow, int>;
var methInfo = func.GetMethodInfo().GetGenericMethodDefinition()
.MakeGenericMethod(typeof(TOwn), typeof(TRow), fieldType);
return (Delegate)methInfo.Invoke(null, new object[] { fieldInfo, assignmentOpCode });
case PropertyInfo propertyInfo:
Type propertyType = propertyInfo.PropertyType;
var assignmentOpCodeProp = GetAssignmentOpCode(propertyType);
Func<PropertyInfo, OpCode, Delegate> funcProp = GeneratePeek<TOwn, TRow, int>;
var methInfoProp = funcProp.GetMethodInfo().GetGenericMethodDefinition()
.MakeGenericMethod(typeof(TOwn), typeof(TRow), propertyType);
return (Delegate)methInfoProp.Invoke(null, new object[] { propertyInfo, assignmentOpCodeProp });
default:
Contracts.Assert(false);
throw Contracts.ExceptNotSupp("Expected a FieldInfo or a PropertyInfo");
}
}
private static Delegate GeneratePeek<TOwn, TRow, TValue>(FieldInfo fieldInfo, OpCode assignmentOpCode)
@ -81,6 +98,28 @@ namespace Microsoft.ML.Runtime.Api
return mb.CreateDelegate(typeof(Peek<TRow, TValue>));
}
private static Delegate GeneratePeek<TOwn, TRow, TValue>(PropertyInfo propertyInfo, OpCode assignmentOpCode)
{
// REVIEW: It seems like we really should cache these, instead of generating them per cursor.
Type[] args = { typeof(TOwn), typeof(TRow), typeof(long), typeof(TValue).MakeByRefType() };
var mb = new DynamicMethod("Peek", null, args, typeof(TOwn), true);
var il = mb.GetILGenerator();
var minfo = propertyInfo.GetGetMethod();
var opcode = (minfo.IsVirtual || minfo.IsAbstract) ? OpCodes.Callvirt : OpCodes.Call;
il.Emit(OpCodes.Ldarg_3); // push arg3
il.Emit(OpCodes.Ldarg_1); // push arg1
il.Emit(opcode, minfo); // call [stack top].get_[propertyInfo]()
// Stobj needs to coupled with a type.
if (assignmentOpCode == OpCodes.Stobj) // [stack top-1] = [stack top]
il.Emit(assignmentOpCode, propertyInfo.PropertyType);
else
il.Emit(assignmentOpCode);
il.Emit(OpCodes.Ret); // ret
return mb.CreateDelegate(typeof(Peek<TRow, TValue>));
}
/// <summary>
/// Each of the specialized 'poke' methods sets the appropriate field value of an instance of T
/// to the provided value. So, the call is 'peek(userObject, providedValue)' and the logic is
@ -88,14 +127,30 @@ namespace Microsoft.ML.Runtime.Api
/// </summary>
internal static Delegate GeneratePoke<TOwn, TRow>(InternalSchemaDefinition.Column column)
{
var fieldInfo = column.FieldInfo;
Type fieldType = fieldInfo.FieldType;
switch (column.MemberInfo)
{
case FieldInfo fieldInfo:
Type fieldType = fieldInfo.FieldType;
var assignmentOpCode = GetAssignmentOpCode(fieldType);
Func<FieldInfo, OpCode, Delegate> func = GeneratePoke<TOwn, TRow, int>;
var methInfo = func.GetMethodInfo().GetGenericMethodDefinition()
.MakeGenericMethod(typeof(TOwn), typeof(TRow), fieldType);
return (Delegate)methInfo.Invoke(null, new object[] { fieldInfo, assignmentOpCode });
var assignmentOpCode = GetAssignmentOpCode(fieldType);
Func<FieldInfo, OpCode, Delegate> func = GeneratePoke<TOwn, TRow, int>;
var methInfo = func.GetMethodInfo().GetGenericMethodDefinition()
.MakeGenericMethod(typeof(TOwn), typeof(TRow), fieldType);
return (Delegate)methInfo.Invoke(null, new object[] { fieldInfo, assignmentOpCode });
case PropertyInfo propertyInfo:
Type propertyType = propertyInfo.PropertyType;
var assignmentOpCodeProp = GetAssignmentOpCode(propertyType);
Func<PropertyInfo, Delegate> funcProp = GeneratePoke<TOwn, TRow, int>;
var methInfoProp = funcProp.GetMethodInfo().GetGenericMethodDefinition()
.MakeGenericMethod(typeof(TOwn), typeof(TRow), propertyType);
return (Delegate)methInfoProp.Invoke(null, new object[] { propertyInfo });
default:
Contracts.Assert(false);
throw Contracts.ExceptNotSupp("Expected a FieldInfo or a PropertyInfo");
}
}
private static Delegate GeneratePoke<TOwn, TRow, TValue>(FieldInfo fieldInfo, OpCode assignmentOpCode)
@ -115,5 +170,20 @@ namespace Microsoft.ML.Runtime.Api
il.Emit(OpCodes.Ret); // ret
return mb.CreateDelegate(typeof(Poke<TRow, TValue>), null);
}
private static Delegate GeneratePoke<TOwn, TRow, TValue>(PropertyInfo propertyInfo)
{
Type[] args = { typeof(TOwn), typeof(TRow), typeof(TValue) };
var mb = new DynamicMethod("Poke", null, args, typeof(TOwn), true);
var il = mb.GetILGenerator();
var minfo = propertyInfo.GetSetMethod();
var opcode = (minfo.IsVirtual || minfo.IsAbstract) ? OpCodes.Callvirt : OpCodes.Call;
il.Emit(OpCodes.Ldarg_1); // push arg1
il.Emit(OpCodes.Ldarg_2); // push arg2
il.Emit(opcode, minfo); // call [stack top-1].set_[propertyInfo]([stack top])
il.Emit(OpCodes.Ret); // ret
return mb.CreateDelegate(typeof(Poke<TRow, TValue>), null);
}
}
}

Просмотреть файл

@ -118,7 +118,7 @@ namespace Microsoft.ML.Runtime.Api
var colType = DataView.Schema.GetColumnType(index);
var column = DataView._schema.SchemaDefn.Columns[index];
var outputType = column.IsComputed ? column.ReturnType : column.FieldInfo.FieldType;
var outputType = column.OutputType;
var genericType = outputType;
Func<int, Delegate> del;

Просмотреть файл

@ -23,21 +23,23 @@ namespace Microsoft.ML.Runtime.Api
public class Column
{
public readonly string ColumnName;
public readonly FieldInfo FieldInfo;
public readonly MemberInfo MemberInfo;
public readonly ParameterInfo ReturnParameterInfo;
public readonly ColumnType ColumnType;
public readonly bool IsComputed;
public readonly Delegate Generator;
private readonly Dictionary<string, MetadataInfo> _metadata;
public Dictionary<string, MetadataInfo> Metadata { get { return _metadata; } }
public Type ReturnType {get { return ReturnParameterInfo.ParameterType.GetElementType(); }}
public Type ComputedReturnType {get { return ReturnParameterInfo.ParameterType.GetElementType(); }}
public Type FieldOrPropertyType => (MemberInfo is FieldInfo) ? (MemberInfo as FieldInfo).FieldType : (MemberInfo as PropertyInfo).PropertyType;
public Type OutputType => IsComputed ? ComputedReturnType : FieldOrPropertyType;
public Column(string columnName, ColumnType columnType, FieldInfo fieldInfo) :
this(columnName, columnType, fieldInfo, null, null) { }
public Column(string columnName, ColumnType columnType, MemberInfo memberInfo) :
this(columnName, columnType, memberInfo, null, null) { }
public Column(string columnName, ColumnType columnType, FieldInfo fieldInfo,
public Column(string columnName, ColumnType columnType, MemberInfo memberInfo,
Dictionary<string, MetadataInfo> metadataInfos) :
this(columnName, columnType, fieldInfo, null, metadataInfos) { }
this(columnName, columnType, memberInfo, null, metadataInfos) { }
public Column(string columnName, ColumnType columnType, Delegate generator) :
this(columnName, columnType, null, generator, null) { }
@ -46,7 +48,7 @@ namespace Microsoft.ML.Runtime.Api
Dictionary<string, MetadataInfo> metadataInfos) :
this(columnName, columnType, null, generator, metadataInfos) { }
private Column(string columnName, ColumnType columnType, FieldInfo fieldInfo = null,
private Column(string columnName, ColumnType columnType, MemberInfo memberInfo = null,
Delegate generator = null, Dictionary<string, MetadataInfo> metadataInfos = null)
{
Contracts.AssertNonEmpty(columnName);
@ -55,8 +57,8 @@ namespace Microsoft.ML.Runtime.Api
if (generator == null)
{
Contracts.AssertValue(fieldInfo);
FieldInfo = fieldInfo;
Contracts.AssertValue(memberInfo);
MemberInfo = memberInfo;
}
else
{
@ -95,8 +97,8 @@ namespace Microsoft.ML.Runtime.Api
// If Column is computed type, it must have a generator.
Contracts.Assert(IsComputed == (Generator != null));
// Column must have either a generator or a fieldInfo value.
Contracts.Assert((Generator == null) != (FieldInfo == null));
// Column must have either a generator or a memberInfo value.
Contracts.Assert((Generator == null) != (MemberInfo == null));
// Additional Checks if there is a generator.
if (Generator == null)
@ -115,9 +117,7 @@ namespace Microsoft.ML.Runtime.Api
Contracts.Assert(Generator.GetMethodInfo().ReturnType == typeof(void));
// Checks that the return type of the generator is compatible with ColumnType.
bool isVector;
DataKind datakind;
GetVectorAndKind(ReturnType, "return type", out isVector, out datakind);
GetVectorAndKind(ComputedReturnType, "return type", out bool isVector, out DataKind datakind);
Contracts.Assert(isVector == ColumnType.IsVector);
Contracts.Assert(datakind == ColumnType.ItemType.RawKind);
}
@ -131,19 +131,30 @@ namespace Microsoft.ML.Runtime.Api
}
/// <summary>
/// Given a field info on a type, returns whether this appears to be a vector type,
/// Given a field or property info on a type, returns whether this appears to be a vector type,
/// and also the associated data kind for this type. If a data kind could not
/// be determined, this will throw.
/// </summary>
/// <param name="fieldInfo">The field info to inspect.</param>
/// <param name="memberInfo">The field or property info to inspect.</param>
/// <param name="isVector">Whether this appears to be a vector type.</param>
/// <param name="kind">The data kind of the type, or items of this type if vector.</param>
public static void GetVectorAndKind(FieldInfo fieldInfo, out bool isVector, out DataKind kind)
public static void GetVectorAndKind(MemberInfo memberInfo, out bool isVector, out DataKind kind)
{
Contracts.AssertValue(fieldInfo);
Type rawFieldType = fieldInfo.FieldType;
var name = fieldInfo.Name;
GetVectorAndKind(rawFieldType, name, out isVector, out kind);
Contracts.AssertValue(memberInfo);
switch (memberInfo)
{
case FieldInfo fieldInfo:
GetVectorAndKind(fieldInfo.FieldType, fieldInfo.Name, out isVector, out kind);
break;
case PropertyInfo propertyInfo:
GetVectorAndKind(propertyInfo.PropertyType, propertyInfo.Name, out isVector, out kind);
break;
default:
Contracts.Assert(false);
throw Contracts.ExceptNotSupp("Expected a FieldInfo or a PropertyInfo");
}
}
/// <summary>
@ -211,23 +222,27 @@ namespace Microsoft.ML.Runtime.Api
bool isVector;
DataKind kind;
FieldInfo fieldInfo = null;
MemberInfo memberInfo = null;
if (!col.IsComputed)
{
fieldInfo = userType.GetField(col.MemberName);
memberInfo = userType.GetField(col.MemberName);
if (fieldInfo == null)
throw Contracts.ExceptParam(nameof(userSchemaDefinition), "No field with name '{0}' found in type '{1}'",
if (memberInfo == null)
memberInfo = userType.GetProperty(col.MemberName);
if (memberInfo == null)
throw Contracts.ExceptParam(nameof(userSchemaDefinition), "No field or property with name '{0}' found in type '{1}'",
col.MemberName,
userType.FullName);
//Clause to handle the field that may be used to expose the cursor channel.
//This field does not need a column.
if (fieldInfo.FieldType == typeof(IChannel))
if ( (memberInfo is FieldInfo && (memberInfo as FieldInfo).FieldType == typeof(IChannel)) ||
(memberInfo is PropertyInfo && (memberInfo as PropertyInfo).PropertyType == typeof(IChannel)))
continue;
GetVectorAndKind(fieldInfo, out isVector, out kind);
GetVectorAndKind(memberInfo, out isVector, out kind);
}
else
{
@ -268,7 +283,7 @@ namespace Microsoft.ML.Runtime.Api
dstCols[i] = col.IsComputed ?
new Column(colName, colType, col.Generator, col.Metadata)
: new Column(colName, colType, fieldInfo, col.Metadata);
: new Column(colName, colType, memberInfo, col.Metadata);
}
return new InternalSchemaDefinition(dstCols);

Просмотреть файл

@ -14,7 +14,7 @@ namespace Microsoft.ML.Runtime.Api
/// <summary>
/// Attach to a member of a class to indicate that the item type should be of class key.
/// </summary>
[AttributeUsage(AttributeTargets.Field, AllowMultiple = false, Inherited = true)]
[AttributeUsage(AttributeTargets.Field | AttributeTargets.Property, AllowMultiple = false, Inherited = true)]
public sealed class KeyTypeAttribute : Attribute
{
// REVIEW: Property based, but should I just have a constructor?
@ -46,7 +46,7 @@ namespace Microsoft.ML.Runtime.Api
/// Allows a member to be marked as a vector valued field, primarily allowing one to set
/// the dimensionality of the resulting array.
/// </summary>
[AttributeUsage(AttributeTargets.Field, AllowMultiple = false, Inherited = true)]
[AttributeUsage(AttributeTargets.Field | AttributeTargets.Property, AllowMultiple = false, Inherited = true)]
public sealed class VectorTypeAttribute : Attribute
{
private readonly int[] _dims;
@ -66,7 +66,7 @@ namespace Microsoft.ML.Runtime.Api
/// Describes column information such as name and the source columns indicies that this
/// column encapsulates.
/// </summary>
[AttributeUsage(AttributeTargets.Field, AllowMultiple = false, Inherited = true)]
[AttributeUsage(AttributeTargets.Field | AttributeTargets.Property, AllowMultiple = false, Inherited = true)]
public sealed class ColumnAttribute : Attribute
{
public ColumnAttribute(string ordinal, string name = null)
@ -97,7 +97,7 @@ namespace Microsoft.ML.Runtime.Api
/// Allows a member to specify its column name directly, as opposed to the default
/// behavior of using the member name as the column name.
/// </summary>
[AttributeUsage(AttributeTargets.Field, AllowMultiple = false, Inherited = true)]
[AttributeUsage(AttributeTargets.Field | AttributeTargets.Property, AllowMultiple = false, Inherited = true)]
public sealed class ColumnNameAttribute : Attribute
{
private readonly string _name;
@ -119,7 +119,7 @@ namespace Microsoft.ML.Runtime.Api
/// <summary>
/// Mark this member as not being exposed as a column in the schema.
/// </summary>
[AttributeUsage(AttributeTargets.Field, AllowMultiple = false, Inherited = true)]
[AttributeUsage(AttributeTargets.Field | AttributeTargets.Property, AllowMultiple = false, Inherited = true)]
public sealed class NoColumnAttribute : Attribute
{
}
@ -128,7 +128,7 @@ namespace Microsoft.ML.Runtime.Api
/// Mark a member that implements exactly IChannel as being permitted to receive
/// channel information from an external channel.
/// </summary>
[AttributeUsage(AttributeTargets.Field, AllowMultiple = false, Inherited = true)]
[AttributeUsage(AttributeTargets.Field | AttributeTargets.Property, AllowMultiple = false, Inherited = true)]
public sealed class CursorChannelAttribute : Attribute
{
/// <summary>
@ -158,19 +158,40 @@ namespace Microsoft.ML.Runtime.Api
.Where(x => x.GetCustomAttributes(typeof(CursorChannelAttribute), false).Any())
.ToArray();
var cursorChannelAttrProperties = typeof(T)
.GetProperties(BindingFlags.Public | BindingFlags.Instance)
.Where(x => x.CanRead && x.CanWrite && x.GetGetMethod() != null && x.GetSetMethod() != null && x.GetIndexParameters().Length == 0)
.Where(x => x.GetCustomAttributes(typeof(CursorChannelAttribute), false).Any());
var cursorChannelAttrMembers = (cursorChannelAttrFields as IEnumerable<MemberInfo>).Concat(cursorChannelAttrProperties).ToArray();
//Check that there is at most one such field.
if (cursorChannelAttrFields.Length == 0)
if (cursorChannelAttrMembers.Length == 0)
return false;
ectx.Check(cursorChannelAttrFields.Length == 1,
"Only one field with CursorChannel attribute is allowed.");
ectx.Check(cursorChannelAttrMembers.Length == 1,
"Only one public field or property with CursorChannel attribute is allowed.");
//Check that the marked field has type IChannel.
var cursorChannelFieldInfo = cursorChannelAttrFields[0];
ectx.Check(cursorChannelFieldInfo.FieldType == typeof(IChannel),
"Field marked as CursorChannel must have type IChannel.");
var cursorChannelAttrMemberInfo = cursorChannelAttrMembers[0];
switch (cursorChannelAttrMemberInfo)
{
case FieldInfo cursorChannelAttrFieldInfo:
ectx.Check(cursorChannelAttrFieldInfo.FieldType == typeof(IChannel),
"Field marked as CursorChannel must have type IChannel.");
cursorChannelAttrFieldInfo.SetValue(obj, channel);
break;
cursorChannelFieldInfo.SetValue(obj, channel);
case PropertyInfo cursorChannelAttrPropertyInfo:
ectx.Check(cursorChannelAttrPropertyInfo.PropertyType == typeof(IChannel),
"Property marked as CursorChannel must have type IChannel.");
cursorChannelAttrPropertyInfo.SetValue(obj, channel);
break;
default:
Contracts.Assert(false);
throw Contracts.ExceptNotSupp("Expected a FieldInfo or a PropertyInfo");
}
return true;
}
}
@ -319,37 +340,63 @@ namespace Microsoft.ML.Runtime.Api
SchemaDefinition cols = new SchemaDefinition();
HashSet<string> colNames = new HashSet<string>();
foreach (var fieldInfo in userType.GetFields())
var fieldInfos = userType.GetFields(BindingFlags.Public | BindingFlags.Instance);
var propertyInfos =
userType
.GetProperties(BindingFlags.Public | BindingFlags.Instance)
.Where(x => x.CanRead && x.CanWrite && x.GetGetMethod() != null && x.GetSetMethod() != null && x.GetIndexParameters().Length == 0);
var memberInfos = (fieldInfos as IEnumerable<MemberInfo>).Concat(propertyInfos).ToArray();
foreach (var memberInfo in memberInfos)
{
// Clause to handle the field that may be used to expose the cursor channel.
// This field does not need a column.
// REVIEW: maybe validate the channel attribute now, instead
// of later at cursor creation.
if (fieldInfo.FieldType == typeof(IChannel))
continue;
// Const fields do not need to be mapped.
if (fieldInfo.IsLiteral)
switch (memberInfo)
{
case FieldInfo fieldInfo:
if (fieldInfo.FieldType == typeof(IChannel))
continue;
// Const fields do not need to be mapped.
if (fieldInfo.IsLiteral)
continue;
break;
case PropertyInfo propertyInfo:
if (propertyInfo.PropertyType == typeof(IChannel))
continue;
break;
default:
Contracts.Assert(false);
throw Contracts.ExceptNotSupp("Expected a FieldInfo or a PropertyInfo");
}
if (memberInfo.GetCustomAttribute<NoColumnAttribute>() != null)
continue;
if (fieldInfo.GetCustomAttribute<NoColumnAttribute>() != null)
continue;
var mappingAttr = fieldInfo.GetCustomAttribute<ColumnAttribute>();
var mappingNameAttr = fieldInfo.GetCustomAttribute<ColumnNameAttribute>();
string name = mappingAttr?.Name ?? mappingNameAttr?.Name ?? fieldInfo.Name;
var mappingAttr = memberInfo.GetCustomAttribute<ColumnAttribute>();
var mappingNameAttr = memberInfo.GetCustomAttribute<ColumnNameAttribute>();
string name = mappingAttr?.Name ?? mappingNameAttr?.Name ?? memberInfo.Name;
// Disallow duplicate names, because the field enumeration order is not actually
// well defined, so we are not gauranteed to have consistent "hiding" from run to
// run, across different .NET versions.
if (!colNames.Add(name))
throw Contracts.ExceptParam(nameof(userType), "Duplicate column name '{0}' detected, this is disallowed", name);
InternalSchemaDefinition.GetVectorAndKind(fieldInfo, out bool isVector, out DataKind kind);
InternalSchemaDefinition.GetVectorAndKind(memberInfo, out bool isVector, out DataKind kind);
PrimitiveType itemType;
var keyAttr = fieldInfo.GetCustomAttribute<KeyTypeAttribute>();
var keyAttr = memberInfo.GetCustomAttribute<KeyTypeAttribute>();
if (keyAttr != null)
{
if (!KeyType.IsValidDataKind(kind))
throw Contracts.ExceptParam(nameof(userType), "Member {0} marked with KeyType attribute, but does not appear to be a valid kind of data for a key type", fieldInfo.Name);
throw Contracts.ExceptParam(nameof(userType), "Member {0} marked with KeyType attribute, but does not appear to be a valid kind of data for a key type", memberInfo.Name);
itemType = new KeyType(kind, keyAttr.Min, keyAttr.Count, keyAttr.Contiguous);
}
else
@ -357,9 +404,9 @@ namespace Microsoft.ML.Runtime.Api
// Get the column type.
ColumnType columnType;
var vectorAttr = fieldInfo.GetCustomAttribute<VectorTypeAttribute>();
var vectorAttr = memberInfo.GetCustomAttribute<VectorTypeAttribute>();
if (vectorAttr != null && !isVector)
throw Contracts.ExceptParam(nameof(userType), "Member {0} marked with VectorType attribute, but does not appear to be a vector type", fieldInfo.Name);
throw Contracts.ExceptParam(nameof(userType), "Member {0} marked with VectorType attribute, but does not appear to be a vector type", memberInfo.Name);
if (isVector)
{
int[] dims = vectorAttr?.Dims;
@ -373,7 +420,7 @@ namespace Microsoft.ML.Runtime.Api
else
columnType = itemType;
cols.Add(new Column() { MemberName = fieldInfo.Name, ColumnName = name, ColumnType = columnType });
cols.Add(new Column() { MemberName = memberInfo.Name, ColumnName = name, ColumnType = columnType });
}
return cols;
}

Просмотреть файл

@ -103,11 +103,11 @@ namespace Microsoft.ML.Runtime.Api
throw _host.Except("Column '{0}' not found in the data view", col.ColumnName);
}
var realColType = _data.Schema.GetColumnType(colIndex);
if (!IsCompatibleType(realColType, col.FieldInfo))
if (!IsCompatibleType(realColType, col.MemberInfo))
{
throw _host.Except(
"Can't bind the IDataView column '{0}' of type '{1}' to field '{2}' of type '{3}'.",
col.ColumnName, realColType, col.FieldInfo.Name, col.FieldInfo.FieldType.FullName);
"Can't bind the IDataView column '{0}' of type '{1}' to field or property '{2}' of type '{3}'.",
col.ColumnName, realColType, col.MemberInfo.Name, col.FieldOrPropertyType.FullName);
}
acceptedCols.Add(col);
@ -130,14 +130,12 @@ namespace Microsoft.ML.Runtime.Api
}
/// <summary>
/// Returns whether the column type <paramref name="colType"/> can be bound to field <paramref name="fieldInfo"/>.
/// Returns whether the column type <paramref name="colType"/> can be bound to field <paramref name="memberInfo"/>.
/// They must both be vectors or scalars, and the raw data kind should match.
/// </summary>
private static bool IsCompatibleType(ColumnType colType, FieldInfo fieldInfo)
private static bool IsCompatibleType(ColumnType colType, MemberInfo memberInfo)
{
bool isVector;
DataKind kind;
InternalSchemaDefinition.GetVectorAndKind(fieldInfo, out isVector, out kind);
InternalSchemaDefinition.GetVectorAndKind(memberInfo, out bool isVector, out DataKind kind);
if (isVector)
return colType.IsVector && colType.ItemType.RawKind == kind;
else
@ -269,8 +267,7 @@ namespace Microsoft.ML.Runtime.Api
private Action<TRow> GenerateSetter(IRow input, int index, InternalSchemaDefinition.Column column, Delegate poke, Delegate peek)
{
var colType = input.Schema.GetColumnType(index);
var fieldInfo = column.FieldInfo;
var fieldType = fieldInfo.FieldType;
var fieldType = column.OutputType;
var genericType = fieldType;
Func<IRow, int, Delegate, Delegate, Action<TRow>> del;
if (fieldType.IsArray)
@ -431,7 +428,7 @@ namespace Microsoft.ML.Runtime.Api
else
{
// REVIEW: Is this even possible?
throw Ch.ExceptNotImpl("Type '{0}' is not yet supported.", fieldInfo.FieldType.FullName);
throw Ch.ExceptNotImpl("Type '{0}' is not yet supported.", column.OutputType.FullName);
}
MethodInfo meth = del.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(genericType);
return (Action<TRow>)meth.Invoke(this, new object[] { input, index, poke, peek });

Просмотреть файл

@ -7,6 +7,7 @@ using Microsoft.ML.Runtime;
using Microsoft.ML.Runtime.Api;
using Microsoft.ML.Runtime.Data;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Reflection;
using System.Text.RegularExpressions;
@ -71,20 +72,30 @@ namespace Microsoft.ML.Data
char separator = '\t', bool allowQuotedStrings = true,
bool supportSparse = true, bool trimWhitespace = false)
{
var fields = typeof(TInput).GetFields();
Arguments.Column = new TextLoaderColumn[fields.Length];
for (int index = 0; index < fields.Length; index++)
var userType = typeof(TInput);
var fieldInfos = userType.GetFields(BindingFlags.Public | BindingFlags.Instance);
var propertyInfos =
userType
.GetProperties(BindingFlags.Public | BindingFlags.Instance)
.Where(x => x.CanRead && x.CanWrite && x.GetGetMethod() != null && x.GetSetMethod() != null && x.GetIndexParameters().Length == 0);
var memberInfos = (fieldInfos as IEnumerable<MemberInfo>).Concat(propertyInfos).ToArray();
Arguments.Column = new TextLoaderColumn[memberInfos.Length];
for (int index = 0; index < memberInfos.Length; index++)
{
var field = fields[index];
var mappingAttr = field.GetCustomAttribute<ColumnAttribute>();
var memberInfo = memberInfos[index];
var mappingAttr = memberInfo.GetCustomAttribute<ColumnAttribute>();
if (mappingAttr == null)
throw Contracts.Except($"{field.Name} is missing ColumnAttribute");
throw Contracts.Except($"Field or property {memberInfo.Name} is missing ColumnAttribute");
if (Regex.Match(mappingAttr.Ordinal, @"[^(0-9,\*\-~)]+").Success)
throw Contracts.Except($"{mappingAttr.Ordinal} contains invalid characters. " +
$"Valid characters are 0-9, *, - and ~");
var name = mappingAttr.Name ?? field.Name;
var name = mappingAttr.Name ?? memberInfo.Name;
Runtime.Data.TextLoader.Range[] sources;
if (!Runtime.Data.TextLoader.Column.TryParseSourceEx(mappingAttr.Ordinal, out sources))
@ -96,8 +107,23 @@ namespace Microsoft.ML.Data
tlc.Name = name;
tlc.Source = new TextLoaderRange[sources.Length];
DataKind dk;
if (!TryGetDataKind(field.FieldType.IsArray ? field.FieldType.GetElementType() : field.FieldType, out dk))
throw Contracts.Except($"{name} is of unsupported type.");
switch (memberInfo)
{
case FieldInfo field:
if (!TryGetDataKind(field.FieldType.IsArray ? field.FieldType.GetElementType() : field.FieldType, out dk))
throw Contracts.Except($"Field {name} is of unsupported type.");
break;
case PropertyInfo property:
if (!TryGetDataKind(property.PropertyType.IsArray ? property.PropertyType.GetElementType() : property.PropertyType, out dk))
throw Contracts.Except($"Property {name} is of unsupported type.");
break;
default:
Contracts.Assert(false);
throw Contracts.ExceptNotSupp("Expected a FieldInfo or a PropertyInfo");
}
tlc.Type = dk;

Просмотреть файл

@ -0,0 +1,52 @@
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<TargetFrameworks>netcoreapp2.0</TargetFrameworks>
<TargetFrameworks Condition="'$(OS)' != 'Unix'">$(TargetFrameworks); net461</TargetFrameworks>
<NoWarn>2003;$(NoWarn)</NoWarn>
<PublicSign>false</PublicSign>
<SourceLink></SourceLink>
<PlatformTarget>x64</PlatformTarget>
</PropertyGroup>
<ItemGroup>
<Compile Include="SmokeTests.fs" />
<Compile Include="Program.fs" />
</ItemGroup>
<ItemGroup>
<!-- Future updates to this test will check use with F# type providers, so -->
<!-- leaving this here for now. -->
<!-- <PackageReference Include="FSharp.Data" Version="3.0.0-beta4" /> -->
</ItemGroup>
<ItemGroup>
<!-- More projects are referenced than are currently tested. Future updates to -->
<!-- these tests will test more of the surface area from F#, so leaving these references -->
<!-- here for now. -->
<ProjectReference Include="..\..\src\Microsoft.ML.Api\Microsoft.ML.Api.csproj" />
<ProjectReference Include="..\..\src\Microsoft.ML.Core\Microsoft.ML.Core.csproj" />
<ProjectReference Include="..\..\src\Microsoft.ML.CpuMath\Microsoft.ML.CpuMath.csproj" />
<ProjectReference Include="..\..\src\Microsoft.ML.Data\Microsoft.ML.Data.csproj" />
<ProjectReference Include="..\..\src\Microsoft.ML.Ensemble\Microsoft.ML.Ensemble.csproj" />
<ProjectReference Include="..\..\src\Microsoft.ML.FastTree\Microsoft.ML.FastTree.csproj" />
<ProjectReference Include="..\..\src\Microsoft.ML.KMeansClustering\Microsoft.ML.KMeansClustering.csproj" />
<ProjectReference Include="..\..\src\Microsoft.ML.LightGBM\Microsoft.ML.LightGBM.csproj" />
<ProjectReference Include="..\..\src\Microsoft.ML.Maml\Microsoft.ML.Maml.csproj" />
<ProjectReference Include="..\..\src\Microsoft.ML.Onnx\Microsoft.ML.Onnx.csproj" />
<ProjectReference Include="..\..\src\Microsoft.ML.Parquet\Microsoft.ML.Parquet.csproj" />
<ProjectReference Include="..\..\src\Microsoft.ML.PCA\Microsoft.ML.PCA.csproj" />
<ProjectReference Include="..\..\src\Microsoft.ML.PipelineInference\Microsoft.ML.PipelineInference.csproj" />
<ProjectReference Include="..\..\src\Microsoft.ML.ResultProcessor\Microsoft.ML.ResultProcessor.csproj" />
<ProjectReference Include="..\..\src\Microsoft.ML.StandardLearners\Microsoft.ML.StandardLearners.csproj" />
<ProjectReference Include="..\..\src\Microsoft.ML.Sweeper\Microsoft.ML.Sweeper.csproj" />
<ProjectReference Include="..\..\src\Microsoft.ML.Transforms\Microsoft.ML.Transforms.csproj" />
<ProjectReference Include="..\..\src\Microsoft.ML\Microsoft.ML.csproj" />
</ItemGroup>
<ItemGroup>
<NativeAssemblyReference Include="FastTreeNative" />
<NativeAssemblyReference Include="CpuMathNative" />
<NativeAssemblyReference Include="FactorizationMachineNative" />
</ItemGroup>
</Project>

Просмотреть файл

@ -0,0 +1,9 @@
namespace Microsoft.ML.FSharp.Tests
#if NETCOREAPP2_0
module Program =
[<EntryPoint>]
let main _ = 0
#endif

Просмотреть файл

@ -0,0 +1,263 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
//=================================================================================================
// This test can be run either as a compiled test with .NET Core (on any platform) or
// manually in script form (to help debug it and also check that F# scripting works with ML.NET).
// Running as a script requires using F# Interactive on Windows, and the explicit references below.
// The references would normally be created by a package loader for the scripting
// environment, e.g. see https://github.com/isaacabraham/ml-test-experiment/, but
// here we list them explicitly to avoid the dependency on a package loader,
//
// You should build Microsoft.ML.FSharp.Tests in Debug mode for framework net461
// before running this as a script with F# Interactive by editing the project
// file to have:
// <TargetFrameworks>netcoreapp2.0; net461</TargetFrameworks>
#if INTERACTIVE
#r "netstandard"
#r @"../../bin/AnyCPU.Debug/Microsoft.ML.FSharp.Tests/net461/Microsoft.ML.Core.dll"
#r @"../../bin/AnyCPU.Debug/Microsoft.ML.FSharp.Tests/net461/Google.Protobuf.dll"
#r @"../../bin/AnyCPU.Debug/Microsoft.ML.FSharp.Tests/net461/Newtonsoft.Json.dll"
#r @"../../bin/AnyCPU.Debug/Microsoft.ML.FSharp.Tests/net461/System.CodeDom.dll"
#r @"../../bin/AnyCPU.Debug/Microsoft.ML.FSharp.Tests/net461/System.Threading.Tasks.Dataflow.dll"
#r @"../../bin/AnyCPU.Debug/Microsoft.ML.FSharp.Tests/net461/Microsoft.ML.CpuMath.dll"
#r @"../../bin/AnyCPU.Debug/Microsoft.ML.FSharp.Tests/net461/Microsoft.ML.Data.dll"
#r @"../../bin/AnyCPU.Debug/Microsoft.ML.FSharp.Tests/net461/Microsoft.ML.Transforms.dll"
#r @"../../bin/AnyCPU.Debug/Microsoft.ML.FSharp.Tests/net461/Microsoft.ML.ResultProcessor.dll"
#r @"../../bin/AnyCPU.Debug/Microsoft.ML.FSharp.Tests/net461/Microsoft.ML.PCA.dll"
#r @"../../bin/AnyCPU.Debug/Microsoft.ML.FSharp.Tests/net461/Microsoft.ML.KMeansClustering.dll"
#r @"../../bin/AnyCPU.Debug/Microsoft.ML.FSharp.Tests/net461/Microsoft.ML.FastTree.dll"
#r @"../../bin/AnyCPU.Debug/Microsoft.ML.FSharp.Tests/net461/Microsoft.ML.Api.dll"
#r @"../../bin/AnyCPU.Debug/Microsoft.ML.FSharp.Tests/net461/Microsoft.ML.Sweeper.dll"
#r @"../../bin/AnyCPU.Debug/Microsoft.ML.FSharp.Tests/net461/Microsoft.ML.dll"
#r @"../../bin/AnyCPU.Debug/Microsoft.ML.FSharp.Tests/net461/Microsoft.ML.StandardLearners.dll"
#r @"../../bin/AnyCPU.Debug/Microsoft.ML.FSharp.Tests/net461/Microsoft.ML.PipelineInference.dll"
#r @"../../bin/AnyCPU.Debug/Microsoft.ML.FSharp.Tests/net461/xunit.core.dll"
#r @"../../bin/AnyCPU.Debug/Microsoft.ML.FSharp.Tests/net461/xunit.assert.dll"
#r "System"
#r "System.ComponentModel.Composition"
#r "System.Core"
#r "System.Xml.Linq"
// Later tests will add data import using F# type providers:
//#r @"../../packages/fsharp.data/3.0.0-beta4/lib/netstandard2.0/FSharp.Data.dll" // this must be referenced from its package location
#endif
//================================================================================
// The tests proper start here
#if !INTERACTIVE
namespace Microsoft.ML.FSharp.Tests
#endif
open System
open Microsoft.ML
open Microsoft.ML.Data
open Microsoft.ML.Transforms
open Microsoft.ML.Trainers
open Microsoft.ML.Runtime.Api
open Xunit
module SmokeTest1 =
type SentimentData() =
[<Column(ordinal = "0"); DefaultValue>]
val mutable SentimentText : string
[<Column(ordinal = "1", name = "Label"); DefaultValue>]
val mutable Sentiment : float32
type SentimentPrediction() =
[<ColumnName("PredictedLabel"); DefaultValue>]
val mutable Sentiment : bool
[<Fact>]
let ``FSharp-Sentiment-Smoke-Test`` () =
// See https://github.com/dotnet/machinelearning/issues/401: forces the loading of ML.NET component assemblies
let _load =
[ typeof<Microsoft.ML.Runtime.Transforms.TextAnalytics>;
typeof<Microsoft.ML.Runtime.FastTree.FastTree> ]
let testDataPath = __SOURCE_DIRECTORY__ + @"/../data/wikipedia-detox-250-line-data.tsv"
let pipeline = LearningPipeline()
pipeline.Add(
TextLoader(testDataPath).CreateFrom<SentimentData>(
Arguments =
TextLoaderArguments(
HasHeader = true,
Column = [| TextLoaderColumn(Name = "Label",
Source = [| TextLoaderRange(0) |],
Type = Nullable (Data.DataKind.Num))
TextLoaderColumn(Name = "SentimentText",
Source = [| TextLoaderRange(1) |],
Type = Nullable (Data.DataKind.Text)) |]
)))
pipeline.Add(
TextFeaturizer(
"Features", [| "SentimentText" |],
KeepDiacritics = false,
KeepPunctuations = false,
TextCase = TextNormalizerTransformCaseNormalizationMode.Lower,
OutputTokens = true,
VectorNormalizer = TextTransformTextNormKind.L2
))
pipeline.Add(
FastTreeBinaryClassifier(
NumLeaves = 5,
NumTrees = 5,
MinDocumentsInLeafs = 2
))
let model = pipeline.Train<SentimentData, SentimentPrediction>()
let predictions =
[ SentimentData(SentimentText = "This is a gross exaggeration. Nobody is setting a kangaroo court. There was a simple addition.")
SentimentData(SentimentText = "Sort of ok")
SentimentData(SentimentText = "Joe versus the Volcano Coffee Company is a great film.") ]
|> model.Predict
let predictionResults = [ for p in predictions -> p.Sentiment ]
Assert.Equal<bool list>(predictionResults, [ false; true; true ])
module SmokeTest2 =
[<CLIMutable>]
type SentimentData =
{ [<Column(ordinal = "0")>]
SentimentText : string
[<Column(ordinal = "1", name = "Label")>]
Sentiment : float32 }
[<CLIMutable>]
type SentimentPrediction =
{ [<ColumnName("PredictedLabel")>]
Sentiment : bool }
[<Fact>]
let ``FSharp-Sentiment-Smoke-Test`` () =
// See https://github.com/dotnet/machinelearning/issues/401: forces the loading of ML.NET component assemblies
let _load =
[ typeof<Microsoft.ML.Runtime.Transforms.TextAnalytics>;
typeof<Microsoft.ML.Runtime.FastTree.FastTree> ]
let testDataPath = __SOURCE_DIRECTORY__ + @"/../data/wikipedia-detox-250-line-data.tsv"
let pipeline = LearningPipeline()
pipeline.Add(
TextLoader(testDataPath).CreateFrom<SentimentData>(
Arguments =
TextLoaderArguments(
HasHeader = true,
Column = [| TextLoaderColumn(Name = "Label",
Source = [| TextLoaderRange(0) |],
Type = Nullable (Data.DataKind.Num))
TextLoaderColumn(Name = "SentimentText",
Source = [| TextLoaderRange(1) |],
Type = Nullable (Data.DataKind.Text)) |]
)))
pipeline.Add(
TextFeaturizer(
"Features", [| "SentimentText" |],
KeepDiacritics = false,
KeepPunctuations = false,
TextCase = TextNormalizerTransformCaseNormalizationMode.Lower,
OutputTokens = true,
VectorNormalizer = TextTransformTextNormKind.L2
))
pipeline.Add(
FastTreeBinaryClassifier(
NumLeaves = 5,
NumTrees = 5,
MinDocumentsInLeafs = 2
))
let model = pipeline.Train<SentimentData, SentimentPrediction>()
let predictions =
[ { SentimentText = "This is a gross exaggeration. Nobody is setting a kangaroo court. There was a simple addition."; Sentiment = 0.0f }
{ SentimentText = "Sort of ok"; Sentiment = 0.0f }
{ SentimentText = "Joe versus the Volcano Coffee Company is a great film."; Sentiment = 0.0f } ]
|> model.Predict
let predictionResults = [ for p in predictions -> p.Sentiment ]
Assert.Equal<bool list>(predictionResults, [ false; true; true ])
module SmokeTest3 =
type SentimentData() =
[<Column(ordinal = "0")>]
member val SentimentText = "" with get, set
[<Column(ordinal = "1", name = "Label")>]
member val Sentiment = 0.0 with get, set
type SentimentPrediction() =
[<ColumnName("PredictedLabel")>]
member val Sentiment = false with get, set
[<Fact>]
let ``FSharp-Sentiment-Smoke-Test`` () =
// See https://github.com/dotnet/machinelearning/issues/401: forces the loading of ML.NET component assemblies
let _load =
[ typeof<Microsoft.ML.Runtime.Transforms.TextAnalytics>;
typeof<Microsoft.ML.Runtime.FastTree.FastTree> ]
let testDataPath = __SOURCE_DIRECTORY__ + @"/../data/wikipedia-detox-250-line-data.tsv"
let pipeline = LearningPipeline()
pipeline.Add(
TextLoader(testDataPath).CreateFrom<SentimentData>(
Arguments =
TextLoaderArguments(
HasHeader = true,
Column = [| TextLoaderColumn(Name = "Label",
Source = [| TextLoaderRange(0) |],
Type = Nullable (Data.DataKind.Num))
TextLoaderColumn(Name = "SentimentText",
Source = [| TextLoaderRange(1) |],
Type = Nullable (Data.DataKind.Text)) |]
)))
pipeline.Add(
TextFeaturizer(
"Features", [| "SentimentText" |],
KeepDiacritics = false,
KeepPunctuations = false,
TextCase = TextNormalizerTransformCaseNormalizationMode.Lower,
OutputTokens = true,
VectorNormalizer = TextTransformTextNormKind.L2
))
pipeline.Add(
FastTreeBinaryClassifier(
NumLeaves = 5,
NumTrees = 5,
MinDocumentsInLeafs = 2
))
let model = pipeline.Train<SentimentData, SentimentPrediction>()
let predictions =
[ SentimentData(SentimentText = "This is a gross exaggeration. Nobody is setting a kangaroo court. There was a simple addition.")
SentimentData(SentimentText = "Sort of ok")
SentimentData(SentimentText = "Joe versus the Volcano Coffee Company is a great film.") ]
|> model.Predict
let predictionResults = [ for p in predictions -> p.Sentiment ]
Assert.Equal<bool list>(predictionResults, [ false; true; true ])

Просмотреть файл

@ -174,6 +174,49 @@ namespace Microsoft.ML.EntryPoints.Tests
}
[Fact]
public void CanTrainProperties()
{
var pipeline = new LearningPipeline();
var data = new List<IrisDataProperties>() {
new IrisDataProperties { SepalLength = 1f, SepalWidth = 1f, PetalLength=0.3f, PetalWidth=5.1f, Label=1},
new IrisDataProperties { SepalLength = 1f, SepalWidth = 1f, PetalLength=0.3f, PetalWidth=5.1f, Label=1},
new IrisDataProperties { SepalLength = 1.2f, SepalWidth = 0.5f, PetalLength=0.3f, PetalWidth=5.1f, Label=0}
};
var collection = CollectionDataSource.Create(data);
pipeline.Add(collection);
pipeline.Add(new ColumnConcatenator(outputColumn: "Features",
"SepalLength", "SepalWidth", "PetalLength", "PetalWidth"));
pipeline.Add(new StochasticDualCoordinateAscentClassifier());
PredictionModel<IrisDataProperties, IrisPredictionProperties> model = pipeline.Train<IrisDataProperties, IrisPredictionProperties>();
IrisPredictionProperties prediction = model.Predict(new IrisDataProperties()
{
SepalLength = 3.3f,
SepalWidth = 1.6f,
PetalLength = 0.2f,
PetalWidth = 5.1f,
});
pipeline = new LearningPipeline();
collection = CollectionDataSource.Create(data.AsEnumerable());
pipeline.Add(collection);
pipeline.Add(new ColumnConcatenator(outputColumn: "Features",
"SepalLength", "SepalWidth", "PetalLength", "PetalWidth"));
pipeline.Add(new StochasticDualCoordinateAscentClassifier());
model = pipeline.Train<IrisDataProperties, IrisPredictionProperties>();
prediction = model.Predict(new IrisDataProperties()
{
SepalLength = 3.3f,
SepalWidth = 1.6f,
PetalLength = 0.2f,
PetalWidth = 5.1f,
});
}
public class Input
{
[Column("0")]
@ -207,6 +250,37 @@ namespace Microsoft.ML.EntryPoints.Tests
public float[] PredictedLabels;
}
public class IrisDataProperties
{
private float _Label;
private float _SepalLength;
private float _SepalWidth;
private float _PetalLength;
private float _PetalWidth;
[Column("0")]
public float Label { get { return _Label; } set { _Label = value; } }
[Column("1")]
public float SepalLength { get { return _SepalLength; } set { _SepalLength = value; } }
[Column("2")]
public float SepalWidth { get { return _SepalWidth; } set { _SepalWidth = value; } }
[Column("3")]
public float PetalLength { get { return _PetalLength; } set { _PetalLength = value; } }
[Column("4")]
public float PetalWidth { get { return _PetalWidth; } set { _PetalWidth = value; } }
}
public class IrisPredictionProperties
{
private float[] _PredictedLabels;
[ColumnName("Score")]
public float[] PredictedLabels { get { return _PredictedLabels; } set { _PredictedLabels = value; } }
}
public class ConversionSimpleClass
{
public int fInt;
@ -257,7 +331,7 @@ namespace Microsoft.ML.EntryPoints.Tests
public bool CompareThroughReflection<T>(T x, T y)
{
foreach (var field in typeof(T).GetFields())
foreach (var field in typeof(T).GetFields(BindingFlags.Public | BindingFlags.Instance))
{
var xvalue = field.GetValue(x);
var yvalue = field.GetValue(y);
@ -272,6 +346,25 @@ namespace Microsoft.ML.EntryPoints.Tests
return false;
}
}
foreach (var property in typeof(T).GetProperties(BindingFlags.Public | BindingFlags.Instance))
{
// Don't compare properties with private getters and setters
if (!property.CanRead || !property.CanWrite || property.GetGetMethod() == null || property.GetSetMethod() == null)
continue;
var xvalue = property.GetValue(x);
var yvalue = property.GetValue(y);
if (property.PropertyType.IsArray)
{
if (!CompareArrayValues(xvalue as Array, yvalue as Array))
return false;
}
else
{
if (!CompareObjectValues(xvalue, yvalue, property.PropertyType))
return false;
}
}
return true;
}
@ -288,14 +381,6 @@ namespace Microsoft.ML.EntryPoints.Tests
return true;
}
public class ClassWithConstField
{
public const string ConstString = "N";
public string fString;
public const int ConstInt = 100;
public int fInt;
}
[Fact]
public void RoundTripConversionWithBasicTypes()
{
@ -489,6 +574,50 @@ namespace Microsoft.ML.EntryPoints.Tests
}
}
public class ConversionLossMinValueClassProperties
{
private int? _fInt;
private long? _fLong;
private short? _fShort;
private sbyte? _fsByte;
public int? IntProp { get { return _fInt; } set { _fInt = value; } }
public short? ShortProp { get { return _fShort; } set { _fShort = value; } }
public sbyte? SByteProp { get { return _fsByte; } set { _fsByte = value; } }
public long? LongProp { get { return _fLong; } set { _fLong = value; } }
}
[Fact]
public void ConversionMinValueToNullBehaviorProperties()
{
using (var env = new TlcEnvironment())
{
var data = new List<ConversionLossMinValueClassProperties>
{
new ConversionLossMinValueClassProperties() { SByteProp = null, IntProp = null, LongProp = null, ShortProp = null },
new ConversionLossMinValueClassProperties() { SByteProp = sbyte.MinValue, IntProp = int.MinValue, LongProp = long.MinValue, ShortProp = short.MinValue }
};
foreach (var field in typeof(ConversionLossMinValueClassProperties).GetFields())
{
var dataView = ComponentCreation.CreateDataView(env, data);
var enumerator = dataView.AsEnumerable<ConversionLossMinValueClassProperties>(env, false).GetEnumerator();
while (enumerator.MoveNext())
{
Assert.True(enumerator.Current.IntProp == null && enumerator.Current.LongProp == null &&
enumerator.Current.SByteProp == null && enumerator.Current.ShortProp == null);
}
}
}
}
public class ClassWithConstField
{
public const string ConstString = "N";
public string fString;
public const int ConstInt = 100;
public int fInt;
}
[Fact]
public void ClassWithConstFieldsConversion()
{
@ -510,6 +639,122 @@ namespace Microsoft.ML.EntryPoints.Tests
}
}
public class ClassWithMixOfFieldsAndProperties
{
public string fString;
private int _fInt;
public int IntProp { get { return _fInt; } set { _fInt = value; } }
}
[Fact]
public void ClassWithMixOfFieldsAndPropertiesConversion()
{
var data = new List<ClassWithMixOfFieldsAndProperties>()
{
new ClassWithMixOfFieldsAndProperties(){ IntProp=1, fString ="lala" },
new ClassWithMixOfFieldsAndProperties(){ IntProp=-1, fString ="" },
new ClassWithMixOfFieldsAndProperties(){ IntProp=0, fString =null }
};
using (var env = new TlcEnvironment())
{
var dataView = ComponentCreation.CreateDataView(env, data);
var enumeratorSimple = dataView.AsEnumerable<ClassWithMixOfFieldsAndProperties>(env, false).GetEnumerator();
var originalEnumerator = data.GetEnumerator();
while (enumeratorSimple.MoveNext() && originalEnumerator.MoveNext())
Assert.True(CompareThroughReflection(enumeratorSimple.Current, originalEnumerator.Current));
Assert.True(!enumeratorSimple.MoveNext() && !originalEnumerator.MoveNext());
}
}
public abstract class BaseClassWithInheritedProperties
{
private string _fString;
private byte _fByte;
public string StringProp { get { return _fString; } set { _fString = value; } }
public abstract long LongProp { get; set; }
public virtual byte ByteProp { get { return _fByte; } set { _fByte = value; } }
}
public class ClassWithPrivateFieldsAndProperties
{
public ClassWithPrivateFieldsAndProperties() { seq++; _unusedStaticField++; _unusedPrivateField1 = 100; }
static public int seq;
static public int _unusedStaticField;
private int _unusedPrivateField1;
private string _fString;
// This property is ignored because it has no setter
private int UnusedReadOnlyProperty { get { return _unusedPrivateField1; } }
// This property is ignored because it is private
private int UnusedPrivateProperty { get { return _unusedPrivateField1; } set { _unusedPrivateField1 = value; } }
// This property is ignored because it has a private setter
public int UnusedPropertyWithPrivateSetter { get { return _unusedPrivateField1; } private set { _unusedPrivateField1 = value; } }
// This property is ignored because it has a private getter
public int UnusedPropertyWithPrivateGetter { private get { return _unusedPrivateField1; } set { _unusedPrivateField1 = value; } }
public string StringProp { get { return _fString; } set { _fString = value; } }
}
[Fact]
public void ClassWithPrivateFieldsAndPropertiesConversion()
{
var data = new List<ClassWithPrivateFieldsAndProperties>()
{
new ClassWithPrivateFieldsAndProperties(){ StringProp ="lala" },
new ClassWithPrivateFieldsAndProperties(){ StringProp ="baba" }
};
using (var env = new TlcEnvironment())
{
var dataView = ComponentCreation.CreateDataView(env, data);
var enumeratorSimple = dataView.AsEnumerable<ClassWithPrivateFieldsAndProperties>(env, false).GetEnumerator();
var originalEnumerator = data.GetEnumerator();
while (enumeratorSimple.MoveNext() && originalEnumerator.MoveNext())
{
Assert.True(CompareThroughReflection(enumeratorSimple.Current, originalEnumerator.Current));
Assert.True(enumeratorSimple.Current.UnusedPropertyWithPrivateSetter == 100);
}
Assert.True(!enumeratorSimple.MoveNext() && !originalEnumerator.MoveNext());
}
}
public class ClassWithInheritedProperties : BaseClassWithInheritedProperties
{
private int _fInt;
private long _fLong;
private byte _fByte2;
public int IntProp { get { return _fInt; } set { _fInt = value; } }
public override long LongProp { get { return _fLong; } set { _fLong = value; } }
public override byte ByteProp { get { return _fByte2; } set { _fByte2 = value; } }
}
[Fact]
public void ClassWithInheritedPropertiesConversion()
{
var data = new List<ClassWithInheritedProperties>()
{
new ClassWithInheritedProperties(){ IntProp=1, StringProp ="lala", LongProp=17, ByteProp=3 },
new ClassWithInheritedProperties(){ IntProp=-1, StringProp ="", LongProp=2, ByteProp=4 },
new ClassWithInheritedProperties(){ IntProp=0, StringProp =null, LongProp=18, ByteProp=5 }
};
using (var env = new TlcEnvironment())
{
var dataView = ComponentCreation.CreateDataView(env, data);
var enumeratorSimple = dataView.AsEnumerable<ClassWithInheritedProperties>(env, false).GetEnumerator();
var originalEnumerator = data.GetEnumerator();
while (enumeratorSimple.MoveNext() && originalEnumerator.MoveNext())
Assert.True(CompareThroughReflection(enumeratorSimple.Current, originalEnumerator.Current));
Assert.True(!enumeratorSimple.MoveNext() && !originalEnumerator.MoveNext());
}
}
public class ClassWithArrays
{
public string[] fString;
@ -609,5 +854,129 @@ namespace Microsoft.ML.EntryPoints.Tests
Assert.True(!enumeratorNullable.MoveNext() && !originalNullalbleEnumerator.MoveNext());
}
}
public class ClassWithArrayProperties
{
private string[] _fString;
private int[] _fInt;
private uint[] _fuInt;
private short[] _fShort;
private ushort[] _fuShort;
private sbyte[] _fsByte;
private byte[] _fByte;
private long[] _fLong;
private ulong[] _fuLong;
private float[] _fFloat;
private double[] _fDouble;
private bool[] _fBool;
public string[] StringProp { get { return _fString; } set { _fString = value; } }
public int[] IntProp { get { return _fInt; } set { _fInt = value; } }
public uint[] UIntProp { get { return _fuInt; } set { _fuInt = value; } }
public short[] ShortProp { get { return _fShort; } set { _fShort = value; } }
public ushort[] UShortProp { get { return _fuShort; } set { _fuShort = value; } }
public sbyte[] SByteProp { get { return _fsByte; } set { _fsByte = value; } }
public byte[] ByteProp { get { return _fByte; } set { _fByte = value; } }
public long[] LongProp { get { return _fLong; } set { _fLong = value; } }
public ulong[] ULongProp { get { return _fuLong; } set { _fuLong = value; } }
public float[] FloatProp { get { return _fFloat; } set { _fFloat = value; } }
public double[] DobuleProp { get { return _fDouble; } set { _fDouble = value; } }
public bool[] BoolProp { get { return _fBool; } set { _fBool = value; } }
}
public class ClassWithNullableArrayProperties
{
private string[] _fString;
private int?[] _fInt;
private uint?[] _fuInt;
private short?[] _fShort;
private ushort?[] _fuShort;
private sbyte?[] _fsByte;
private byte?[] _fByte;
private long?[] _fLong;
private ulong?[] _fuLong;
private float?[] _fFloat;
private double?[] _fDouble;
private bool?[] _fBool;
public string[] StringProp { get { return _fString; } set { _fString = value; } }
public int?[] IntProp { get { return _fInt; } set { _fInt = value; } }
public uint?[] UIntProp { get { return _fuInt; } set { _fuInt = value; } }
public short?[] ShortProp { get { return _fShort; } set { _fShort = value; } }
public ushort?[] UShortProp { get { return _fuShort; } set { _fuShort = value; } }
public sbyte?[] SByteProp { get { return _fsByte; } set { _fsByte = value; } }
public byte?[] ByteProp { get { return _fByte; } set { _fByte = value; } }
public long?[] LongProp { get { return _fLong; } set { _fLong = value; } }
public ulong?[] ULongProp { get { return _fuLong; } set { _fuLong = value; } }
public float?[] SingleProp { get { return _fFloat; } set { _fFloat = value; } }
public double?[] DoubleProp { get { return _fDouble; } set { _fDouble = value; } }
public bool?[] BoolProp { get { return _fBool; } set { _fBool = value; } }
}
[Fact]
public void RoundTripConversionWithArrayPropertiess()
{
var data = new List<ClassWithArrayProperties>
{
new ClassWithArrayProperties()
{
IntProp = new int[3] { 0, 1, 2 },
FloatProp = new float[3] { -0.99f, 0f, 0.99f },
StringProp = new string[2] { "hola", "lola" },
BoolProp = new bool[2] { true, false },
ByteProp = new byte[3] { 0, 124, 255 },
DobuleProp = new double[3] { -1, 0, 1 },
LongProp = new long[] { 0, 1, 2 },
SByteProp = new sbyte[3] { -127, 127, 0 },
ShortProp = new short[3] { 0, 1225, 32767 },
UIntProp = new uint[2] { 0, uint.MaxValue },
ULongProp = new ulong[2] { ulong.MaxValue, 0 },
UShortProp = new ushort[2] { 0, ushort.MaxValue }
},
new ClassWithArrayProperties() { IntProp = new int[3] { -2, 1, 0 }, FloatProp = new float[3] { 0.99f, 0f, -0.99f }, StringProp = new string[2] { "", null } },
new ClassWithArrayProperties()
};
var nullableData = new List<ClassWithNullableArrayProperties>
{
new ClassWithNullableArrayProperties()
{
IntProp = new int?[3] { null, -1, 1 },
SingleProp = new float?[3] { -0.99f, null, 0.99f },
StringProp = new string[2] { null, "" },
BoolProp = new bool?[3] { true, null, false },
ByteProp = new byte?[4] { 0, 125, null, 255 },
DoubleProp = new double?[3] { -1, null, 1 },
LongProp = new long?[] { null, -1, 1 },
SByteProp = new sbyte?[3] { -127, 127, null },
ShortProp = new short?[3] { 0, null, 32767 },
UIntProp = new uint?[4] { null, 42, 0, uint.MaxValue },
ULongProp = new ulong?[3] { ulong.MaxValue, null, 0 },
UShortProp = new ushort?[3] { 0, null, ushort.MaxValue }
},
new ClassWithNullableArrayProperties() { IntProp = new int?[3] { -2, 1, 0 }, SingleProp = new float?[3] { 0.99f, 0f, -0.99f }, StringProp = new string[2] { "lola", "hola" } },
new ClassWithNullableArrayProperties()
};
using (var env = new TlcEnvironment())
{
var dataView = ComponentCreation.CreateDataView(env, data);
var enumeratorSimple = dataView.AsEnumerable<ClassWithArrayProperties>(env, false).GetEnumerator();
var originalEnumerator = data.GetEnumerator();
while (enumeratorSimple.MoveNext() && originalEnumerator.MoveNext())
{
Assert.True(CompareThroughReflection(enumeratorSimple.Current, originalEnumerator.Current));
}
Assert.True(!enumeratorSimple.MoveNext() && !originalEnumerator.MoveNext());
var nullableDataView = ComponentCreation.CreateDataView(env, nullableData);
var enumeratorNullable = nullableDataView.AsEnumerable<ClassWithNullableArrayProperties>(env, false).GetEnumerator();
var originalNullalbleEnumerator = nullableData.GetEnumerator();
while (enumeratorNullable.MoveNext() && originalNullalbleEnumerator.MoveNext())
{
Assert.True(CompareThroughReflection(enumeratorNullable.Current, originalNullalbleEnumerator.Current));
}
Assert.True(!enumeratorNullable.MoveNext() && !originalNullalbleEnumerator.MoveNext());
}
}
}
}

Просмотреть файл

@ -228,7 +228,7 @@ namespace Microsoft.ML.EntryPoints.Tests
public void ThrowsExceptionWithPropertyName()
{
Exception ex = Assert.Throws<InvalidOperationException>( () => new Data.TextLoader("fakefile.txt").CreateFrom<ModelWithoutColumnAttribute>() );
Assert.StartsWith("String1 is missing ColumnAttribute", ex.Message);
Assert.StartsWith("Field or property String1 is missing ColumnAttribute", ex.Message);
}
public class QuoteInput

Просмотреть файл

@ -3,6 +3,7 @@
<ItemGroup>
<Project Include="$(MSBuildThisFileDirectory)**\*.csproj" />
<Project Include="$(MSBuildThisFileDirectory)**\*.fsproj" />
</ItemGroup>
<Target Name="RunTests">