Allow use of property-based row classes in ML.NET (#616)
Now schema comprehension works with public properties as well as public fields.
This commit is contained in:
Родитель
89dfc82f5e
Коммит
f6934a0705
|
@ -97,6 +97,8 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.CodeAnalyzer",
|
|||
EndProject
|
||||
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.CodeAnalyzer.Tests", "test\Microsoft.ML.CodeAnalyzer.Tests\Microsoft.ML.CodeAnalyzer.Tests.csproj", "{3E4ABF07-7970-4BE6-B45B-A13D3C397545}"
|
||||
EndProject
|
||||
Project("{F2A71F9B-5D33-465A-A702-920D77279786}") = "Microsoft.ML.FSharp.Tests", "test\Microsoft.ML.FSharp.Tests\Microsoft.ML.FSharp.Tests.fsproj", "{802233D6-8CC0-46AD-9F23-FEE1E9AED9B3}"
|
||||
EndProject
|
||||
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.ImageAnalytics", "src\Microsoft.ML.ImageAnalytics\Microsoft.ML.ImageAnalytics.csproj", "{00E38F77-1E61-4CDF-8F97-1417D4E85053}"
|
||||
EndProject
|
||||
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.HalLearners", "src\Microsoft.ML.HalLearners\Microsoft.ML.HalLearners.csproj", "{A7222F41-1CF0-47D9-B80C-B4D77B027A61}"
|
||||
|
@ -333,6 +335,14 @@ Global
|
|||
{3E4ABF07-7970-4BE6-B45B-A13D3C397545}.Release|Any CPU.Build.0 = Release|Any CPU
|
||||
{3E4ABF07-7970-4BE6-B45B-A13D3C397545}.Release-Intrinsics|Any CPU.ActiveCfg = Release|Any CPU
|
||||
{3E4ABF07-7970-4BE6-B45B-A13D3C397545}.Release-Intrinsics|Any CPU.Build.0 = Release|Any CPU
|
||||
{802233D6-8CC0-46AD-9F23-FEE1E9AED9B3}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
|
||||
{802233D6-8CC0-46AD-9F23-FEE1E9AED9B3}.Debug|Any CPU.Build.0 = Debug|Any CPU
|
||||
{802233D6-8CC0-46AD-9F23-FEE1E9AED9B3}.Debug-Intrinsics|Any CPU.ActiveCfg = Debug|Any CPU
|
||||
{802233D6-8CC0-46AD-9F23-FEE1E9AED9B3}.Debug-Intrinsics|Any CPU.Build.0 = Debug|Any CPU
|
||||
{802233D6-8CC0-46AD-9F23-FEE1E9AED9B3}.Release|Any CPU.ActiveCfg = Release|Any CPU
|
||||
{802233D6-8CC0-46AD-9F23-FEE1E9AED9B3}.Release|Any CPU.Build.0 = Release|Any CPU
|
||||
{802233D6-8CC0-46AD-9F23-FEE1E9AED9B3}.Release-Intrinsics|Any CPU.ActiveCfg = Release|Any CPU
|
||||
{802233D6-8CC0-46AD-9F23-FEE1E9AED9B3}.Release-Intrinsics|Any CPU.Build.0 = Release|Any CPU
|
||||
{00E38F77-1E61-4CDF-8F97-1417D4E85053}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
|
||||
{00E38F77-1E61-4CDF-8F97-1417D4E85053}.Debug|Any CPU.Build.0 = Debug|Any CPU
|
||||
{00E38F77-1E61-4CDF-8F97-1417D4E85053}.Debug-Intrinsics|Any CPU.ActiveCfg = Debug|Any CPU
|
||||
|
@ -387,6 +397,7 @@ Global
|
|||
{BF66A305-DF10-47E4-8D81-42049B149D2B} = {D3D38B03-B557-484D-8348-8BADEE4DF592}
|
||||
{B4E55B2D-2A92-46E7-B72F-E76D6FD83440} = {7F13E156-3EBA-4021-84A5-CD56BA72F99E}
|
||||
{3E4ABF07-7970-4BE6-B45B-A13D3C397545} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4}
|
||||
{802233D6-8CC0-46AD-9F23-FEE1E9AED9B3} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4}
|
||||
{00E38F77-1E61-4CDF-8F97-1417D4E85053} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
|
||||
{A7222F41-1CF0-47D9-B80C-B4D77B027A61} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
|
||||
EndGlobalSection
|
||||
|
|
|
@ -51,14 +51,31 @@ namespace Microsoft.ML.Runtime.Api
|
|||
/// </summary>
|
||||
internal static Delegate GeneratePeek<TOwn, TRow>(InternalSchemaDefinition.Column column)
|
||||
{
|
||||
var fieldInfo = column.FieldInfo;
|
||||
Type fieldType = fieldInfo.FieldType;
|
||||
switch (column.MemberInfo)
|
||||
{
|
||||
case FieldInfo fieldInfo:
|
||||
Type fieldType = fieldInfo.FieldType;
|
||||
|
||||
var assignmentOpCode = GetAssignmentOpCode(fieldType);
|
||||
Func<FieldInfo, OpCode, Delegate> func = GeneratePeek<TOwn, TRow, int>;
|
||||
var methInfo = func.GetMethodInfo().GetGenericMethodDefinition()
|
||||
.MakeGenericMethod(typeof(TOwn), typeof(TRow), fieldType);
|
||||
return (Delegate)methInfo.Invoke(null, new object[] { fieldInfo, assignmentOpCode });
|
||||
var assignmentOpCode = GetAssignmentOpCode(fieldType);
|
||||
Func<FieldInfo, OpCode, Delegate> func = GeneratePeek<TOwn, TRow, int>;
|
||||
var methInfo = func.GetMethodInfo().GetGenericMethodDefinition()
|
||||
.MakeGenericMethod(typeof(TOwn), typeof(TRow), fieldType);
|
||||
return (Delegate)methInfo.Invoke(null, new object[] { fieldInfo, assignmentOpCode });
|
||||
|
||||
case PropertyInfo propertyInfo:
|
||||
Type propertyType = propertyInfo.PropertyType;
|
||||
|
||||
var assignmentOpCodeProp = GetAssignmentOpCode(propertyType);
|
||||
Func<PropertyInfo, OpCode, Delegate> funcProp = GeneratePeek<TOwn, TRow, int>;
|
||||
var methInfoProp = funcProp.GetMethodInfo().GetGenericMethodDefinition()
|
||||
.MakeGenericMethod(typeof(TOwn), typeof(TRow), propertyType);
|
||||
return (Delegate)methInfoProp.Invoke(null, new object[] { propertyInfo, assignmentOpCodeProp });
|
||||
|
||||
default:
|
||||
Contracts.Assert(false);
|
||||
throw Contracts.ExceptNotSupp("Expected a FieldInfo or a PropertyInfo");
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
private static Delegate GeneratePeek<TOwn, TRow, TValue>(FieldInfo fieldInfo, OpCode assignmentOpCode)
|
||||
|
@ -81,6 +98,28 @@ namespace Microsoft.ML.Runtime.Api
|
|||
return mb.CreateDelegate(typeof(Peek<TRow, TValue>));
|
||||
}
|
||||
|
||||
private static Delegate GeneratePeek<TOwn, TRow, TValue>(PropertyInfo propertyInfo, OpCode assignmentOpCode)
|
||||
{
|
||||
// REVIEW: It seems like we really should cache these, instead of generating them per cursor.
|
||||
Type[] args = { typeof(TOwn), typeof(TRow), typeof(long), typeof(TValue).MakeByRefType() };
|
||||
var mb = new DynamicMethod("Peek", null, args, typeof(TOwn), true);
|
||||
var il = mb.GetILGenerator();
|
||||
var minfo = propertyInfo.GetGetMethod();
|
||||
var opcode = (minfo.IsVirtual || minfo.IsAbstract) ? OpCodes.Callvirt : OpCodes.Call;
|
||||
|
||||
il.Emit(OpCodes.Ldarg_3); // push arg3
|
||||
il.Emit(OpCodes.Ldarg_1); // push arg1
|
||||
il.Emit(opcode, minfo); // call [stack top].get_[propertyInfo]()
|
||||
// Stobj needs to coupled with a type.
|
||||
if (assignmentOpCode == OpCodes.Stobj) // [stack top-1] = [stack top]
|
||||
il.Emit(assignmentOpCode, propertyInfo.PropertyType);
|
||||
else
|
||||
il.Emit(assignmentOpCode);
|
||||
il.Emit(OpCodes.Ret); // ret
|
||||
|
||||
return mb.CreateDelegate(typeof(Peek<TRow, TValue>));
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Each of the specialized 'poke' methods sets the appropriate field value of an instance of T
|
||||
/// to the provided value. So, the call is 'peek(userObject, providedValue)' and the logic is
|
||||
|
@ -88,14 +127,30 @@ namespace Microsoft.ML.Runtime.Api
|
|||
/// </summary>
|
||||
internal static Delegate GeneratePoke<TOwn, TRow>(InternalSchemaDefinition.Column column)
|
||||
{
|
||||
var fieldInfo = column.FieldInfo;
|
||||
Type fieldType = fieldInfo.FieldType;
|
||||
switch (column.MemberInfo)
|
||||
{
|
||||
case FieldInfo fieldInfo:
|
||||
Type fieldType = fieldInfo.FieldType;
|
||||
|
||||
var assignmentOpCode = GetAssignmentOpCode(fieldType);
|
||||
Func<FieldInfo, OpCode, Delegate> func = GeneratePoke<TOwn, TRow, int>;
|
||||
var methInfo = func.GetMethodInfo().GetGenericMethodDefinition()
|
||||
.MakeGenericMethod(typeof(TOwn), typeof(TRow), fieldType);
|
||||
return (Delegate)methInfo.Invoke(null, new object[] { fieldInfo, assignmentOpCode });
|
||||
var assignmentOpCode = GetAssignmentOpCode(fieldType);
|
||||
Func<FieldInfo, OpCode, Delegate> func = GeneratePoke<TOwn, TRow, int>;
|
||||
var methInfo = func.GetMethodInfo().GetGenericMethodDefinition()
|
||||
.MakeGenericMethod(typeof(TOwn), typeof(TRow), fieldType);
|
||||
return (Delegate)methInfo.Invoke(null, new object[] { fieldInfo, assignmentOpCode });
|
||||
|
||||
case PropertyInfo propertyInfo:
|
||||
Type propertyType = propertyInfo.PropertyType;
|
||||
|
||||
var assignmentOpCodeProp = GetAssignmentOpCode(propertyType);
|
||||
Func<PropertyInfo, Delegate> funcProp = GeneratePoke<TOwn, TRow, int>;
|
||||
var methInfoProp = funcProp.GetMethodInfo().GetGenericMethodDefinition()
|
||||
.MakeGenericMethod(typeof(TOwn), typeof(TRow), propertyType);
|
||||
return (Delegate)methInfoProp.Invoke(null, new object[] { propertyInfo });
|
||||
|
||||
default:
|
||||
Contracts.Assert(false);
|
||||
throw Contracts.ExceptNotSupp("Expected a FieldInfo or a PropertyInfo");
|
||||
}
|
||||
}
|
||||
|
||||
private static Delegate GeneratePoke<TOwn, TRow, TValue>(FieldInfo fieldInfo, OpCode assignmentOpCode)
|
||||
|
@ -115,5 +170,20 @@ namespace Microsoft.ML.Runtime.Api
|
|||
il.Emit(OpCodes.Ret); // ret
|
||||
return mb.CreateDelegate(typeof(Poke<TRow, TValue>), null);
|
||||
}
|
||||
|
||||
private static Delegate GeneratePoke<TOwn, TRow, TValue>(PropertyInfo propertyInfo)
|
||||
{
|
||||
Type[] args = { typeof(TOwn), typeof(TRow), typeof(TValue) };
|
||||
var mb = new DynamicMethod("Poke", null, args, typeof(TOwn), true);
|
||||
var il = mb.GetILGenerator();
|
||||
var minfo = propertyInfo.GetSetMethod();
|
||||
var opcode = (minfo.IsVirtual || minfo.IsAbstract) ? OpCodes.Callvirt : OpCodes.Call;
|
||||
|
||||
il.Emit(OpCodes.Ldarg_1); // push arg1
|
||||
il.Emit(OpCodes.Ldarg_2); // push arg2
|
||||
il.Emit(opcode, minfo); // call [stack top-1].set_[propertyInfo]([stack top])
|
||||
il.Emit(OpCodes.Ret); // ret
|
||||
return mb.CreateDelegate(typeof(Poke<TRow, TValue>), null);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -118,7 +118,7 @@ namespace Microsoft.ML.Runtime.Api
|
|||
var colType = DataView.Schema.GetColumnType(index);
|
||||
|
||||
var column = DataView._schema.SchemaDefn.Columns[index];
|
||||
var outputType = column.IsComputed ? column.ReturnType : column.FieldInfo.FieldType;
|
||||
var outputType = column.OutputType;
|
||||
var genericType = outputType;
|
||||
Func<int, Delegate> del;
|
||||
|
||||
|
|
|
@ -23,21 +23,23 @@ namespace Microsoft.ML.Runtime.Api
|
|||
public class Column
|
||||
{
|
||||
public readonly string ColumnName;
|
||||
public readonly FieldInfo FieldInfo;
|
||||
public readonly MemberInfo MemberInfo;
|
||||
public readonly ParameterInfo ReturnParameterInfo;
|
||||
public readonly ColumnType ColumnType;
|
||||
public readonly bool IsComputed;
|
||||
public readonly Delegate Generator;
|
||||
private readonly Dictionary<string, MetadataInfo> _metadata;
|
||||
public Dictionary<string, MetadataInfo> Metadata { get { return _metadata; } }
|
||||
public Type ReturnType {get { return ReturnParameterInfo.ParameterType.GetElementType(); }}
|
||||
public Type ComputedReturnType {get { return ReturnParameterInfo.ParameterType.GetElementType(); }}
|
||||
public Type FieldOrPropertyType => (MemberInfo is FieldInfo) ? (MemberInfo as FieldInfo).FieldType : (MemberInfo as PropertyInfo).PropertyType;
|
||||
public Type OutputType => IsComputed ? ComputedReturnType : FieldOrPropertyType;
|
||||
|
||||
public Column(string columnName, ColumnType columnType, FieldInfo fieldInfo) :
|
||||
this(columnName, columnType, fieldInfo, null, null) { }
|
||||
public Column(string columnName, ColumnType columnType, MemberInfo memberInfo) :
|
||||
this(columnName, columnType, memberInfo, null, null) { }
|
||||
|
||||
public Column(string columnName, ColumnType columnType, FieldInfo fieldInfo,
|
||||
public Column(string columnName, ColumnType columnType, MemberInfo memberInfo,
|
||||
Dictionary<string, MetadataInfo> metadataInfos) :
|
||||
this(columnName, columnType, fieldInfo, null, metadataInfos) { }
|
||||
this(columnName, columnType, memberInfo, null, metadataInfos) { }
|
||||
|
||||
public Column(string columnName, ColumnType columnType, Delegate generator) :
|
||||
this(columnName, columnType, null, generator, null) { }
|
||||
|
@ -46,7 +48,7 @@ namespace Microsoft.ML.Runtime.Api
|
|||
Dictionary<string, MetadataInfo> metadataInfos) :
|
||||
this(columnName, columnType, null, generator, metadataInfos) { }
|
||||
|
||||
private Column(string columnName, ColumnType columnType, FieldInfo fieldInfo = null,
|
||||
private Column(string columnName, ColumnType columnType, MemberInfo memberInfo = null,
|
||||
Delegate generator = null, Dictionary<string, MetadataInfo> metadataInfos = null)
|
||||
{
|
||||
Contracts.AssertNonEmpty(columnName);
|
||||
|
@ -55,8 +57,8 @@ namespace Microsoft.ML.Runtime.Api
|
|||
|
||||
if (generator == null)
|
||||
{
|
||||
Contracts.AssertValue(fieldInfo);
|
||||
FieldInfo = fieldInfo;
|
||||
Contracts.AssertValue(memberInfo);
|
||||
MemberInfo = memberInfo;
|
||||
}
|
||||
else
|
||||
{
|
||||
|
@ -95,8 +97,8 @@ namespace Microsoft.ML.Runtime.Api
|
|||
// If Column is computed type, it must have a generator.
|
||||
Contracts.Assert(IsComputed == (Generator != null));
|
||||
|
||||
// Column must have either a generator or a fieldInfo value.
|
||||
Contracts.Assert((Generator == null) != (FieldInfo == null));
|
||||
// Column must have either a generator or a memberInfo value.
|
||||
Contracts.Assert((Generator == null) != (MemberInfo == null));
|
||||
|
||||
// Additional Checks if there is a generator.
|
||||
if (Generator == null)
|
||||
|
@ -115,9 +117,7 @@ namespace Microsoft.ML.Runtime.Api
|
|||
Contracts.Assert(Generator.GetMethodInfo().ReturnType == typeof(void));
|
||||
|
||||
// Checks that the return type of the generator is compatible with ColumnType.
|
||||
bool isVector;
|
||||
DataKind datakind;
|
||||
GetVectorAndKind(ReturnType, "return type", out isVector, out datakind);
|
||||
GetVectorAndKind(ComputedReturnType, "return type", out bool isVector, out DataKind datakind);
|
||||
Contracts.Assert(isVector == ColumnType.IsVector);
|
||||
Contracts.Assert(datakind == ColumnType.ItemType.RawKind);
|
||||
}
|
||||
|
@ -131,19 +131,30 @@ namespace Microsoft.ML.Runtime.Api
|
|||
}
|
||||
|
||||
/// <summary>
|
||||
/// Given a field info on a type, returns whether this appears to be a vector type,
|
||||
/// Given a field or property info on a type, returns whether this appears to be a vector type,
|
||||
/// and also the associated data kind for this type. If a data kind could not
|
||||
/// be determined, this will throw.
|
||||
/// </summary>
|
||||
/// <param name="fieldInfo">The field info to inspect.</param>
|
||||
/// <param name="memberInfo">The field or property info to inspect.</param>
|
||||
/// <param name="isVector">Whether this appears to be a vector type.</param>
|
||||
/// <param name="kind">The data kind of the type, or items of this type if vector.</param>
|
||||
public static void GetVectorAndKind(FieldInfo fieldInfo, out bool isVector, out DataKind kind)
|
||||
public static void GetVectorAndKind(MemberInfo memberInfo, out bool isVector, out DataKind kind)
|
||||
{
|
||||
Contracts.AssertValue(fieldInfo);
|
||||
Type rawFieldType = fieldInfo.FieldType;
|
||||
var name = fieldInfo.Name;
|
||||
GetVectorAndKind(rawFieldType, name, out isVector, out kind);
|
||||
Contracts.AssertValue(memberInfo);
|
||||
switch (memberInfo)
|
||||
{
|
||||
case FieldInfo fieldInfo:
|
||||
GetVectorAndKind(fieldInfo.FieldType, fieldInfo.Name, out isVector, out kind);
|
||||
break;
|
||||
|
||||
case PropertyInfo propertyInfo:
|
||||
GetVectorAndKind(propertyInfo.PropertyType, propertyInfo.Name, out isVector, out kind);
|
||||
break;
|
||||
|
||||
default:
|
||||
Contracts.Assert(false);
|
||||
throw Contracts.ExceptNotSupp("Expected a FieldInfo or a PropertyInfo");
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
|
@ -211,23 +222,27 @@ namespace Microsoft.ML.Runtime.Api
|
|||
|
||||
bool isVector;
|
||||
DataKind kind;
|
||||
FieldInfo fieldInfo = null;
|
||||
MemberInfo memberInfo = null;
|
||||
|
||||
if (!col.IsComputed)
|
||||
{
|
||||
fieldInfo = userType.GetField(col.MemberName);
|
||||
memberInfo = userType.GetField(col.MemberName);
|
||||
|
||||
if (fieldInfo == null)
|
||||
throw Contracts.ExceptParam(nameof(userSchemaDefinition), "No field with name '{0}' found in type '{1}'",
|
||||
if (memberInfo == null)
|
||||
memberInfo = userType.GetProperty(col.MemberName);
|
||||
|
||||
if (memberInfo == null)
|
||||
throw Contracts.ExceptParam(nameof(userSchemaDefinition), "No field or property with name '{0}' found in type '{1}'",
|
||||
col.MemberName,
|
||||
userType.FullName);
|
||||
|
||||
//Clause to handle the field that may be used to expose the cursor channel.
|
||||
//This field does not need a column.
|
||||
if (fieldInfo.FieldType == typeof(IChannel))
|
||||
if ( (memberInfo is FieldInfo && (memberInfo as FieldInfo).FieldType == typeof(IChannel)) ||
|
||||
(memberInfo is PropertyInfo && (memberInfo as PropertyInfo).PropertyType == typeof(IChannel)))
|
||||
continue;
|
||||
|
||||
GetVectorAndKind(fieldInfo, out isVector, out kind);
|
||||
GetVectorAndKind(memberInfo, out isVector, out kind);
|
||||
}
|
||||
else
|
||||
{
|
||||
|
@ -268,7 +283,7 @@ namespace Microsoft.ML.Runtime.Api
|
|||
|
||||
dstCols[i] = col.IsComputed ?
|
||||
new Column(colName, colType, col.Generator, col.Metadata)
|
||||
: new Column(colName, colType, fieldInfo, col.Metadata);
|
||||
: new Column(colName, colType, memberInfo, col.Metadata);
|
||||
|
||||
}
|
||||
return new InternalSchemaDefinition(dstCols);
|
||||
|
|
|
@ -14,7 +14,7 @@ namespace Microsoft.ML.Runtime.Api
|
|||
/// <summary>
|
||||
/// Attach to a member of a class to indicate that the item type should be of class key.
|
||||
/// </summary>
|
||||
[AttributeUsage(AttributeTargets.Field, AllowMultiple = false, Inherited = true)]
|
||||
[AttributeUsage(AttributeTargets.Field | AttributeTargets.Property, AllowMultiple = false, Inherited = true)]
|
||||
public sealed class KeyTypeAttribute : Attribute
|
||||
{
|
||||
// REVIEW: Property based, but should I just have a constructor?
|
||||
|
@ -46,7 +46,7 @@ namespace Microsoft.ML.Runtime.Api
|
|||
/// Allows a member to be marked as a vector valued field, primarily allowing one to set
|
||||
/// the dimensionality of the resulting array.
|
||||
/// </summary>
|
||||
[AttributeUsage(AttributeTargets.Field, AllowMultiple = false, Inherited = true)]
|
||||
[AttributeUsage(AttributeTargets.Field | AttributeTargets.Property, AllowMultiple = false, Inherited = true)]
|
||||
public sealed class VectorTypeAttribute : Attribute
|
||||
{
|
||||
private readonly int[] _dims;
|
||||
|
@ -66,7 +66,7 @@ namespace Microsoft.ML.Runtime.Api
|
|||
/// Describes column information such as name and the source columns indicies that this
|
||||
/// column encapsulates.
|
||||
/// </summary>
|
||||
[AttributeUsage(AttributeTargets.Field, AllowMultiple = false, Inherited = true)]
|
||||
[AttributeUsage(AttributeTargets.Field | AttributeTargets.Property, AllowMultiple = false, Inherited = true)]
|
||||
public sealed class ColumnAttribute : Attribute
|
||||
{
|
||||
public ColumnAttribute(string ordinal, string name = null)
|
||||
|
@ -97,7 +97,7 @@ namespace Microsoft.ML.Runtime.Api
|
|||
/// Allows a member to specify its column name directly, as opposed to the default
|
||||
/// behavior of using the member name as the column name.
|
||||
/// </summary>
|
||||
[AttributeUsage(AttributeTargets.Field, AllowMultiple = false, Inherited = true)]
|
||||
[AttributeUsage(AttributeTargets.Field | AttributeTargets.Property, AllowMultiple = false, Inherited = true)]
|
||||
public sealed class ColumnNameAttribute : Attribute
|
||||
{
|
||||
private readonly string _name;
|
||||
|
@ -119,7 +119,7 @@ namespace Microsoft.ML.Runtime.Api
|
|||
/// <summary>
|
||||
/// Mark this member as not being exposed as a column in the schema.
|
||||
/// </summary>
|
||||
[AttributeUsage(AttributeTargets.Field, AllowMultiple = false, Inherited = true)]
|
||||
[AttributeUsage(AttributeTargets.Field | AttributeTargets.Property, AllowMultiple = false, Inherited = true)]
|
||||
public sealed class NoColumnAttribute : Attribute
|
||||
{
|
||||
}
|
||||
|
@ -128,7 +128,7 @@ namespace Microsoft.ML.Runtime.Api
|
|||
/// Mark a member that implements exactly IChannel as being permitted to receive
|
||||
/// channel information from an external channel.
|
||||
/// </summary>
|
||||
[AttributeUsage(AttributeTargets.Field, AllowMultiple = false, Inherited = true)]
|
||||
[AttributeUsage(AttributeTargets.Field | AttributeTargets.Property, AllowMultiple = false, Inherited = true)]
|
||||
public sealed class CursorChannelAttribute : Attribute
|
||||
{
|
||||
/// <summary>
|
||||
|
@ -158,19 +158,40 @@ namespace Microsoft.ML.Runtime.Api
|
|||
.Where(x => x.GetCustomAttributes(typeof(CursorChannelAttribute), false).Any())
|
||||
.ToArray();
|
||||
|
||||
var cursorChannelAttrProperties = typeof(T)
|
||||
.GetProperties(BindingFlags.Public | BindingFlags.Instance)
|
||||
.Where(x => x.CanRead && x.CanWrite && x.GetGetMethod() != null && x.GetSetMethod() != null && x.GetIndexParameters().Length == 0)
|
||||
.Where(x => x.GetCustomAttributes(typeof(CursorChannelAttribute), false).Any());
|
||||
|
||||
var cursorChannelAttrMembers = (cursorChannelAttrFields as IEnumerable<MemberInfo>).Concat(cursorChannelAttrProperties).ToArray();
|
||||
|
||||
//Check that there is at most one such field.
|
||||
if (cursorChannelAttrFields.Length == 0)
|
||||
if (cursorChannelAttrMembers.Length == 0)
|
||||
return false;
|
||||
|
||||
ectx.Check(cursorChannelAttrFields.Length == 1,
|
||||
"Only one field with CursorChannel attribute is allowed.");
|
||||
ectx.Check(cursorChannelAttrMembers.Length == 1,
|
||||
"Only one public field or property with CursorChannel attribute is allowed.");
|
||||
|
||||
//Check that the marked field has type IChannel.
|
||||
var cursorChannelFieldInfo = cursorChannelAttrFields[0];
|
||||
ectx.Check(cursorChannelFieldInfo.FieldType == typeof(IChannel),
|
||||
"Field marked as CursorChannel must have type IChannel.");
|
||||
var cursorChannelAttrMemberInfo = cursorChannelAttrMembers[0];
|
||||
switch (cursorChannelAttrMemberInfo)
|
||||
{
|
||||
case FieldInfo cursorChannelAttrFieldInfo:
|
||||
ectx.Check(cursorChannelAttrFieldInfo.FieldType == typeof(IChannel),
|
||||
"Field marked as CursorChannel must have type IChannel.");
|
||||
cursorChannelAttrFieldInfo.SetValue(obj, channel);
|
||||
break;
|
||||
|
||||
cursorChannelFieldInfo.SetValue(obj, channel);
|
||||
case PropertyInfo cursorChannelAttrPropertyInfo:
|
||||
ectx.Check(cursorChannelAttrPropertyInfo.PropertyType == typeof(IChannel),
|
||||
"Property marked as CursorChannel must have type IChannel.");
|
||||
cursorChannelAttrPropertyInfo.SetValue(obj, channel);
|
||||
break;
|
||||
|
||||
default:
|
||||
Contracts.Assert(false);
|
||||
throw Contracts.ExceptNotSupp("Expected a FieldInfo or a PropertyInfo");
|
||||
}
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
@ -319,37 +340,63 @@ namespace Microsoft.ML.Runtime.Api
|
|||
|
||||
SchemaDefinition cols = new SchemaDefinition();
|
||||
HashSet<string> colNames = new HashSet<string>();
|
||||
foreach (var fieldInfo in userType.GetFields())
|
||||
|
||||
var fieldInfos = userType.GetFields(BindingFlags.Public | BindingFlags.Instance);
|
||||
var propertyInfos =
|
||||
userType
|
||||
.GetProperties(BindingFlags.Public | BindingFlags.Instance)
|
||||
.Where(x => x.CanRead && x.CanWrite && x.GetGetMethod() != null && x.GetSetMethod() != null && x.GetIndexParameters().Length == 0);
|
||||
|
||||
var memberInfos = (fieldInfos as IEnumerable<MemberInfo>).Concat(propertyInfos).ToArray();
|
||||
|
||||
foreach (var memberInfo in memberInfos)
|
||||
{
|
||||
// Clause to handle the field that may be used to expose the cursor channel.
|
||||
// This field does not need a column.
|
||||
// REVIEW: maybe validate the channel attribute now, instead
|
||||
// of later at cursor creation.
|
||||
if (fieldInfo.FieldType == typeof(IChannel))
|
||||
continue;
|
||||
// Const fields do not need to be mapped.
|
||||
if (fieldInfo.IsLiteral)
|
||||
switch (memberInfo)
|
||||
{
|
||||
case FieldInfo fieldInfo:
|
||||
if (fieldInfo.FieldType == typeof(IChannel))
|
||||
continue;
|
||||
|
||||
// Const fields do not need to be mapped.
|
||||
if (fieldInfo.IsLiteral)
|
||||
continue;
|
||||
|
||||
break;
|
||||
|
||||
case PropertyInfo propertyInfo:
|
||||
if (propertyInfo.PropertyType == typeof(IChannel))
|
||||
continue;
|
||||
break;
|
||||
|
||||
default:
|
||||
Contracts.Assert(false);
|
||||
throw Contracts.ExceptNotSupp("Expected a FieldInfo or a PropertyInfo");
|
||||
}
|
||||
|
||||
if (memberInfo.GetCustomAttribute<NoColumnAttribute>() != null)
|
||||
continue;
|
||||
|
||||
if (fieldInfo.GetCustomAttribute<NoColumnAttribute>() != null)
|
||||
continue;
|
||||
var mappingAttr = fieldInfo.GetCustomAttribute<ColumnAttribute>();
|
||||
var mappingNameAttr = fieldInfo.GetCustomAttribute<ColumnNameAttribute>();
|
||||
string name = mappingAttr?.Name ?? mappingNameAttr?.Name ?? fieldInfo.Name;
|
||||
var mappingAttr = memberInfo.GetCustomAttribute<ColumnAttribute>();
|
||||
var mappingNameAttr = memberInfo.GetCustomAttribute<ColumnNameAttribute>();
|
||||
string name = mappingAttr?.Name ?? mappingNameAttr?.Name ?? memberInfo.Name;
|
||||
// Disallow duplicate names, because the field enumeration order is not actually
|
||||
// well defined, so we are not gauranteed to have consistent "hiding" from run to
|
||||
// run, across different .NET versions.
|
||||
if (!colNames.Add(name))
|
||||
throw Contracts.ExceptParam(nameof(userType), "Duplicate column name '{0}' detected, this is disallowed", name);
|
||||
|
||||
InternalSchemaDefinition.GetVectorAndKind(fieldInfo, out bool isVector, out DataKind kind);
|
||||
InternalSchemaDefinition.GetVectorAndKind(memberInfo, out bool isVector, out DataKind kind);
|
||||
|
||||
PrimitiveType itemType;
|
||||
var keyAttr = fieldInfo.GetCustomAttribute<KeyTypeAttribute>();
|
||||
var keyAttr = memberInfo.GetCustomAttribute<KeyTypeAttribute>();
|
||||
if (keyAttr != null)
|
||||
{
|
||||
if (!KeyType.IsValidDataKind(kind))
|
||||
throw Contracts.ExceptParam(nameof(userType), "Member {0} marked with KeyType attribute, but does not appear to be a valid kind of data for a key type", fieldInfo.Name);
|
||||
throw Contracts.ExceptParam(nameof(userType), "Member {0} marked with KeyType attribute, but does not appear to be a valid kind of data for a key type", memberInfo.Name);
|
||||
itemType = new KeyType(kind, keyAttr.Min, keyAttr.Count, keyAttr.Contiguous);
|
||||
}
|
||||
else
|
||||
|
@ -357,9 +404,9 @@ namespace Microsoft.ML.Runtime.Api
|
|||
|
||||
// Get the column type.
|
||||
ColumnType columnType;
|
||||
var vectorAttr = fieldInfo.GetCustomAttribute<VectorTypeAttribute>();
|
||||
var vectorAttr = memberInfo.GetCustomAttribute<VectorTypeAttribute>();
|
||||
if (vectorAttr != null && !isVector)
|
||||
throw Contracts.ExceptParam(nameof(userType), "Member {0} marked with VectorType attribute, but does not appear to be a vector type", fieldInfo.Name);
|
||||
throw Contracts.ExceptParam(nameof(userType), "Member {0} marked with VectorType attribute, but does not appear to be a vector type", memberInfo.Name);
|
||||
if (isVector)
|
||||
{
|
||||
int[] dims = vectorAttr?.Dims;
|
||||
|
@ -373,7 +420,7 @@ namespace Microsoft.ML.Runtime.Api
|
|||
else
|
||||
columnType = itemType;
|
||||
|
||||
cols.Add(new Column() { MemberName = fieldInfo.Name, ColumnName = name, ColumnType = columnType });
|
||||
cols.Add(new Column() { MemberName = memberInfo.Name, ColumnName = name, ColumnType = columnType });
|
||||
}
|
||||
return cols;
|
||||
}
|
||||
|
|
|
@ -103,11 +103,11 @@ namespace Microsoft.ML.Runtime.Api
|
|||
throw _host.Except("Column '{0}' not found in the data view", col.ColumnName);
|
||||
}
|
||||
var realColType = _data.Schema.GetColumnType(colIndex);
|
||||
if (!IsCompatibleType(realColType, col.FieldInfo))
|
||||
if (!IsCompatibleType(realColType, col.MemberInfo))
|
||||
{
|
||||
throw _host.Except(
|
||||
"Can't bind the IDataView column '{0}' of type '{1}' to field '{2}' of type '{3}'.",
|
||||
col.ColumnName, realColType, col.FieldInfo.Name, col.FieldInfo.FieldType.FullName);
|
||||
"Can't bind the IDataView column '{0}' of type '{1}' to field or property '{2}' of type '{3}'.",
|
||||
col.ColumnName, realColType, col.MemberInfo.Name, col.FieldOrPropertyType.FullName);
|
||||
}
|
||||
|
||||
acceptedCols.Add(col);
|
||||
|
@ -130,14 +130,12 @@ namespace Microsoft.ML.Runtime.Api
|
|||
}
|
||||
|
||||
/// <summary>
|
||||
/// Returns whether the column type <paramref name="colType"/> can be bound to field <paramref name="fieldInfo"/>.
|
||||
/// Returns whether the column type <paramref name="colType"/> can be bound to field <paramref name="memberInfo"/>.
|
||||
/// They must both be vectors or scalars, and the raw data kind should match.
|
||||
/// </summary>
|
||||
private static bool IsCompatibleType(ColumnType colType, FieldInfo fieldInfo)
|
||||
private static bool IsCompatibleType(ColumnType colType, MemberInfo memberInfo)
|
||||
{
|
||||
bool isVector;
|
||||
DataKind kind;
|
||||
InternalSchemaDefinition.GetVectorAndKind(fieldInfo, out isVector, out kind);
|
||||
InternalSchemaDefinition.GetVectorAndKind(memberInfo, out bool isVector, out DataKind kind);
|
||||
if (isVector)
|
||||
return colType.IsVector && colType.ItemType.RawKind == kind;
|
||||
else
|
||||
|
@ -269,8 +267,7 @@ namespace Microsoft.ML.Runtime.Api
|
|||
private Action<TRow> GenerateSetter(IRow input, int index, InternalSchemaDefinition.Column column, Delegate poke, Delegate peek)
|
||||
{
|
||||
var colType = input.Schema.GetColumnType(index);
|
||||
var fieldInfo = column.FieldInfo;
|
||||
var fieldType = fieldInfo.FieldType;
|
||||
var fieldType = column.OutputType;
|
||||
var genericType = fieldType;
|
||||
Func<IRow, int, Delegate, Delegate, Action<TRow>> del;
|
||||
if (fieldType.IsArray)
|
||||
|
@ -431,7 +428,7 @@ namespace Microsoft.ML.Runtime.Api
|
|||
else
|
||||
{
|
||||
// REVIEW: Is this even possible?
|
||||
throw Ch.ExceptNotImpl("Type '{0}' is not yet supported.", fieldInfo.FieldType.FullName);
|
||||
throw Ch.ExceptNotImpl("Type '{0}' is not yet supported.", column.OutputType.FullName);
|
||||
}
|
||||
MethodInfo meth = del.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(genericType);
|
||||
return (Action<TRow>)meth.Invoke(this, new object[] { input, index, poke, peek });
|
||||
|
|
|
@ -7,6 +7,7 @@ using Microsoft.ML.Runtime;
|
|||
using Microsoft.ML.Runtime.Api;
|
||||
using Microsoft.ML.Runtime.Data;
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using System.Reflection;
|
||||
using System.Text.RegularExpressions;
|
||||
|
@ -71,20 +72,30 @@ namespace Microsoft.ML.Data
|
|||
char separator = '\t', bool allowQuotedStrings = true,
|
||||
bool supportSparse = true, bool trimWhitespace = false)
|
||||
{
|
||||
var fields = typeof(TInput).GetFields();
|
||||
Arguments.Column = new TextLoaderColumn[fields.Length];
|
||||
for (int index = 0; index < fields.Length; index++)
|
||||
var userType = typeof(TInput);
|
||||
|
||||
var fieldInfos = userType.GetFields(BindingFlags.Public | BindingFlags.Instance);
|
||||
|
||||
var propertyInfos =
|
||||
userType
|
||||
.GetProperties(BindingFlags.Public | BindingFlags.Instance)
|
||||
.Where(x => x.CanRead && x.CanWrite && x.GetGetMethod() != null && x.GetSetMethod() != null && x.GetIndexParameters().Length == 0);
|
||||
|
||||
var memberInfos = (fieldInfos as IEnumerable<MemberInfo>).Concat(propertyInfos).ToArray();
|
||||
|
||||
Arguments.Column = new TextLoaderColumn[memberInfos.Length];
|
||||
for (int index = 0; index < memberInfos.Length; index++)
|
||||
{
|
||||
var field = fields[index];
|
||||
var mappingAttr = field.GetCustomAttribute<ColumnAttribute>();
|
||||
var memberInfo = memberInfos[index];
|
||||
var mappingAttr = memberInfo.GetCustomAttribute<ColumnAttribute>();
|
||||
if (mappingAttr == null)
|
||||
throw Contracts.Except($"{field.Name} is missing ColumnAttribute");
|
||||
throw Contracts.Except($"Field or property {memberInfo.Name} is missing ColumnAttribute");
|
||||
|
||||
if (Regex.Match(mappingAttr.Ordinal, @"[^(0-9,\*\-~)]+").Success)
|
||||
throw Contracts.Except($"{mappingAttr.Ordinal} contains invalid characters. " +
|
||||
$"Valid characters are 0-9, *, - and ~");
|
||||
|
||||
var name = mappingAttr.Name ?? field.Name;
|
||||
var name = mappingAttr.Name ?? memberInfo.Name;
|
||||
|
||||
Runtime.Data.TextLoader.Range[] sources;
|
||||
if (!Runtime.Data.TextLoader.Column.TryParseSourceEx(mappingAttr.Ordinal, out sources))
|
||||
|
@ -96,8 +107,23 @@ namespace Microsoft.ML.Data
|
|||
tlc.Name = name;
|
||||
tlc.Source = new TextLoaderRange[sources.Length];
|
||||
DataKind dk;
|
||||
if (!TryGetDataKind(field.FieldType.IsArray ? field.FieldType.GetElementType() : field.FieldType, out dk))
|
||||
throw Contracts.Except($"{name} is of unsupported type.");
|
||||
switch (memberInfo)
|
||||
{
|
||||
case FieldInfo field:
|
||||
if (!TryGetDataKind(field.FieldType.IsArray ? field.FieldType.GetElementType() : field.FieldType, out dk))
|
||||
throw Contracts.Except($"Field {name} is of unsupported type.");
|
||||
|
||||
break;
|
||||
|
||||
case PropertyInfo property:
|
||||
if (!TryGetDataKind(property.PropertyType.IsArray ? property.PropertyType.GetElementType() : property.PropertyType, out dk))
|
||||
throw Contracts.Except($"Property {name} is of unsupported type.");
|
||||
break;
|
||||
|
||||
default:
|
||||
Contracts.Assert(false);
|
||||
throw Contracts.ExceptNotSupp("Expected a FieldInfo or a PropertyInfo");
|
||||
}
|
||||
|
||||
tlc.Type = dk;
|
||||
|
||||
|
|
|
@ -0,0 +1,52 @@
|
|||
<Project Sdk="Microsoft.NET.Sdk">
|
||||
|
||||
<PropertyGroup>
|
||||
<TargetFrameworks>netcoreapp2.0</TargetFrameworks>
|
||||
<TargetFrameworks Condition="'$(OS)' != 'Unix'">$(TargetFrameworks); net461</TargetFrameworks>
|
||||
<NoWarn>2003;$(NoWarn)</NoWarn>
|
||||
<PublicSign>false</PublicSign>
|
||||
<SourceLink></SourceLink>
|
||||
<PlatformTarget>x64</PlatformTarget>
|
||||
</PropertyGroup>
|
||||
|
||||
<ItemGroup>
|
||||
<Compile Include="SmokeTests.fs" />
|
||||
<Compile Include="Program.fs" />
|
||||
</ItemGroup>
|
||||
|
||||
<ItemGroup>
|
||||
<!-- Future updates to this test will check use with F# type providers, so -->
|
||||
<!-- leaving this here for now. -->
|
||||
<!-- <PackageReference Include="FSharp.Data" Version="3.0.0-beta4" /> -->
|
||||
</ItemGroup>
|
||||
|
||||
<ItemGroup>
|
||||
<!-- More projects are referenced than are currently tested. Future updates to -->
|
||||
<!-- these tests will test more of the surface area from F#, so leaving these references -->
|
||||
<!-- here for now. -->
|
||||
<ProjectReference Include="..\..\src\Microsoft.ML.Api\Microsoft.ML.Api.csproj" />
|
||||
<ProjectReference Include="..\..\src\Microsoft.ML.Core\Microsoft.ML.Core.csproj" />
|
||||
<ProjectReference Include="..\..\src\Microsoft.ML.CpuMath\Microsoft.ML.CpuMath.csproj" />
|
||||
<ProjectReference Include="..\..\src\Microsoft.ML.Data\Microsoft.ML.Data.csproj" />
|
||||
<ProjectReference Include="..\..\src\Microsoft.ML.Ensemble\Microsoft.ML.Ensemble.csproj" />
|
||||
<ProjectReference Include="..\..\src\Microsoft.ML.FastTree\Microsoft.ML.FastTree.csproj" />
|
||||
<ProjectReference Include="..\..\src\Microsoft.ML.KMeansClustering\Microsoft.ML.KMeansClustering.csproj" />
|
||||
<ProjectReference Include="..\..\src\Microsoft.ML.LightGBM\Microsoft.ML.LightGBM.csproj" />
|
||||
<ProjectReference Include="..\..\src\Microsoft.ML.Maml\Microsoft.ML.Maml.csproj" />
|
||||
<ProjectReference Include="..\..\src\Microsoft.ML.Onnx\Microsoft.ML.Onnx.csproj" />
|
||||
<ProjectReference Include="..\..\src\Microsoft.ML.Parquet\Microsoft.ML.Parquet.csproj" />
|
||||
<ProjectReference Include="..\..\src\Microsoft.ML.PCA\Microsoft.ML.PCA.csproj" />
|
||||
<ProjectReference Include="..\..\src\Microsoft.ML.PipelineInference\Microsoft.ML.PipelineInference.csproj" />
|
||||
<ProjectReference Include="..\..\src\Microsoft.ML.ResultProcessor\Microsoft.ML.ResultProcessor.csproj" />
|
||||
<ProjectReference Include="..\..\src\Microsoft.ML.StandardLearners\Microsoft.ML.StandardLearners.csproj" />
|
||||
<ProjectReference Include="..\..\src\Microsoft.ML.Sweeper\Microsoft.ML.Sweeper.csproj" />
|
||||
<ProjectReference Include="..\..\src\Microsoft.ML.Transforms\Microsoft.ML.Transforms.csproj" />
|
||||
<ProjectReference Include="..\..\src\Microsoft.ML\Microsoft.ML.csproj" />
|
||||
</ItemGroup>
|
||||
<ItemGroup>
|
||||
<NativeAssemblyReference Include="FastTreeNative" />
|
||||
<NativeAssemblyReference Include="CpuMathNative" />
|
||||
<NativeAssemblyReference Include="FactorizationMachineNative" />
|
||||
</ItemGroup>
|
||||
|
||||
</Project>
|
|
@ -0,0 +1,9 @@
|
|||
namespace Microsoft.ML.FSharp.Tests
|
||||
|
||||
#if NETCOREAPP2_0
|
||||
module Program =
|
||||
|
||||
[<EntryPoint>]
|
||||
let main _ = 0
|
||||
#endif
|
||||
|
|
@ -0,0 +1,263 @@
|
|||
// Licensed to the .NET Foundation under one or more agreements.
|
||||
// The .NET Foundation licenses this file to you under the MIT license.
|
||||
// See the LICENSE file in the project root for more information.
|
||||
|
||||
|
||||
//=================================================================================================
|
||||
// This test can be run either as a compiled test with .NET Core (on any platform) or
|
||||
// manually in script form (to help debug it and also check that F# scripting works with ML.NET).
|
||||
// Running as a script requires using F# Interactive on Windows, and the explicit references below.
|
||||
// The references would normally be created by a package loader for the scripting
|
||||
// environment, e.g. see https://github.com/isaacabraham/ml-test-experiment/, but
|
||||
// here we list them explicitly to avoid the dependency on a package loader,
|
||||
//
|
||||
// You should build Microsoft.ML.FSharp.Tests in Debug mode for framework net461
|
||||
// before running this as a script with F# Interactive by editing the project
|
||||
// file to have:
|
||||
// <TargetFrameworks>netcoreapp2.0; net461</TargetFrameworks>
|
||||
|
||||
#if INTERACTIVE
|
||||
#r "netstandard"
|
||||
#r @"../../bin/AnyCPU.Debug/Microsoft.ML.FSharp.Tests/net461/Microsoft.ML.Core.dll"
|
||||
#r @"../../bin/AnyCPU.Debug/Microsoft.ML.FSharp.Tests/net461/Google.Protobuf.dll"
|
||||
#r @"../../bin/AnyCPU.Debug/Microsoft.ML.FSharp.Tests/net461/Newtonsoft.Json.dll"
|
||||
#r @"../../bin/AnyCPU.Debug/Microsoft.ML.FSharp.Tests/net461/System.CodeDom.dll"
|
||||
#r @"../../bin/AnyCPU.Debug/Microsoft.ML.FSharp.Tests/net461/System.Threading.Tasks.Dataflow.dll"
|
||||
#r @"../../bin/AnyCPU.Debug/Microsoft.ML.FSharp.Tests/net461/Microsoft.ML.CpuMath.dll"
|
||||
#r @"../../bin/AnyCPU.Debug/Microsoft.ML.FSharp.Tests/net461/Microsoft.ML.Data.dll"
|
||||
#r @"../../bin/AnyCPU.Debug/Microsoft.ML.FSharp.Tests/net461/Microsoft.ML.Transforms.dll"
|
||||
#r @"../../bin/AnyCPU.Debug/Microsoft.ML.FSharp.Tests/net461/Microsoft.ML.ResultProcessor.dll"
|
||||
#r @"../../bin/AnyCPU.Debug/Microsoft.ML.FSharp.Tests/net461/Microsoft.ML.PCA.dll"
|
||||
#r @"../../bin/AnyCPU.Debug/Microsoft.ML.FSharp.Tests/net461/Microsoft.ML.KMeansClustering.dll"
|
||||
#r @"../../bin/AnyCPU.Debug/Microsoft.ML.FSharp.Tests/net461/Microsoft.ML.FastTree.dll"
|
||||
#r @"../../bin/AnyCPU.Debug/Microsoft.ML.FSharp.Tests/net461/Microsoft.ML.Api.dll"
|
||||
#r @"../../bin/AnyCPU.Debug/Microsoft.ML.FSharp.Tests/net461/Microsoft.ML.Sweeper.dll"
|
||||
#r @"../../bin/AnyCPU.Debug/Microsoft.ML.FSharp.Tests/net461/Microsoft.ML.dll"
|
||||
#r @"../../bin/AnyCPU.Debug/Microsoft.ML.FSharp.Tests/net461/Microsoft.ML.StandardLearners.dll"
|
||||
#r @"../../bin/AnyCPU.Debug/Microsoft.ML.FSharp.Tests/net461/Microsoft.ML.PipelineInference.dll"
|
||||
#r @"../../bin/AnyCPU.Debug/Microsoft.ML.FSharp.Tests/net461/xunit.core.dll"
|
||||
#r @"../../bin/AnyCPU.Debug/Microsoft.ML.FSharp.Tests/net461/xunit.assert.dll"
|
||||
#r "System"
|
||||
#r "System.ComponentModel.Composition"
|
||||
#r "System.Core"
|
||||
#r "System.Xml.Linq"
|
||||
|
||||
// Later tests will add data import using F# type providers:
|
||||
//#r @"../../packages/fsharp.data/3.0.0-beta4/lib/netstandard2.0/FSharp.Data.dll" // this must be referenced from its package location
|
||||
|
||||
#endif
|
||||
|
||||
//================================================================================
|
||||
// The tests proper start here
|
||||
|
||||
#if !INTERACTIVE
|
||||
namespace Microsoft.ML.FSharp.Tests
|
||||
#endif
|
||||
|
||||
open System
|
||||
open Microsoft.ML
|
||||
open Microsoft.ML.Data
|
||||
open Microsoft.ML.Transforms
|
||||
open Microsoft.ML.Trainers
|
||||
open Microsoft.ML.Runtime.Api
|
||||
open Xunit
|
||||
|
||||
module SmokeTest1 =
|
||||
|
||||
type SentimentData() =
|
||||
[<Column(ordinal = "0"); DefaultValue>]
|
||||
val mutable SentimentText : string
|
||||
[<Column(ordinal = "1", name = "Label"); DefaultValue>]
|
||||
val mutable Sentiment : float32
|
||||
|
||||
type SentimentPrediction() =
|
||||
[<ColumnName("PredictedLabel"); DefaultValue>]
|
||||
val mutable Sentiment : bool
|
||||
|
||||
[<Fact>]
|
||||
let ``FSharp-Sentiment-Smoke-Test`` () =
|
||||
|
||||
// See https://github.com/dotnet/machinelearning/issues/401: forces the loading of ML.NET component assemblies
|
||||
let _load =
|
||||
[ typeof<Microsoft.ML.Runtime.Transforms.TextAnalytics>;
|
||||
typeof<Microsoft.ML.Runtime.FastTree.FastTree> ]
|
||||
|
||||
let testDataPath = __SOURCE_DIRECTORY__ + @"/../data/wikipedia-detox-250-line-data.tsv"
|
||||
|
||||
let pipeline = LearningPipeline()
|
||||
|
||||
pipeline.Add(
|
||||
TextLoader(testDataPath).CreateFrom<SentimentData>(
|
||||
Arguments =
|
||||
TextLoaderArguments(
|
||||
HasHeader = true,
|
||||
Column = [| TextLoaderColumn(Name = "Label",
|
||||
Source = [| TextLoaderRange(0) |],
|
||||
Type = Nullable (Data.DataKind.Num))
|
||||
TextLoaderColumn(Name = "SentimentText",
|
||||
Source = [| TextLoaderRange(1) |],
|
||||
Type = Nullable (Data.DataKind.Text)) |]
|
||||
)))
|
||||
|
||||
pipeline.Add(
|
||||
TextFeaturizer(
|
||||
"Features", [| "SentimentText" |],
|
||||
KeepDiacritics = false,
|
||||
KeepPunctuations = false,
|
||||
TextCase = TextNormalizerTransformCaseNormalizationMode.Lower,
|
||||
OutputTokens = true,
|
||||
VectorNormalizer = TextTransformTextNormKind.L2
|
||||
))
|
||||
|
||||
pipeline.Add(
|
||||
FastTreeBinaryClassifier(
|
||||
NumLeaves = 5,
|
||||
NumTrees = 5,
|
||||
MinDocumentsInLeafs = 2
|
||||
))
|
||||
|
||||
let model = pipeline.Train<SentimentData, SentimentPrediction>()
|
||||
|
||||
let predictions =
|
||||
[ SentimentData(SentimentText = "This is a gross exaggeration. Nobody is setting a kangaroo court. There was a simple addition.")
|
||||
SentimentData(SentimentText = "Sort of ok")
|
||||
SentimentData(SentimentText = "Joe versus the Volcano Coffee Company is a great film.") ]
|
||||
|> model.Predict
|
||||
|
||||
let predictionResults = [ for p in predictions -> p.Sentiment ]
|
||||
Assert.Equal<bool list>(predictionResults, [ false; true; true ])
|
||||
|
||||
module SmokeTest2 =
|
||||
|
||||
[<CLIMutable>]
|
||||
type SentimentData =
|
||||
{ [<Column(ordinal = "0")>]
|
||||
SentimentText : string
|
||||
|
||||
[<Column(ordinal = "1", name = "Label")>]
|
||||
Sentiment : float32 }
|
||||
|
||||
[<CLIMutable>]
|
||||
type SentimentPrediction =
|
||||
{ [<ColumnName("PredictedLabel")>]
|
||||
Sentiment : bool }
|
||||
|
||||
[<Fact>]
|
||||
let ``FSharp-Sentiment-Smoke-Test`` () =
|
||||
|
||||
// See https://github.com/dotnet/machinelearning/issues/401: forces the loading of ML.NET component assemblies
|
||||
let _load =
|
||||
[ typeof<Microsoft.ML.Runtime.Transforms.TextAnalytics>;
|
||||
typeof<Microsoft.ML.Runtime.FastTree.FastTree> ]
|
||||
|
||||
let testDataPath = __SOURCE_DIRECTORY__ + @"/../data/wikipedia-detox-250-line-data.tsv"
|
||||
|
||||
let pipeline = LearningPipeline()
|
||||
|
||||
pipeline.Add(
|
||||
TextLoader(testDataPath).CreateFrom<SentimentData>(
|
||||
Arguments =
|
||||
TextLoaderArguments(
|
||||
HasHeader = true,
|
||||
Column = [| TextLoaderColumn(Name = "Label",
|
||||
Source = [| TextLoaderRange(0) |],
|
||||
Type = Nullable (Data.DataKind.Num))
|
||||
TextLoaderColumn(Name = "SentimentText",
|
||||
Source = [| TextLoaderRange(1) |],
|
||||
Type = Nullable (Data.DataKind.Text)) |]
|
||||
)))
|
||||
|
||||
pipeline.Add(
|
||||
TextFeaturizer(
|
||||
"Features", [| "SentimentText" |],
|
||||
KeepDiacritics = false,
|
||||
KeepPunctuations = false,
|
||||
TextCase = TextNormalizerTransformCaseNormalizationMode.Lower,
|
||||
OutputTokens = true,
|
||||
VectorNormalizer = TextTransformTextNormKind.L2
|
||||
))
|
||||
|
||||
pipeline.Add(
|
||||
FastTreeBinaryClassifier(
|
||||
NumLeaves = 5,
|
||||
NumTrees = 5,
|
||||
MinDocumentsInLeafs = 2
|
||||
))
|
||||
|
||||
let model = pipeline.Train<SentimentData, SentimentPrediction>()
|
||||
|
||||
let predictions =
|
||||
[ { SentimentText = "This is a gross exaggeration. Nobody is setting a kangaroo court. There was a simple addition."; Sentiment = 0.0f }
|
||||
{ SentimentText = "Sort of ok"; Sentiment = 0.0f }
|
||||
{ SentimentText = "Joe versus the Volcano Coffee Company is a great film."; Sentiment = 0.0f } ]
|
||||
|> model.Predict
|
||||
|
||||
let predictionResults = [ for p in predictions -> p.Sentiment ]
|
||||
Assert.Equal<bool list>(predictionResults, [ false; true; true ])
|
||||
|
||||
module SmokeTest3 =
|
||||
|
||||
type SentimentData() =
|
||||
[<Column(ordinal = "0")>]
|
||||
member val SentimentText = "" with get, set
|
||||
|
||||
[<Column(ordinal = "1", name = "Label")>]
|
||||
member val Sentiment = 0.0 with get, set
|
||||
|
||||
type SentimentPrediction() =
|
||||
[<ColumnName("PredictedLabel")>]
|
||||
member val Sentiment = false with get, set
|
||||
|
||||
[<Fact>]
|
||||
let ``FSharp-Sentiment-Smoke-Test`` () =
|
||||
|
||||
// See https://github.com/dotnet/machinelearning/issues/401: forces the loading of ML.NET component assemblies
|
||||
let _load =
|
||||
[ typeof<Microsoft.ML.Runtime.Transforms.TextAnalytics>;
|
||||
typeof<Microsoft.ML.Runtime.FastTree.FastTree> ]
|
||||
|
||||
let testDataPath = __SOURCE_DIRECTORY__ + @"/../data/wikipedia-detox-250-line-data.tsv"
|
||||
|
||||
let pipeline = LearningPipeline()
|
||||
|
||||
pipeline.Add(
|
||||
TextLoader(testDataPath).CreateFrom<SentimentData>(
|
||||
Arguments =
|
||||
TextLoaderArguments(
|
||||
HasHeader = true,
|
||||
Column = [| TextLoaderColumn(Name = "Label",
|
||||
Source = [| TextLoaderRange(0) |],
|
||||
Type = Nullable (Data.DataKind.Num))
|
||||
TextLoaderColumn(Name = "SentimentText",
|
||||
Source = [| TextLoaderRange(1) |],
|
||||
Type = Nullable (Data.DataKind.Text)) |]
|
||||
)))
|
||||
|
||||
pipeline.Add(
|
||||
TextFeaturizer(
|
||||
"Features", [| "SentimentText" |],
|
||||
KeepDiacritics = false,
|
||||
KeepPunctuations = false,
|
||||
TextCase = TextNormalizerTransformCaseNormalizationMode.Lower,
|
||||
OutputTokens = true,
|
||||
VectorNormalizer = TextTransformTextNormKind.L2
|
||||
))
|
||||
|
||||
pipeline.Add(
|
||||
FastTreeBinaryClassifier(
|
||||
NumLeaves = 5,
|
||||
NumTrees = 5,
|
||||
MinDocumentsInLeafs = 2
|
||||
))
|
||||
|
||||
let model = pipeline.Train<SentimentData, SentimentPrediction>()
|
||||
|
||||
let predictions =
|
||||
[ SentimentData(SentimentText = "This is a gross exaggeration. Nobody is setting a kangaroo court. There was a simple addition.")
|
||||
SentimentData(SentimentText = "Sort of ok")
|
||||
SentimentData(SentimentText = "Joe versus the Volcano Coffee Company is a great film.") ]
|
||||
|> model.Predict
|
||||
|
||||
let predictionResults = [ for p in predictions -> p.Sentiment ]
|
||||
Assert.Equal<bool list>(predictionResults, [ false; true; true ])
|
||||
|
|
@ -174,6 +174,49 @@ namespace Microsoft.ML.EntryPoints.Tests
|
|||
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void CanTrainProperties()
|
||||
{
|
||||
var pipeline = new LearningPipeline();
|
||||
var data = new List<IrisDataProperties>() {
|
||||
new IrisDataProperties { SepalLength = 1f, SepalWidth = 1f, PetalLength=0.3f, PetalWidth=5.1f, Label=1},
|
||||
new IrisDataProperties { SepalLength = 1f, SepalWidth = 1f, PetalLength=0.3f, PetalWidth=5.1f, Label=1},
|
||||
new IrisDataProperties { SepalLength = 1.2f, SepalWidth = 0.5f, PetalLength=0.3f, PetalWidth=5.1f, Label=0}
|
||||
};
|
||||
var collection = CollectionDataSource.Create(data);
|
||||
|
||||
pipeline.Add(collection);
|
||||
pipeline.Add(new ColumnConcatenator(outputColumn: "Features",
|
||||
"SepalLength", "SepalWidth", "PetalLength", "PetalWidth"));
|
||||
pipeline.Add(new StochasticDualCoordinateAscentClassifier());
|
||||
PredictionModel<IrisDataProperties, IrisPredictionProperties> model = pipeline.Train<IrisDataProperties, IrisPredictionProperties>();
|
||||
|
||||
IrisPredictionProperties prediction = model.Predict(new IrisDataProperties()
|
||||
{
|
||||
SepalLength = 3.3f,
|
||||
SepalWidth = 1.6f,
|
||||
PetalLength = 0.2f,
|
||||
PetalWidth = 5.1f,
|
||||
});
|
||||
|
||||
pipeline = new LearningPipeline();
|
||||
collection = CollectionDataSource.Create(data.AsEnumerable());
|
||||
pipeline.Add(collection);
|
||||
pipeline.Add(new ColumnConcatenator(outputColumn: "Features",
|
||||
"SepalLength", "SepalWidth", "PetalLength", "PetalWidth"));
|
||||
pipeline.Add(new StochasticDualCoordinateAscentClassifier());
|
||||
model = pipeline.Train<IrisDataProperties, IrisPredictionProperties>();
|
||||
|
||||
prediction = model.Predict(new IrisDataProperties()
|
||||
{
|
||||
SepalLength = 3.3f,
|
||||
SepalWidth = 1.6f,
|
||||
PetalLength = 0.2f,
|
||||
PetalWidth = 5.1f,
|
||||
});
|
||||
|
||||
}
|
||||
|
||||
public class Input
|
||||
{
|
||||
[Column("0")]
|
||||
|
@ -207,6 +250,37 @@ namespace Microsoft.ML.EntryPoints.Tests
|
|||
public float[] PredictedLabels;
|
||||
}
|
||||
|
||||
public class IrisDataProperties
|
||||
{
|
||||
private float _Label;
|
||||
private float _SepalLength;
|
||||
private float _SepalWidth;
|
||||
private float _PetalLength;
|
||||
private float _PetalWidth;
|
||||
|
||||
[Column("0")]
|
||||
public float Label { get { return _Label; } set { _Label = value; } }
|
||||
|
||||
[Column("1")]
|
||||
public float SepalLength { get { return _SepalLength; } set { _SepalLength = value; } }
|
||||
|
||||
[Column("2")]
|
||||
public float SepalWidth { get { return _SepalWidth; } set { _SepalWidth = value; } }
|
||||
|
||||
[Column("3")]
|
||||
public float PetalLength { get { return _PetalLength; } set { _PetalLength = value; } }
|
||||
|
||||
[Column("4")]
|
||||
public float PetalWidth { get { return _PetalWidth; } set { _PetalWidth = value; } }
|
||||
}
|
||||
|
||||
public class IrisPredictionProperties
|
||||
{
|
||||
private float[] _PredictedLabels;
|
||||
[ColumnName("Score")]
|
||||
public float[] PredictedLabels { get { return _PredictedLabels; } set { _PredictedLabels = value; } }
|
||||
}
|
||||
|
||||
public class ConversionSimpleClass
|
||||
{
|
||||
public int fInt;
|
||||
|
@ -257,7 +331,7 @@ namespace Microsoft.ML.EntryPoints.Tests
|
|||
|
||||
public bool CompareThroughReflection<T>(T x, T y)
|
||||
{
|
||||
foreach (var field in typeof(T).GetFields())
|
||||
foreach (var field in typeof(T).GetFields(BindingFlags.Public | BindingFlags.Instance))
|
||||
{
|
||||
var xvalue = field.GetValue(x);
|
||||
var yvalue = field.GetValue(y);
|
||||
|
@ -272,6 +346,25 @@ namespace Microsoft.ML.EntryPoints.Tests
|
|||
return false;
|
||||
}
|
||||
}
|
||||
foreach (var property in typeof(T).GetProperties(BindingFlags.Public | BindingFlags.Instance))
|
||||
{
|
||||
// Don't compare properties with private getters and setters
|
||||
if (!property.CanRead || !property.CanWrite || property.GetGetMethod() == null || property.GetSetMethod() == null)
|
||||
continue;
|
||||
|
||||
var xvalue = property.GetValue(x);
|
||||
var yvalue = property.GetValue(y);
|
||||
if (property.PropertyType.IsArray)
|
||||
{
|
||||
if (!CompareArrayValues(xvalue as Array, yvalue as Array))
|
||||
return false;
|
||||
}
|
||||
else
|
||||
{
|
||||
if (!CompareObjectValues(xvalue, yvalue, property.PropertyType))
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -288,14 +381,6 @@ namespace Microsoft.ML.EntryPoints.Tests
|
|||
return true;
|
||||
}
|
||||
|
||||
public class ClassWithConstField
|
||||
{
|
||||
public const string ConstString = "N";
|
||||
public string fString;
|
||||
public const int ConstInt = 100;
|
||||
public int fInt;
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void RoundTripConversionWithBasicTypes()
|
||||
{
|
||||
|
@ -489,6 +574,50 @@ namespace Microsoft.ML.EntryPoints.Tests
|
|||
}
|
||||
}
|
||||
|
||||
public class ConversionLossMinValueClassProperties
|
||||
{
|
||||
private int? _fInt;
|
||||
private long? _fLong;
|
||||
private short? _fShort;
|
||||
private sbyte? _fsByte;
|
||||
public int? IntProp { get { return _fInt; } set { _fInt = value; } }
|
||||
public short? ShortProp { get { return _fShort; } set { _fShort = value; } }
|
||||
public sbyte? SByteProp { get { return _fsByte; } set { _fsByte = value; } }
|
||||
public long? LongProp { get { return _fLong; } set { _fLong = value; } }
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void ConversionMinValueToNullBehaviorProperties()
|
||||
{
|
||||
using (var env = new TlcEnvironment())
|
||||
{
|
||||
|
||||
var data = new List<ConversionLossMinValueClassProperties>
|
||||
{
|
||||
new ConversionLossMinValueClassProperties() { SByteProp = null, IntProp = null, LongProp = null, ShortProp = null },
|
||||
new ConversionLossMinValueClassProperties() { SByteProp = sbyte.MinValue, IntProp = int.MinValue, LongProp = long.MinValue, ShortProp = short.MinValue }
|
||||
};
|
||||
foreach (var field in typeof(ConversionLossMinValueClassProperties).GetFields())
|
||||
{
|
||||
var dataView = ComponentCreation.CreateDataView(env, data);
|
||||
var enumerator = dataView.AsEnumerable<ConversionLossMinValueClassProperties>(env, false).GetEnumerator();
|
||||
while (enumerator.MoveNext())
|
||||
{
|
||||
Assert.True(enumerator.Current.IntProp == null && enumerator.Current.LongProp == null &&
|
||||
enumerator.Current.SByteProp == null && enumerator.Current.ShortProp == null);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public class ClassWithConstField
|
||||
{
|
||||
public const string ConstString = "N";
|
||||
public string fString;
|
||||
public const int ConstInt = 100;
|
||||
public int fInt;
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void ClassWithConstFieldsConversion()
|
||||
{
|
||||
|
@ -510,6 +639,122 @@ namespace Microsoft.ML.EntryPoints.Tests
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
public class ClassWithMixOfFieldsAndProperties
|
||||
{
|
||||
public string fString;
|
||||
private int _fInt;
|
||||
public int IntProp { get { return _fInt; } set { _fInt = value; } }
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void ClassWithMixOfFieldsAndPropertiesConversion()
|
||||
{
|
||||
var data = new List<ClassWithMixOfFieldsAndProperties>()
|
||||
{
|
||||
new ClassWithMixOfFieldsAndProperties(){ IntProp=1, fString ="lala" },
|
||||
new ClassWithMixOfFieldsAndProperties(){ IntProp=-1, fString ="" },
|
||||
new ClassWithMixOfFieldsAndProperties(){ IntProp=0, fString =null }
|
||||
};
|
||||
|
||||
using (var env = new TlcEnvironment())
|
||||
{
|
||||
var dataView = ComponentCreation.CreateDataView(env, data);
|
||||
var enumeratorSimple = dataView.AsEnumerable<ClassWithMixOfFieldsAndProperties>(env, false).GetEnumerator();
|
||||
var originalEnumerator = data.GetEnumerator();
|
||||
while (enumeratorSimple.MoveNext() && originalEnumerator.MoveNext())
|
||||
Assert.True(CompareThroughReflection(enumeratorSimple.Current, originalEnumerator.Current));
|
||||
Assert.True(!enumeratorSimple.MoveNext() && !originalEnumerator.MoveNext());
|
||||
}
|
||||
}
|
||||
|
||||
public abstract class BaseClassWithInheritedProperties
|
||||
{
|
||||
private string _fString;
|
||||
private byte _fByte;
|
||||
public string StringProp { get { return _fString; } set { _fString = value; } }
|
||||
public abstract long LongProp { get; set; }
|
||||
public virtual byte ByteProp { get { return _fByte; } set { _fByte = value; } }
|
||||
}
|
||||
|
||||
|
||||
public class ClassWithPrivateFieldsAndProperties
|
||||
{
|
||||
public ClassWithPrivateFieldsAndProperties() { seq++; _unusedStaticField++; _unusedPrivateField1 = 100; }
|
||||
static public int seq;
|
||||
static public int _unusedStaticField;
|
||||
private int _unusedPrivateField1;
|
||||
private string _fString;
|
||||
|
||||
// This property is ignored because it has no setter
|
||||
private int UnusedReadOnlyProperty { get { return _unusedPrivateField1; } }
|
||||
|
||||
// This property is ignored because it is private
|
||||
private int UnusedPrivateProperty { get { return _unusedPrivateField1; } set { _unusedPrivateField1 = value; } }
|
||||
|
||||
// This property is ignored because it has a private setter
|
||||
public int UnusedPropertyWithPrivateSetter { get { return _unusedPrivateField1; } private set { _unusedPrivateField1 = value; } }
|
||||
|
||||
// This property is ignored because it has a private getter
|
||||
public int UnusedPropertyWithPrivateGetter { private get { return _unusedPrivateField1; } set { _unusedPrivateField1 = value; } }
|
||||
|
||||
public string StringProp { get { return _fString; } set { _fString = value; } }
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void ClassWithPrivateFieldsAndPropertiesConversion()
|
||||
{
|
||||
var data = new List<ClassWithPrivateFieldsAndProperties>()
|
||||
{
|
||||
new ClassWithPrivateFieldsAndProperties(){ StringProp ="lala" },
|
||||
new ClassWithPrivateFieldsAndProperties(){ StringProp ="baba" }
|
||||
};
|
||||
|
||||
using (var env = new TlcEnvironment())
|
||||
{
|
||||
var dataView = ComponentCreation.CreateDataView(env, data);
|
||||
var enumeratorSimple = dataView.AsEnumerable<ClassWithPrivateFieldsAndProperties>(env, false).GetEnumerator();
|
||||
var originalEnumerator = data.GetEnumerator();
|
||||
while (enumeratorSimple.MoveNext() && originalEnumerator.MoveNext())
|
||||
{
|
||||
Assert.True(CompareThroughReflection(enumeratorSimple.Current, originalEnumerator.Current));
|
||||
Assert.True(enumeratorSimple.Current.UnusedPropertyWithPrivateSetter == 100);
|
||||
}
|
||||
Assert.True(!enumeratorSimple.MoveNext() && !originalEnumerator.MoveNext());
|
||||
}
|
||||
}
|
||||
|
||||
public class ClassWithInheritedProperties : BaseClassWithInheritedProperties
|
||||
{
|
||||
private int _fInt;
|
||||
private long _fLong;
|
||||
private byte _fByte2;
|
||||
public int IntProp { get { return _fInt; } set { _fInt = value; } }
|
||||
public override long LongProp { get { return _fLong; } set { _fLong = value; } }
|
||||
public override byte ByteProp { get { return _fByte2; } set { _fByte2 = value; } }
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void ClassWithInheritedPropertiesConversion()
|
||||
{
|
||||
var data = new List<ClassWithInheritedProperties>()
|
||||
{
|
||||
new ClassWithInheritedProperties(){ IntProp=1, StringProp ="lala", LongProp=17, ByteProp=3 },
|
||||
new ClassWithInheritedProperties(){ IntProp=-1, StringProp ="", LongProp=2, ByteProp=4 },
|
||||
new ClassWithInheritedProperties(){ IntProp=0, StringProp =null, LongProp=18, ByteProp=5 }
|
||||
};
|
||||
|
||||
using (var env = new TlcEnvironment())
|
||||
{
|
||||
var dataView = ComponentCreation.CreateDataView(env, data);
|
||||
var enumeratorSimple = dataView.AsEnumerable<ClassWithInheritedProperties>(env, false).GetEnumerator();
|
||||
var originalEnumerator = data.GetEnumerator();
|
||||
while (enumeratorSimple.MoveNext() && originalEnumerator.MoveNext())
|
||||
Assert.True(CompareThroughReflection(enumeratorSimple.Current, originalEnumerator.Current));
|
||||
Assert.True(!enumeratorSimple.MoveNext() && !originalEnumerator.MoveNext());
|
||||
}
|
||||
}
|
||||
|
||||
public class ClassWithArrays
|
||||
{
|
||||
public string[] fString;
|
||||
|
@ -609,5 +854,129 @@ namespace Microsoft.ML.EntryPoints.Tests
|
|||
Assert.True(!enumeratorNullable.MoveNext() && !originalNullalbleEnumerator.MoveNext());
|
||||
}
|
||||
}
|
||||
public class ClassWithArrayProperties
|
||||
{
|
||||
private string[] _fString;
|
||||
private int[] _fInt;
|
||||
private uint[] _fuInt;
|
||||
private short[] _fShort;
|
||||
private ushort[] _fuShort;
|
||||
private sbyte[] _fsByte;
|
||||
private byte[] _fByte;
|
||||
private long[] _fLong;
|
||||
private ulong[] _fuLong;
|
||||
private float[] _fFloat;
|
||||
private double[] _fDouble;
|
||||
private bool[] _fBool;
|
||||
public string[] StringProp { get { return _fString; } set { _fString = value; } }
|
||||
public int[] IntProp { get { return _fInt; } set { _fInt = value; } }
|
||||
public uint[] UIntProp { get { return _fuInt; } set { _fuInt = value; } }
|
||||
public short[] ShortProp { get { return _fShort; } set { _fShort = value; } }
|
||||
public ushort[] UShortProp { get { return _fuShort; } set { _fuShort = value; } }
|
||||
public sbyte[] SByteProp { get { return _fsByte; } set { _fsByte = value; } }
|
||||
public byte[] ByteProp { get { return _fByte; } set { _fByte = value; } }
|
||||
public long[] LongProp { get { return _fLong; } set { _fLong = value; } }
|
||||
public ulong[] ULongProp { get { return _fuLong; } set { _fuLong = value; } }
|
||||
public float[] FloatProp { get { return _fFloat; } set { _fFloat = value; } }
|
||||
public double[] DobuleProp { get { return _fDouble; } set { _fDouble = value; } }
|
||||
public bool[] BoolProp { get { return _fBool; } set { _fBool = value; } }
|
||||
}
|
||||
|
||||
public class ClassWithNullableArrayProperties
|
||||
{
|
||||
private string[] _fString;
|
||||
private int?[] _fInt;
|
||||
private uint?[] _fuInt;
|
||||
private short?[] _fShort;
|
||||
private ushort?[] _fuShort;
|
||||
private sbyte?[] _fsByte;
|
||||
private byte?[] _fByte;
|
||||
private long?[] _fLong;
|
||||
private ulong?[] _fuLong;
|
||||
private float?[] _fFloat;
|
||||
private double?[] _fDouble;
|
||||
private bool?[] _fBool;
|
||||
|
||||
public string[] StringProp { get { return _fString; } set { _fString = value; } }
|
||||
public int?[] IntProp { get { return _fInt; } set { _fInt = value; } }
|
||||
public uint?[] UIntProp { get { return _fuInt; } set { _fuInt = value; } }
|
||||
public short?[] ShortProp { get { return _fShort; } set { _fShort = value; } }
|
||||
public ushort?[] UShortProp { get { return _fuShort; } set { _fuShort = value; } }
|
||||
public sbyte?[] SByteProp { get { return _fsByte; } set { _fsByte = value; } }
|
||||
public byte?[] ByteProp { get { return _fByte; } set { _fByte = value; } }
|
||||
public long?[] LongProp { get { return _fLong; } set { _fLong = value; } }
|
||||
public ulong?[] ULongProp { get { return _fuLong; } set { _fuLong = value; } }
|
||||
public float?[] SingleProp { get { return _fFloat; } set { _fFloat = value; } }
|
||||
public double?[] DoubleProp { get { return _fDouble; } set { _fDouble = value; } }
|
||||
public bool?[] BoolProp { get { return _fBool; } set { _fBool = value; } }
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void RoundTripConversionWithArrayPropertiess()
|
||||
{
|
||||
|
||||
var data = new List<ClassWithArrayProperties>
|
||||
{
|
||||
new ClassWithArrayProperties()
|
||||
{
|
||||
IntProp = new int[3] { 0, 1, 2 },
|
||||
FloatProp = new float[3] { -0.99f, 0f, 0.99f },
|
||||
StringProp = new string[2] { "hola", "lola" },
|
||||
BoolProp = new bool[2] { true, false },
|
||||
ByteProp = new byte[3] { 0, 124, 255 },
|
||||
DobuleProp = new double[3] { -1, 0, 1 },
|
||||
LongProp = new long[] { 0, 1, 2 },
|
||||
SByteProp = new sbyte[3] { -127, 127, 0 },
|
||||
ShortProp = new short[3] { 0, 1225, 32767 },
|
||||
UIntProp = new uint[2] { 0, uint.MaxValue },
|
||||
ULongProp = new ulong[2] { ulong.MaxValue, 0 },
|
||||
UShortProp = new ushort[2] { 0, ushort.MaxValue }
|
||||
},
|
||||
new ClassWithArrayProperties() { IntProp = new int[3] { -2, 1, 0 }, FloatProp = new float[3] { 0.99f, 0f, -0.99f }, StringProp = new string[2] { "", null } },
|
||||
new ClassWithArrayProperties()
|
||||
};
|
||||
|
||||
var nullableData = new List<ClassWithNullableArrayProperties>
|
||||
{
|
||||
new ClassWithNullableArrayProperties()
|
||||
{
|
||||
IntProp = new int?[3] { null, -1, 1 },
|
||||
SingleProp = new float?[3] { -0.99f, null, 0.99f },
|
||||
StringProp = new string[2] { null, "" },
|
||||
BoolProp = new bool?[3] { true, null, false },
|
||||
ByteProp = new byte?[4] { 0, 125, null, 255 },
|
||||
DoubleProp = new double?[3] { -1, null, 1 },
|
||||
LongProp = new long?[] { null, -1, 1 },
|
||||
SByteProp = new sbyte?[3] { -127, 127, null },
|
||||
ShortProp = new short?[3] { 0, null, 32767 },
|
||||
UIntProp = new uint?[4] { null, 42, 0, uint.MaxValue },
|
||||
ULongProp = new ulong?[3] { ulong.MaxValue, null, 0 },
|
||||
UShortProp = new ushort?[3] { 0, null, ushort.MaxValue }
|
||||
},
|
||||
new ClassWithNullableArrayProperties() { IntProp = new int?[3] { -2, 1, 0 }, SingleProp = new float?[3] { 0.99f, 0f, -0.99f }, StringProp = new string[2] { "lola", "hola" } },
|
||||
new ClassWithNullableArrayProperties()
|
||||
};
|
||||
|
||||
using (var env = new TlcEnvironment())
|
||||
{
|
||||
var dataView = ComponentCreation.CreateDataView(env, data);
|
||||
var enumeratorSimple = dataView.AsEnumerable<ClassWithArrayProperties>(env, false).GetEnumerator();
|
||||
var originalEnumerator = data.GetEnumerator();
|
||||
while (enumeratorSimple.MoveNext() && originalEnumerator.MoveNext())
|
||||
{
|
||||
Assert.True(CompareThroughReflection(enumeratorSimple.Current, originalEnumerator.Current));
|
||||
}
|
||||
Assert.True(!enumeratorSimple.MoveNext() && !originalEnumerator.MoveNext());
|
||||
|
||||
var nullableDataView = ComponentCreation.CreateDataView(env, nullableData);
|
||||
var enumeratorNullable = nullableDataView.AsEnumerable<ClassWithNullableArrayProperties>(env, false).GetEnumerator();
|
||||
var originalNullalbleEnumerator = nullableData.GetEnumerator();
|
||||
while (enumeratorNullable.MoveNext() && originalNullalbleEnumerator.MoveNext())
|
||||
{
|
||||
Assert.True(CompareThroughReflection(enumeratorNullable.Current, originalNullalbleEnumerator.Current));
|
||||
}
|
||||
Assert.True(!enumeratorNullable.MoveNext() && !originalNullalbleEnumerator.MoveNext());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -228,7 +228,7 @@ namespace Microsoft.ML.EntryPoints.Tests
|
|||
public void ThrowsExceptionWithPropertyName()
|
||||
{
|
||||
Exception ex = Assert.Throws<InvalidOperationException>( () => new Data.TextLoader("fakefile.txt").CreateFrom<ModelWithoutColumnAttribute>() );
|
||||
Assert.StartsWith("String1 is missing ColumnAttribute", ex.Message);
|
||||
Assert.StartsWith("Field or property String1 is missing ColumnAttribute", ex.Message);
|
||||
}
|
||||
|
||||
public class QuoteInput
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
|
||||
<ItemGroup>
|
||||
<Project Include="$(MSBuildThisFileDirectory)**\*.csproj" />
|
||||
<Project Include="$(MSBuildThisFileDirectory)**\*.fsproj" />
|
||||
</ItemGroup>
|
||||
|
||||
<Target Name="RunTests">
|
||||
|
|
Загрузка…
Ссылка в новой задаче