File-scoped namespaces in files under `Data` (`Microsoft.ML.Core`) (#6789)
Co-authored-by: Lehonti Ramos <john@doe>
This commit is contained in:
Родитель
34389b63e5
Коммит
aaf226c7e7
|
@ -4,29 +4,28 @@
|
|||
|
||||
using System;
|
||||
|
||||
namespace Microsoft.ML.Data
|
||||
{
|
||||
[BestFriend]
|
||||
internal static class AnnotationBuilderExtensions
|
||||
{
|
||||
/// <summary>
|
||||
/// Add slot names annotation.
|
||||
/// </summary>
|
||||
/// <param name="builder">The <see cref="DataViewSchema.Annotations.Builder"/> to which to add the slot names.</param>
|
||||
/// <param name="size">The size of the slot names vector.</param>
|
||||
/// <param name="getter">The getter delegate for the slot names.</param>
|
||||
public static void AddSlotNames(this DataViewSchema.Annotations.Builder builder, int size, ValueGetter<VBuffer<ReadOnlyMemory<char>>> getter)
|
||||
=> builder.Add(AnnotationUtils.Kinds.SlotNames, new VectorDataViewType(TextDataViewType.Instance, size), getter);
|
||||
namespace Microsoft.ML.Data;
|
||||
|
||||
/// <summary>
|
||||
/// Add key values annotation.
|
||||
/// </summary>
|
||||
/// <typeparam name="TValue">The value type of key values.</typeparam>
|
||||
/// <param name="builder">The <see cref="DataViewSchema.Annotations.Builder"/> to which to add the key values.</param>
|
||||
/// <param name="size">The size of key values vector.</param>
|
||||
/// <param name="valueType">The value type of key values. Its raw type must match <typeparamref name="TValue"/>.</param>
|
||||
/// <param name="getter">The getter delegate for the key values.</param>
|
||||
public static void AddKeyValues<TValue>(this DataViewSchema.Annotations.Builder builder, int size, PrimitiveDataViewType valueType, ValueGetter<VBuffer<TValue>> getter)
|
||||
=> builder.Add(AnnotationUtils.Kinds.KeyValues, new VectorDataViewType(valueType, size), getter);
|
||||
}
|
||||
[BestFriend]
|
||||
internal static class AnnotationBuilderExtensions
|
||||
{
|
||||
/// <summary>
|
||||
/// Add slot names annotation.
|
||||
/// </summary>
|
||||
/// <param name="builder">The <see cref="DataViewSchema.Annotations.Builder"/> to which to add the slot names.</param>
|
||||
/// <param name="size">The size of the slot names vector.</param>
|
||||
/// <param name="getter">The getter delegate for the slot names.</param>
|
||||
public static void AddSlotNames(this DataViewSchema.Annotations.Builder builder, int size, ValueGetter<VBuffer<ReadOnlyMemory<char>>> getter)
|
||||
=> builder.Add(AnnotationUtils.Kinds.SlotNames, new VectorDataViewType(TextDataViewType.Instance, size), getter);
|
||||
|
||||
/// <summary>
|
||||
/// Add key values annotation.
|
||||
/// </summary>
|
||||
/// <typeparam name="TValue">The value type of key values.</typeparam>
|
||||
/// <param name="builder">The <see cref="DataViewSchema.Annotations.Builder"/> to which to add the key values.</param>
|
||||
/// <param name="size">The size of key values vector.</param>
|
||||
/// <param name="valueType">The value type of key values. Its raw type must match <typeparamref name="TValue"/>.</param>
|
||||
/// <param name="getter">The getter delegate for the key values.</param>
|
||||
public static void AddKeyValues<TValue>(this DataViewSchema.Annotations.Builder builder, int size, PrimitiveDataViewType valueType, ValueGetter<VBuffer<TValue>> getter)
|
||||
=> builder.Add(AnnotationUtils.Kinds.KeyValues, new VectorDataViewType(valueType, size), getter);
|
||||
}
|
||||
|
|
|
@ -9,493 +9,492 @@ using System.Threading;
|
|||
using Microsoft.ML.Internal.Utilities;
|
||||
using Microsoft.ML.Runtime;
|
||||
|
||||
namespace Microsoft.ML.Data
|
||||
namespace Microsoft.ML.Data;
|
||||
|
||||
/// <summary>
|
||||
/// Utilities for implementing and using the annotation API of <see cref="DataViewSchema"/>.
|
||||
/// </summary>
|
||||
[BestFriend]
|
||||
internal static class AnnotationUtils
|
||||
{
|
||||
/// <summary>
|
||||
/// Utilities for implementing and using the annotation API of <see cref="DataViewSchema"/>.
|
||||
/// This class lists the canonical annotation kinds
|
||||
/// </summary>
|
||||
[BestFriend]
|
||||
internal static class AnnotationUtils
|
||||
public static class Kinds
|
||||
{
|
||||
/// <summary>
|
||||
/// This class lists the canonical annotation kinds
|
||||
/// Annotation kind for names associated with slots/positions in a vector-valued column.
|
||||
/// The associated annotation type is typically fixed-sized vector of Text.
|
||||
/// </summary>
|
||||
public static class Kinds
|
||||
{
|
||||
/// <summary>
|
||||
/// Annotation kind for names associated with slots/positions in a vector-valued column.
|
||||
/// The associated annotation type is typically fixed-sized vector of Text.
|
||||
/// </summary>
|
||||
public const string SlotNames = "SlotNames";
|
||||
|
||||
/// <summary>
|
||||
/// Annotation kind for values associated with the key indices when the column type's item type
|
||||
/// is a key type. The associated annotation type is typically fixed-sized vector of a primitive
|
||||
/// type. The primitive type is frequently Text, but can be anything.
|
||||
/// </summary>
|
||||
public const string KeyValues = "KeyValues";
|
||||
|
||||
/// <summary>
|
||||
/// Annotation kind for sets of score columns. The value is typically a <see cref="KeyDataViewType"/> with raw type U4.
|
||||
/// </summary>
|
||||
public const string ScoreColumnSetId = "ScoreColumnSetId";
|
||||
|
||||
/// <summary>
|
||||
/// Annotation kind that indicates the prediction kind as a string. For example, "BinaryClassification".
|
||||
/// The value is typically a ReadOnlyMemory<char>.
|
||||
/// </summary>
|
||||
public const string ScoreColumnKind = "ScoreColumnKind";
|
||||
|
||||
/// <summary>
|
||||
/// Annotation kind that indicates the value kind of the score column as a string. For example, "Score", "PredictedLabel", "Probability". The value is typically a ReadOnlyMemory.
|
||||
/// </summary>
|
||||
public const string ScoreValueKind = "ScoreValueKind";
|
||||
|
||||
/// <summary>
|
||||
/// Annotation kind that indicates if a column is normalized. The value is typically a Bool.
|
||||
/// </summary>
|
||||
public const string IsNormalized = "IsNormalized";
|
||||
|
||||
/// <summary>
|
||||
/// Annotation kind that indicates if a column is visible to the users. The value is typically a Bool.
|
||||
/// Not to be confused with IsHidden() that determines if a column is masked.
|
||||
/// </summary>
|
||||
public const string IsUserVisible = "IsUserVisible";
|
||||
|
||||
/// <summary>
|
||||
/// Annotation kind for the label values used in training to be used for the predicted label.
|
||||
/// The value is typically a fixed-sized vector of Text.
|
||||
/// </summary>
|
||||
public const string TrainingLabelValues = "TrainingLabelValues";
|
||||
|
||||
/// <summary>
|
||||
/// Annotation kind that indicates the ranges within a column that are categorical features.
|
||||
/// The value is a vector type of ints with dimension of two. The first dimension
|
||||
/// represents the number of categorical features and second dimension represents the range
|
||||
/// and is of size two. The range has start and end index(both inclusive) of categorical
|
||||
/// slots within that column.
|
||||
/// </summary>
|
||||
public const string CategoricalSlotRanges = "CategoricalSlotRanges";
|
||||
}
|
||||
public const string SlotNames = "SlotNames";
|
||||
|
||||
/// <summary>
|
||||
/// This class holds all pre-defined string values that can be found in canonical annotations
|
||||
/// Annotation kind for values associated with the key indices when the column type's item type
|
||||
/// is a key type. The associated annotation type is typically fixed-sized vector of a primitive
|
||||
/// type. The primitive type is frequently Text, but can be anything.
|
||||
/// </summary>
|
||||
public static class Const
|
||||
{
|
||||
public static class ScoreColumnKind
|
||||
{
|
||||
public const string BinaryClassification = "BinaryClassification";
|
||||
public const string MulticlassClassification = "MulticlassClassification";
|
||||
public const string Regression = "Regression";
|
||||
public const string Ranking = "Ranking";
|
||||
public const string Clustering = "Clustering";
|
||||
public const string MultiOutputRegression = "MultiOutputRegression";
|
||||
public const string AnomalyDetection = "AnomalyDetection";
|
||||
public const string SequenceClassification = "SequenceClassification";
|
||||
public const string QuantileRegression = "QuantileRegression";
|
||||
public const string Recommender = "Recommender";
|
||||
public const string ItemSimilarity = "ItemSimilarity";
|
||||
public const string FeatureContribution = "FeatureContribution";
|
||||
}
|
||||
|
||||
public static class ScoreValueKind
|
||||
{
|
||||
public const string Score = "Score";
|
||||
public const string PredictedLabel = "PredictedLabel";
|
||||
public const string Probability = "Probability";
|
||||
}
|
||||
}
|
||||
public const string KeyValues = "KeyValues";
|
||||
|
||||
/// <summary>
|
||||
/// Helper delegate for marshaling from generic land to specific types. Used by the Marshal method below.
|
||||
/// Annotation kind for sets of score columns. The value is typically a <see cref="KeyDataViewType"/> with raw type U4.
|
||||
/// </summary>
|
||||
public delegate void AnnotationGetter<TValue>(int col, ref TValue dst);
|
||||
public const string ScoreColumnSetId = "ScoreColumnSetId";
|
||||
|
||||
/// <summary>
|
||||
/// Returns a standard exception for responding to an invalid call to GetAnnotation.
|
||||
/// Annotation kind that indicates the prediction kind as a string. For example, "BinaryClassification".
|
||||
/// The value is typically a ReadOnlyMemory<char>.
|
||||
/// </summary>
|
||||
public static Exception ExceptGetAnnotation() => Contracts.Except("Invalid call to GetAnnotation");
|
||||
public const string ScoreColumnKind = "ScoreColumnKind";
|
||||
|
||||
/// <summary>
|
||||
/// Returns a standard exception for responding to an invalid call to GetAnnotation.
|
||||
/// Annotation kind that indicates the value kind of the score column as a string. For example, "Score", "PredictedLabel", "Probability". The value is typically a ReadOnlyMemory.
|
||||
/// </summary>
|
||||
public static Exception ExceptGetAnnotation(this IExceptionContext ctx) => ctx.Except("Invalid call to GetAnnotation");
|
||||
public const string ScoreValueKind = "ScoreValueKind";
|
||||
|
||||
/// <summary>
|
||||
/// Helper to marshal a call to GetAnnotation{TValue} to a specific type.
|
||||
/// Annotation kind that indicates if a column is normalized. The value is typically a Bool.
|
||||
/// </summary>
|
||||
public static void Marshal<THave, TNeed>(this AnnotationGetter<THave> getter, int col, ref TNeed dst)
|
||||
{
|
||||
Contracts.CheckValue(getter, nameof(getter));
|
||||
|
||||
if (typeof(TNeed) != typeof(THave))
|
||||
throw ExceptGetAnnotation();
|
||||
var get = (AnnotationGetter<TNeed>)(Delegate)getter;
|
||||
get(col, ref dst);
|
||||
}
|
||||
public const string IsNormalized = "IsNormalized";
|
||||
|
||||
/// <summary>
|
||||
/// Returns a vector type with item type text and the given size. The size must be positive.
|
||||
/// This is a standard type for annotation consisting of multiple text values, eg SlotNames.
|
||||
/// Annotation kind that indicates if a column is visible to the users. The value is typically a Bool.
|
||||
/// Not to be confused with IsHidden() that determines if a column is masked.
|
||||
/// </summary>
|
||||
public static VectorDataViewType GetNamesType(int size)
|
||||
{
|
||||
Contracts.CheckParam(size > 0, nameof(size), "must be known size");
|
||||
return new VectorDataViewType(TextDataViewType.Instance, size);
|
||||
}
|
||||
public const string IsUserVisible = "IsUserVisible";
|
||||
|
||||
/// <summary>
|
||||
/// Returns a vector type with item type int and the given size.
|
||||
/// The range count must be a positive integer.
|
||||
/// This is a standard type for annotation consisting of multiple int values that represent
|
||||
/// categorical slot ranges with in a column.
|
||||
/// Annotation kind for the label values used in training to be used for the predicted label.
|
||||
/// The value is typically a fixed-sized vector of Text.
|
||||
/// </summary>
|
||||
public static VectorDataViewType GetCategoricalType(int rangeCount)
|
||||
{
|
||||
Contracts.CheckParam(rangeCount > 0, nameof(rangeCount), "must be known size");
|
||||
return new VectorDataViewType(NumberDataViewType.Int32, rangeCount, 2);
|
||||
}
|
||||
|
||||
private static volatile KeyDataViewType _scoreColumnSetIdType;
|
||||
public const string TrainingLabelValues = "TrainingLabelValues";
|
||||
|
||||
/// <summary>
|
||||
/// The type of the ScoreColumnSetId annotation.
|
||||
/// Annotation kind that indicates the ranges within a column that are categorical features.
|
||||
/// The value is a vector type of ints with dimension of two. The first dimension
|
||||
/// represents the number of categorical features and second dimension represents the range
|
||||
/// and is of size two. The range has start and end index(both inclusive) of categorical
|
||||
/// slots within that column.
|
||||
/// </summary>
|
||||
public static KeyDataViewType ScoreColumnSetIdType
|
||||
public const string CategoricalSlotRanges = "CategoricalSlotRanges";
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// This class holds all pre-defined string values that can be found in canonical annotations
|
||||
/// </summary>
|
||||
public static class Const
|
||||
{
|
||||
public static class ScoreColumnKind
|
||||
{
|
||||
get
|
||||
{
|
||||
return _scoreColumnSetIdType ??
|
||||
Interlocked.CompareExchange(ref _scoreColumnSetIdType, new KeyDataViewType(typeof(uint), int.MaxValue), null) ??
|
||||
_scoreColumnSetIdType;
|
||||
}
|
||||
public const string BinaryClassification = "BinaryClassification";
|
||||
public const string MulticlassClassification = "MulticlassClassification";
|
||||
public const string Regression = "Regression";
|
||||
public const string Ranking = "Ranking";
|
||||
public const string Clustering = "Clustering";
|
||||
public const string MultiOutputRegression = "MultiOutputRegression";
|
||||
public const string AnomalyDetection = "AnomalyDetection";
|
||||
public const string SequenceClassification = "SequenceClassification";
|
||||
public const string QuantileRegression = "QuantileRegression";
|
||||
public const string Recommender = "Recommender";
|
||||
public const string ItemSimilarity = "ItemSimilarity";
|
||||
public const string FeatureContribution = "FeatureContribution";
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Returns a key-value pair useful when implementing GetAnnotationTypes(col).
|
||||
/// </summary>
|
||||
public static KeyValuePair<string, DataViewType> GetSlotNamesPair(int size)
|
||||
public static class ScoreValueKind
|
||||
{
|
||||
return GetNamesType(size).GetPair(Kinds.SlotNames);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Returns a key-value pair useful when implementing GetAnnotationTypes(col). This assumes
|
||||
/// that the values of the key type are Text.
|
||||
/// </summary>
|
||||
public static KeyValuePair<string, DataViewType> GetKeyNamesPair(int size)
|
||||
{
|
||||
return GetNamesType(size).GetPair(Kinds.KeyValues);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Given a type and annotation kind string, returns a key-value pair. This is useful when
|
||||
/// implementing GetAnnotationTypes(col).
|
||||
/// </summary>
|
||||
public static KeyValuePair<string, DataViewType> GetPair(this DataViewType type, string kind)
|
||||
{
|
||||
Contracts.CheckValue(type, nameof(type));
|
||||
return new KeyValuePair<string, DataViewType>(kind, type);
|
||||
}
|
||||
|
||||
// REVIEW: This should be in some general utility code.
|
||||
|
||||
/// <summary>
|
||||
/// Prepends a params array to an enumerable. Useful when implementing GetAnnotationTypes.
|
||||
/// </summary>
|
||||
public static IEnumerable<T> Prepend<T>(this IEnumerable<T> tail, params T[] head)
|
||||
{
|
||||
return head.Concat(tail);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Returns the max value for the specified annotation kind.
|
||||
/// The annotation type should be a <see cref="KeyDataViewType"/> with raw type U4.
|
||||
/// colMax will be set to the first column that has the max value for the specified annotation.
|
||||
/// If no column has the specified annotation, colMax is set to -1 and the method returns zero.
|
||||
/// The filter function is called for each column, passing in the schema and the column index, and returns
|
||||
/// true if the column should be considered, false if the column should be skipped.
|
||||
/// </summary>
|
||||
public static uint GetMaxAnnotationKind(this DataViewSchema schema, out int colMax, string annotationKind, Func<DataViewSchema, int, bool> filterFunc = null)
|
||||
{
|
||||
uint max = 0;
|
||||
colMax = -1;
|
||||
for (int col = 0; col < schema.Count; col++)
|
||||
{
|
||||
var columnType = schema[col].Annotations.Schema.GetColumnOrNull(annotationKind)?.Type;
|
||||
if (!(columnType is KeyDataViewType) || columnType.RawType != typeof(uint))
|
||||
continue;
|
||||
if (filterFunc != null && !filterFunc(schema, col))
|
||||
continue;
|
||||
uint value = 0;
|
||||
schema[col].Annotations.GetValue(annotationKind, ref value);
|
||||
if (max < value)
|
||||
{
|
||||
max = value;
|
||||
colMax = col;
|
||||
}
|
||||
}
|
||||
return max;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Returns the set of column ids which match the value of specified annotation kind.
|
||||
/// The annotation type should be a <see cref="KeyDataViewType"/> with raw type U4.
|
||||
/// </summary>
|
||||
public static IEnumerable<int> GetColumnSet(this DataViewSchema schema, string annotationKind, uint value)
|
||||
{
|
||||
for (int col = 0; col < schema.Count; col++)
|
||||
{
|
||||
var columnType = schema[col].Annotations.Schema.GetColumnOrNull(annotationKind)?.Type;
|
||||
if (columnType is KeyDataViewType && columnType.RawType == typeof(uint))
|
||||
{
|
||||
uint val = 0;
|
||||
schema[col].Annotations.GetValue(annotationKind, ref val);
|
||||
if (val == value)
|
||||
yield return col;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Returns the set of column ids which match the value of specified annotation kind.
|
||||
/// The annotation type should be of type text.
|
||||
/// </summary>
|
||||
public static IEnumerable<int> GetColumnSet(this DataViewSchema schema, string annotationKind, string value)
|
||||
{
|
||||
for (int col = 0; col < schema.Count; col++)
|
||||
{
|
||||
var columnType = schema[col].Annotations.Schema.GetColumnOrNull(annotationKind)?.Type;
|
||||
if (columnType is TextDataViewType)
|
||||
{
|
||||
ReadOnlyMemory<char> val = default;
|
||||
schema[col].Annotations.GetValue(annotationKind, ref val);
|
||||
if (ReadOnlyMemoryUtils.EqualsStr(value, val))
|
||||
yield return col;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Returns <c>true</c> if the specified column:
|
||||
/// * has a SlotNames annotation
|
||||
/// * annotation type is VBuffer<ReadOnlyMemory<char>> of length <paramref name="vectorSize"/>.
|
||||
/// </summary>
|
||||
public static bool HasSlotNames(this DataViewSchema.Column column, int vectorSize)
|
||||
{
|
||||
if (vectorSize == 0)
|
||||
return false;
|
||||
|
||||
var metaColumn = column.Annotations.Schema.GetColumnOrNull(Kinds.SlotNames);
|
||||
return
|
||||
metaColumn != null
|
||||
&& metaColumn.Value.Type is VectorDataViewType vectorType
|
||||
&& vectorType.Size == vectorSize
|
||||
&& vectorType.ItemType is TextDataViewType;
|
||||
}
|
||||
|
||||
public static void GetSlotNames(RoleMappedSchema schema, RoleMappedSchema.ColumnRole role, int vectorSize, ref VBuffer<ReadOnlyMemory<char>> slotNames)
|
||||
{
|
||||
Contracts.CheckValueOrNull(schema);
|
||||
Contracts.CheckParam(vectorSize >= 0, nameof(vectorSize));
|
||||
|
||||
IReadOnlyList<DataViewSchema.Column> list = schema?.GetColumns(role);
|
||||
if (list?.Count != 1 || !schema.Schema[list[0].Index].HasSlotNames(vectorSize))
|
||||
VBufferUtils.Resize(ref slotNames, vectorSize, 0);
|
||||
else
|
||||
schema.Schema[list[0].Index].Annotations.GetValue(Kinds.SlotNames, ref slotNames);
|
||||
}
|
||||
|
||||
public static bool NeedsSlotNames(this SchemaShape.Column col)
|
||||
{
|
||||
return col.Annotations.TryFindColumn(Kinds.KeyValues, out var metaCol)
|
||||
&& metaCol.Kind == SchemaShape.Column.VectorKind.Vector
|
||||
&& metaCol.ItemType is TextDataViewType;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Returns whether a column has the <see cref="Kinds.IsNormalized"/> annotation indicated by
|
||||
/// the schema shape.
|
||||
/// </summary>
|
||||
/// <param name="column">The schema shape column to query</param>
|
||||
/// <returns>True if and only if the column has the <see cref="Kinds.IsNormalized"/> annotation
|
||||
/// of a scalar <see cref="BooleanDataViewType"/> type, which we assume, if set, should be <c>true</c>.</returns>
|
||||
public static bool IsNormalized(this SchemaShape.Column column)
|
||||
{
|
||||
Contracts.CheckParam(column.IsValid, nameof(column), "struct not initialized properly");
|
||||
return column.Annotations.TryFindColumn(Kinds.IsNormalized, out var metaCol)
|
||||
&& metaCol.Kind == SchemaShape.Column.VectorKind.Scalar && !metaCol.IsKey
|
||||
&& metaCol.ItemType == BooleanDataViewType.Instance;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Returns whether a column has the <see cref="Kinds.SlotNames"/> annotation indicated by
|
||||
/// the schema shape.
|
||||
/// </summary>
|
||||
/// <param name="col">The schema shape column to query</param>
|
||||
/// <returns>True if and only if the column is a definite sized vector type, has the
|
||||
/// <see cref="Kinds.SlotNames"/> annotation of definite sized vectors of text.</returns>
|
||||
public static bool HasSlotNames(this SchemaShape.Column col)
|
||||
{
|
||||
Contracts.CheckParam(col.IsValid, nameof(col), "struct not initialized properly");
|
||||
return col.Kind == SchemaShape.Column.VectorKind.Vector
|
||||
&& col.Annotations.TryFindColumn(Kinds.SlotNames, out var metaCol)
|
||||
&& metaCol.Kind == SchemaShape.Column.VectorKind.Vector && !metaCol.IsKey
|
||||
&& metaCol.ItemType == TextDataViewType.Instance;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Tries to get the annotation kind of the specified type for a column.
|
||||
/// </summary>
|
||||
/// <typeparam name="T">The raw type of the annotation, should match the PrimitiveType type</typeparam>
|
||||
/// <param name="schema">The schema</param>
|
||||
/// <param name="type">The type of the annotation</param>
|
||||
/// <param name="kind">The annotation kind</param>
|
||||
/// <param name="col">The column</param>
|
||||
/// <param name="value">The value to return, if successful</param>
|
||||
/// <returns>True if the annotation of the right type exists, false otherwise</returns>
|
||||
public static bool TryGetAnnotation<T>(this DataViewSchema schema, PrimitiveDataViewType type, string kind, int col, ref T value)
|
||||
{
|
||||
Contracts.CheckValue(schema, nameof(schema));
|
||||
Contracts.CheckValue(type, nameof(type));
|
||||
|
||||
var annotationType = schema[col].Annotations.Schema.GetColumnOrNull(kind)?.Type;
|
||||
if (!type.Equals(annotationType))
|
||||
return false;
|
||||
schema[col].Annotations.GetValue(kind, ref value);
|
||||
return true;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// The categoricalFeatures is a vector of the indices of categorical features slots.
|
||||
/// This vector should always have an even number of elements, and the elements should be parsed in groups of two consecutive numbers.
|
||||
/// So if its value is the range of numbers: 0,2,3,4,8,9
|
||||
/// look at it as [0,2],[3,4],[8,9].
|
||||
/// The way to interpret that is: feature with indices 0, 1, and 2 are one categorical
|
||||
/// Features with indices 3 and 4 are another categorical. Features 5 and 6 don't appear there, so they are not categoricals.
|
||||
/// </summary>
|
||||
public static bool TryGetCategoricalFeatureIndices(DataViewSchema schema, int colIndex, out int[] categoricalFeatures)
|
||||
{
|
||||
Contracts.CheckValue(schema, nameof(schema));
|
||||
Contracts.Check(colIndex >= 0, nameof(colIndex));
|
||||
|
||||
bool isValid = false;
|
||||
categoricalFeatures = null;
|
||||
if (!(schema[colIndex].Type is VectorDataViewType vecType && vecType.Size > 0))
|
||||
return isValid;
|
||||
|
||||
var type = schema[colIndex].Annotations.Schema.GetColumnOrNull(Kinds.CategoricalSlotRanges)?.Type;
|
||||
if (type?.RawType == typeof(VBuffer<int>))
|
||||
{
|
||||
VBuffer<int> catIndices = default(VBuffer<int>);
|
||||
schema[colIndex].Annotations.GetValue(Kinds.CategoricalSlotRanges, ref catIndices);
|
||||
VBufferUtils.Densify(ref catIndices);
|
||||
int columnSlotsCount = vecType.Size;
|
||||
if (catIndices.Length > 0 && catIndices.Length % 2 == 0 && catIndices.Length <= columnSlotsCount * 2)
|
||||
{
|
||||
int previousEndIndex = -1;
|
||||
isValid = true;
|
||||
var catIndicesValues = catIndices.GetValues();
|
||||
for (int i = 0; i < catIndicesValues.Length; i += 2)
|
||||
{
|
||||
if (catIndicesValues[i] > catIndicesValues[i + 1] ||
|
||||
catIndicesValues[i] <= previousEndIndex ||
|
||||
catIndicesValues[i] >= columnSlotsCount ||
|
||||
catIndicesValues[i + 1] >= columnSlotsCount)
|
||||
{
|
||||
isValid = false;
|
||||
break;
|
||||
}
|
||||
|
||||
previousEndIndex = catIndicesValues[i + 1];
|
||||
}
|
||||
if (isValid)
|
||||
categoricalFeatures = catIndicesValues.ToArray();
|
||||
}
|
||||
}
|
||||
|
||||
return isValid;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Produces sequence of columns that are generated by trainer estimators.
|
||||
/// </summary>
|
||||
/// <param name="isNormalized">whether we should also append 'IsNormalized' (typically for probability column)</param>
|
||||
public static IEnumerable<SchemaShape.Column> GetTrainerOutputAnnotation(bool isNormalized = false)
|
||||
{
|
||||
var cols = new List<SchemaShape.Column>();
|
||||
cols.Add(new SchemaShape.Column(Kinds.ScoreColumnSetId, SchemaShape.Column.VectorKind.Scalar, NumberDataViewType.UInt32, true));
|
||||
cols.Add(new SchemaShape.Column(Kinds.ScoreColumnKind, SchemaShape.Column.VectorKind.Scalar, TextDataViewType.Instance, false));
|
||||
cols.Add(new SchemaShape.Column(Kinds.ScoreValueKind, SchemaShape.Column.VectorKind.Scalar, TextDataViewType.Instance, false));
|
||||
if (isNormalized)
|
||||
cols.Add(new SchemaShape.Column(Kinds.IsNormalized, SchemaShape.Column.VectorKind.Scalar, BooleanDataViewType.Instance, false));
|
||||
return cols;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Produces annotations for the score column generated by trainer estimators for multiclass classification.
|
||||
/// If input LabelColumn is not available it produces slotnames annotation by default.
|
||||
/// </summary>
|
||||
/// <param name="labelColumn">Label column.</param>
|
||||
public static IEnumerable<SchemaShape.Column> AnnotationsForMulticlassScoreColumn(SchemaShape.Column? labelColumn = null)
|
||||
{
|
||||
var cols = new List<SchemaShape.Column>();
|
||||
if (labelColumn != null && labelColumn.Value.IsKey)
|
||||
{
|
||||
if (labelColumn.Value.Annotations.TryFindColumn(Kinds.KeyValues, out var metaCol) &&
|
||||
metaCol.Kind == SchemaShape.Column.VectorKind.Vector)
|
||||
{
|
||||
if (metaCol.ItemType is TextDataViewType)
|
||||
cols.Add(new SchemaShape.Column(Kinds.SlotNames, SchemaShape.Column.VectorKind.Vector, TextDataViewType.Instance, false));
|
||||
cols.Add(new SchemaShape.Column(Kinds.TrainingLabelValues, SchemaShape.Column.VectorKind.Vector, metaCol.ItemType, false));
|
||||
}
|
||||
}
|
||||
cols.AddRange(GetTrainerOutputAnnotation());
|
||||
return cols;
|
||||
}
|
||||
|
||||
private sealed class AnnotationRow : DataViewRow
|
||||
{
|
||||
private readonly DataViewSchema.Annotations _annotations;
|
||||
|
||||
public AnnotationRow(DataViewSchema.Annotations annotations)
|
||||
{
|
||||
Contracts.AssertValue(annotations);
|
||||
_annotations = annotations;
|
||||
}
|
||||
|
||||
public override DataViewSchema Schema => _annotations.Schema;
|
||||
public override long Position => 0;
|
||||
public override long Batch => 0;
|
||||
|
||||
/// <summary>
|
||||
/// Returns a value getter delegate to fetch the value of column with the given columnIndex, from the row.
|
||||
/// This throws if the column is not active in this row, or if the type
|
||||
/// <typeparamref name="TValue"/> differs from this column's type.
|
||||
/// </summary>
|
||||
/// <typeparam name="TValue"> is the column's content type.</typeparam>
|
||||
/// <param name="column"> is the output column whose getter should be returned.</param>
|
||||
public override ValueGetter<TValue> GetGetter<TValue>(DataViewSchema.Column column) => _annotations.GetGetter<TValue>(column);
|
||||
|
||||
public override ValueGetter<DataViewRowId> GetIdGetter() => (ref DataViewRowId dst) => dst = default;
|
||||
|
||||
/// <summary>
|
||||
/// Returns whether the given column is active in this row.
|
||||
/// </summary>
|
||||
public override bool IsColumnActive(DataViewSchema.Column column) => true;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Presents a <see cref="DataViewSchema.Annotations"/> as a an <see cref="DataViewRow"/>.
|
||||
/// </summary>
|
||||
/// <param name="annotations">The annotations to wrap.</param>
|
||||
/// <returns>A row that wraps an input annotations.</returns>
|
||||
[BestFriend]
|
||||
internal static DataViewRow AnnotationsAsRow(DataViewSchema.Annotations annotations)
|
||||
{
|
||||
Contracts.CheckValue(annotations, nameof(annotations));
|
||||
return new AnnotationRow(annotations);
|
||||
public const string Score = "Score";
|
||||
public const string PredictedLabel = "PredictedLabel";
|
||||
public const string Probability = "Probability";
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Helper delegate for marshaling from generic land to specific types. Used by the Marshal method below.
|
||||
/// </summary>
|
||||
public delegate void AnnotationGetter<TValue>(int col, ref TValue dst);
|
||||
|
||||
/// <summary>
|
||||
/// Returns a standard exception for responding to an invalid call to GetAnnotation.
|
||||
/// </summary>
|
||||
public static Exception ExceptGetAnnotation() => Contracts.Except("Invalid call to GetAnnotation");
|
||||
|
||||
/// <summary>
|
||||
/// Returns a standard exception for responding to an invalid call to GetAnnotation.
|
||||
/// </summary>
|
||||
public static Exception ExceptGetAnnotation(this IExceptionContext ctx) => ctx.Except("Invalid call to GetAnnotation");
|
||||
|
||||
/// <summary>
|
||||
/// Helper to marshal a call to GetAnnotation{TValue} to a specific type.
|
||||
/// </summary>
|
||||
public static void Marshal<THave, TNeed>(this AnnotationGetter<THave> getter, int col, ref TNeed dst)
|
||||
{
|
||||
Contracts.CheckValue(getter, nameof(getter));
|
||||
|
||||
if (typeof(TNeed) != typeof(THave))
|
||||
throw ExceptGetAnnotation();
|
||||
var get = (AnnotationGetter<TNeed>)(Delegate)getter;
|
||||
get(col, ref dst);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Returns a vector type with item type text and the given size. The size must be positive.
|
||||
/// This is a standard type for annotation consisting of multiple text values, eg SlotNames.
|
||||
/// </summary>
|
||||
public static VectorDataViewType GetNamesType(int size)
|
||||
{
|
||||
Contracts.CheckParam(size > 0, nameof(size), "must be known size");
|
||||
return new VectorDataViewType(TextDataViewType.Instance, size);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Returns a vector type with item type int and the given size.
|
||||
/// The range count must be a positive integer.
|
||||
/// This is a standard type for annotation consisting of multiple int values that represent
|
||||
/// categorical slot ranges with in a column.
|
||||
/// </summary>
|
||||
public static VectorDataViewType GetCategoricalType(int rangeCount)
|
||||
{
|
||||
Contracts.CheckParam(rangeCount > 0, nameof(rangeCount), "must be known size");
|
||||
return new VectorDataViewType(NumberDataViewType.Int32, rangeCount, 2);
|
||||
}
|
||||
|
||||
private static volatile KeyDataViewType _scoreColumnSetIdType;
|
||||
|
||||
/// <summary>
|
||||
/// The type of the ScoreColumnSetId annotation.
|
||||
/// </summary>
|
||||
public static KeyDataViewType ScoreColumnSetIdType
|
||||
{
|
||||
get
|
||||
{
|
||||
return _scoreColumnSetIdType ??
|
||||
Interlocked.CompareExchange(ref _scoreColumnSetIdType, new KeyDataViewType(typeof(uint), int.MaxValue), null) ??
|
||||
_scoreColumnSetIdType;
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Returns a key-value pair useful when implementing GetAnnotationTypes(col).
|
||||
/// </summary>
|
||||
public static KeyValuePair<string, DataViewType> GetSlotNamesPair(int size)
|
||||
{
|
||||
return GetNamesType(size).GetPair(Kinds.SlotNames);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Returns a key-value pair useful when implementing GetAnnotationTypes(col). This assumes
|
||||
/// that the values of the key type are Text.
|
||||
/// </summary>
|
||||
public static KeyValuePair<string, DataViewType> GetKeyNamesPair(int size)
|
||||
{
|
||||
return GetNamesType(size).GetPair(Kinds.KeyValues);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Given a type and annotation kind string, returns a key-value pair. This is useful when
|
||||
/// implementing GetAnnotationTypes(col).
|
||||
/// </summary>
|
||||
public static KeyValuePair<string, DataViewType> GetPair(this DataViewType type, string kind)
|
||||
{
|
||||
Contracts.CheckValue(type, nameof(type));
|
||||
return new KeyValuePair<string, DataViewType>(kind, type);
|
||||
}
|
||||
|
||||
// REVIEW: This should be in some general utility code.
|
||||
|
||||
/// <summary>
|
||||
/// Prepends a params array to an enumerable. Useful when implementing GetAnnotationTypes.
|
||||
/// </summary>
|
||||
public static IEnumerable<T> Prepend<T>(this IEnumerable<T> tail, params T[] head)
|
||||
{
|
||||
return head.Concat(tail);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Returns the max value for the specified annotation kind.
|
||||
/// The annotation type should be a <see cref="KeyDataViewType"/> with raw type U4.
|
||||
/// colMax will be set to the first column that has the max value for the specified annotation.
|
||||
/// If no column has the specified annotation, colMax is set to -1 and the method returns zero.
|
||||
/// The filter function is called for each column, passing in the schema and the column index, and returns
|
||||
/// true if the column should be considered, false if the column should be skipped.
|
||||
/// </summary>
|
||||
public static uint GetMaxAnnotationKind(this DataViewSchema schema, out int colMax, string annotationKind, Func<DataViewSchema, int, bool> filterFunc = null)
|
||||
{
|
||||
uint max = 0;
|
||||
colMax = -1;
|
||||
for (int col = 0; col < schema.Count; col++)
|
||||
{
|
||||
var columnType = schema[col].Annotations.Schema.GetColumnOrNull(annotationKind)?.Type;
|
||||
if (!(columnType is KeyDataViewType) || columnType.RawType != typeof(uint))
|
||||
continue;
|
||||
if (filterFunc != null && !filterFunc(schema, col))
|
||||
continue;
|
||||
uint value = 0;
|
||||
schema[col].Annotations.GetValue(annotationKind, ref value);
|
||||
if (max < value)
|
||||
{
|
||||
max = value;
|
||||
colMax = col;
|
||||
}
|
||||
}
|
||||
return max;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Returns the set of column ids which match the value of specified annotation kind.
|
||||
/// The annotation type should be a <see cref="KeyDataViewType"/> with raw type U4.
|
||||
/// </summary>
|
||||
public static IEnumerable<int> GetColumnSet(this DataViewSchema schema, string annotationKind, uint value)
|
||||
{
|
||||
for (int col = 0; col < schema.Count; col++)
|
||||
{
|
||||
var columnType = schema[col].Annotations.Schema.GetColumnOrNull(annotationKind)?.Type;
|
||||
if (columnType is KeyDataViewType && columnType.RawType == typeof(uint))
|
||||
{
|
||||
uint val = 0;
|
||||
schema[col].Annotations.GetValue(annotationKind, ref val);
|
||||
if (val == value)
|
||||
yield return col;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Returns the set of column ids which match the value of specified annotation kind.
|
||||
/// The annotation type should be of type text.
|
||||
/// </summary>
|
||||
public static IEnumerable<int> GetColumnSet(this DataViewSchema schema, string annotationKind, string value)
|
||||
{
|
||||
for (int col = 0; col < schema.Count; col++)
|
||||
{
|
||||
var columnType = schema[col].Annotations.Schema.GetColumnOrNull(annotationKind)?.Type;
|
||||
if (columnType is TextDataViewType)
|
||||
{
|
||||
ReadOnlyMemory<char> val = default;
|
||||
schema[col].Annotations.GetValue(annotationKind, ref val);
|
||||
if (ReadOnlyMemoryUtils.EqualsStr(value, val))
|
||||
yield return col;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Returns <c>true</c> if the specified column:
|
||||
/// * has a SlotNames annotation
|
||||
/// * annotation type is VBuffer<ReadOnlyMemory<char>> of length <paramref name="vectorSize"/>.
|
||||
/// </summary>
|
||||
public static bool HasSlotNames(this DataViewSchema.Column column, int vectorSize)
|
||||
{
|
||||
if (vectorSize == 0)
|
||||
return false;
|
||||
|
||||
var metaColumn = column.Annotations.Schema.GetColumnOrNull(Kinds.SlotNames);
|
||||
return
|
||||
metaColumn != null
|
||||
&& metaColumn.Value.Type is VectorDataViewType vectorType
|
||||
&& vectorType.Size == vectorSize
|
||||
&& vectorType.ItemType is TextDataViewType;
|
||||
}
|
||||
|
||||
public static void GetSlotNames(RoleMappedSchema schema, RoleMappedSchema.ColumnRole role, int vectorSize, ref VBuffer<ReadOnlyMemory<char>> slotNames)
|
||||
{
|
||||
Contracts.CheckValueOrNull(schema);
|
||||
Contracts.CheckParam(vectorSize >= 0, nameof(vectorSize));
|
||||
|
||||
IReadOnlyList<DataViewSchema.Column> list = schema?.GetColumns(role);
|
||||
if (list?.Count != 1 || !schema.Schema[list[0].Index].HasSlotNames(vectorSize))
|
||||
VBufferUtils.Resize(ref slotNames, vectorSize, 0);
|
||||
else
|
||||
schema.Schema[list[0].Index].Annotations.GetValue(Kinds.SlotNames, ref slotNames);
|
||||
}
|
||||
|
||||
public static bool NeedsSlotNames(this SchemaShape.Column col)
|
||||
{
|
||||
return col.Annotations.TryFindColumn(Kinds.KeyValues, out var metaCol)
|
||||
&& metaCol.Kind == SchemaShape.Column.VectorKind.Vector
|
||||
&& metaCol.ItemType is TextDataViewType;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Returns whether a column has the <see cref="Kinds.IsNormalized"/> annotation indicated by
|
||||
/// the schema shape.
|
||||
/// </summary>
|
||||
/// <param name="column">The schema shape column to query</param>
|
||||
/// <returns>True if and only if the column has the <see cref="Kinds.IsNormalized"/> annotation
|
||||
/// of a scalar <see cref="BooleanDataViewType"/> type, which we assume, if set, should be <c>true</c>.</returns>
|
||||
public static bool IsNormalized(this SchemaShape.Column column)
|
||||
{
|
||||
Contracts.CheckParam(column.IsValid, nameof(column), "struct not initialized properly");
|
||||
return column.Annotations.TryFindColumn(Kinds.IsNormalized, out var metaCol)
|
||||
&& metaCol.Kind == SchemaShape.Column.VectorKind.Scalar && !metaCol.IsKey
|
||||
&& metaCol.ItemType == BooleanDataViewType.Instance;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Returns whether a column has the <see cref="Kinds.SlotNames"/> annotation indicated by
|
||||
/// the schema shape.
|
||||
/// </summary>
|
||||
/// <param name="col">The schema shape column to query</param>
|
||||
/// <returns>True if and only if the column is a definite sized vector type, has the
|
||||
/// <see cref="Kinds.SlotNames"/> annotation of definite sized vectors of text.</returns>
|
||||
public static bool HasSlotNames(this SchemaShape.Column col)
|
||||
{
|
||||
Contracts.CheckParam(col.IsValid, nameof(col), "struct not initialized properly");
|
||||
return col.Kind == SchemaShape.Column.VectorKind.Vector
|
||||
&& col.Annotations.TryFindColumn(Kinds.SlotNames, out var metaCol)
|
||||
&& metaCol.Kind == SchemaShape.Column.VectorKind.Vector && !metaCol.IsKey
|
||||
&& metaCol.ItemType == TextDataViewType.Instance;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Tries to get the annotation kind of the specified type for a column.
|
||||
/// </summary>
|
||||
/// <typeparam name="T">The raw type of the annotation, should match the PrimitiveType type</typeparam>
|
||||
/// <param name="schema">The schema</param>
|
||||
/// <param name="type">The type of the annotation</param>
|
||||
/// <param name="kind">The annotation kind</param>
|
||||
/// <param name="col">The column</param>
|
||||
/// <param name="value">The value to return, if successful</param>
|
||||
/// <returns>True if the annotation of the right type exists, false otherwise</returns>
|
||||
public static bool TryGetAnnotation<T>(this DataViewSchema schema, PrimitiveDataViewType type, string kind, int col, ref T value)
|
||||
{
|
||||
Contracts.CheckValue(schema, nameof(schema));
|
||||
Contracts.CheckValue(type, nameof(type));
|
||||
|
||||
var annotationType = schema[col].Annotations.Schema.GetColumnOrNull(kind)?.Type;
|
||||
if (!type.Equals(annotationType))
|
||||
return false;
|
||||
schema[col].Annotations.GetValue(kind, ref value);
|
||||
return true;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// The categoricalFeatures is a vector of the indices of categorical features slots.
|
||||
/// This vector should always have an even number of elements, and the elements should be parsed in groups of two consecutive numbers.
|
||||
/// So if its value is the range of numbers: 0,2,3,4,8,9
|
||||
/// look at it as [0,2],[3,4],[8,9].
|
||||
/// The way to interpret that is: feature with indices 0, 1, and 2 are one categorical
|
||||
/// Features with indices 3 and 4 are another categorical. Features 5 and 6 don't appear there, so they are not categoricals.
|
||||
/// </summary>
|
||||
public static bool TryGetCategoricalFeatureIndices(DataViewSchema schema, int colIndex, out int[] categoricalFeatures)
|
||||
{
|
||||
Contracts.CheckValue(schema, nameof(schema));
|
||||
Contracts.Check(colIndex >= 0, nameof(colIndex));
|
||||
|
||||
bool isValid = false;
|
||||
categoricalFeatures = null;
|
||||
if (!(schema[colIndex].Type is VectorDataViewType vecType && vecType.Size > 0))
|
||||
return isValid;
|
||||
|
||||
var type = schema[colIndex].Annotations.Schema.GetColumnOrNull(Kinds.CategoricalSlotRanges)?.Type;
|
||||
if (type?.RawType == typeof(VBuffer<int>))
|
||||
{
|
||||
VBuffer<int> catIndices = default(VBuffer<int>);
|
||||
schema[colIndex].Annotations.GetValue(Kinds.CategoricalSlotRanges, ref catIndices);
|
||||
VBufferUtils.Densify(ref catIndices);
|
||||
int columnSlotsCount = vecType.Size;
|
||||
if (catIndices.Length > 0 && catIndices.Length % 2 == 0 && catIndices.Length <= columnSlotsCount * 2)
|
||||
{
|
||||
int previousEndIndex = -1;
|
||||
isValid = true;
|
||||
var catIndicesValues = catIndices.GetValues();
|
||||
for (int i = 0; i < catIndicesValues.Length; i += 2)
|
||||
{
|
||||
if (catIndicesValues[i] > catIndicesValues[i + 1] ||
|
||||
catIndicesValues[i] <= previousEndIndex ||
|
||||
catIndicesValues[i] >= columnSlotsCount ||
|
||||
catIndicesValues[i + 1] >= columnSlotsCount)
|
||||
{
|
||||
isValid = false;
|
||||
break;
|
||||
}
|
||||
|
||||
previousEndIndex = catIndicesValues[i + 1];
|
||||
}
|
||||
if (isValid)
|
||||
categoricalFeatures = catIndicesValues.ToArray();
|
||||
}
|
||||
}
|
||||
|
||||
return isValid;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Produces sequence of columns that are generated by trainer estimators.
|
||||
/// </summary>
|
||||
/// <param name="isNormalized">whether we should also append 'IsNormalized' (typically for probability column)</param>
|
||||
public static IEnumerable<SchemaShape.Column> GetTrainerOutputAnnotation(bool isNormalized = false)
|
||||
{
|
||||
var cols = new List<SchemaShape.Column>();
|
||||
cols.Add(new SchemaShape.Column(Kinds.ScoreColumnSetId, SchemaShape.Column.VectorKind.Scalar, NumberDataViewType.UInt32, true));
|
||||
cols.Add(new SchemaShape.Column(Kinds.ScoreColumnKind, SchemaShape.Column.VectorKind.Scalar, TextDataViewType.Instance, false));
|
||||
cols.Add(new SchemaShape.Column(Kinds.ScoreValueKind, SchemaShape.Column.VectorKind.Scalar, TextDataViewType.Instance, false));
|
||||
if (isNormalized)
|
||||
cols.Add(new SchemaShape.Column(Kinds.IsNormalized, SchemaShape.Column.VectorKind.Scalar, BooleanDataViewType.Instance, false));
|
||||
return cols;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Produces annotations for the score column generated by trainer estimators for multiclass classification.
|
||||
/// If input LabelColumn is not available it produces slotnames annotation by default.
|
||||
/// </summary>
|
||||
/// <param name="labelColumn">Label column.</param>
|
||||
public static IEnumerable<SchemaShape.Column> AnnotationsForMulticlassScoreColumn(SchemaShape.Column? labelColumn = null)
|
||||
{
|
||||
var cols = new List<SchemaShape.Column>();
|
||||
if (labelColumn != null && labelColumn.Value.IsKey)
|
||||
{
|
||||
if (labelColumn.Value.Annotations.TryFindColumn(Kinds.KeyValues, out var metaCol) &&
|
||||
metaCol.Kind == SchemaShape.Column.VectorKind.Vector)
|
||||
{
|
||||
if (metaCol.ItemType is TextDataViewType)
|
||||
cols.Add(new SchemaShape.Column(Kinds.SlotNames, SchemaShape.Column.VectorKind.Vector, TextDataViewType.Instance, false));
|
||||
cols.Add(new SchemaShape.Column(Kinds.TrainingLabelValues, SchemaShape.Column.VectorKind.Vector, metaCol.ItemType, false));
|
||||
}
|
||||
}
|
||||
cols.AddRange(GetTrainerOutputAnnotation());
|
||||
return cols;
|
||||
}
|
||||
|
||||
private sealed class AnnotationRow : DataViewRow
|
||||
{
|
||||
private readonly DataViewSchema.Annotations _annotations;
|
||||
|
||||
public AnnotationRow(DataViewSchema.Annotations annotations)
|
||||
{
|
||||
Contracts.AssertValue(annotations);
|
||||
_annotations = annotations;
|
||||
}
|
||||
|
||||
public override DataViewSchema Schema => _annotations.Schema;
|
||||
public override long Position => 0;
|
||||
public override long Batch => 0;
|
||||
|
||||
/// <summary>
|
||||
/// Returns a value getter delegate to fetch the value of column with the given columnIndex, from the row.
|
||||
/// This throws if the column is not active in this row, or if the type
|
||||
/// <typeparamref name="TValue"/> differs from this column's type.
|
||||
/// </summary>
|
||||
/// <typeparam name="TValue"> is the column's content type.</typeparam>
|
||||
/// <param name="column"> is the output column whose getter should be returned.</param>
|
||||
public override ValueGetter<TValue> GetGetter<TValue>(DataViewSchema.Column column) => _annotations.GetGetter<TValue>(column);
|
||||
|
||||
public override ValueGetter<DataViewRowId> GetIdGetter() => (ref DataViewRowId dst) => dst = default;
|
||||
|
||||
/// <summary>
|
||||
/// Returns whether the given column is active in this row.
|
||||
/// </summary>
|
||||
public override bool IsColumnActive(DataViewSchema.Column column) => true;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Presents a <see cref="DataViewSchema.Annotations"/> as a an <see cref="DataViewRow"/>.
|
||||
/// </summary>
|
||||
/// <param name="annotations">The annotations to wrap.</param>
|
||||
/// <returns>A row that wraps an input annotations.</returns>
|
||||
[BestFriend]
|
||||
internal static DataViewRow AnnotationsAsRow(DataViewSchema.Annotations annotations)
|
||||
{
|
||||
Contracts.CheckValue(annotations, nameof(annotations));
|
||||
return new AnnotationRow(annotations);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5,165 +5,164 @@
|
|||
using System;
|
||||
using Microsoft.ML.Runtime;
|
||||
|
||||
namespace Microsoft.ML.Data
|
||||
namespace Microsoft.ML.Data;
|
||||
|
||||
/// <summary>
|
||||
/// Extension methods related to the ColumnType class.
|
||||
/// </summary>
|
||||
[BestFriend]
|
||||
internal static class ColumnTypeExtensions
|
||||
{
|
||||
/// <summary>
|
||||
/// Extension methods related to the ColumnType class.
|
||||
/// Whether this type is a standard scalar type completely determined by its <see cref="DataViewType.RawType"/>
|
||||
/// (not a <see cref="KeyDataViewType"/> or <see cref="StructuredDataViewType"/>, etc).
|
||||
/// </summary>
|
||||
[BestFriend]
|
||||
internal static class ColumnTypeExtensions
|
||||
public static bool IsStandardScalar(this DataViewType columnType) =>
|
||||
(columnType is NumberDataViewType) || (columnType is TextDataViewType) || (columnType is BooleanDataViewType) ||
|
||||
(columnType is RowIdDataViewType) || (columnType is TimeSpanDataViewType) ||
|
||||
(columnType is DateTimeDataViewType) || (columnType is DateTimeOffsetDataViewType);
|
||||
|
||||
/// <summary>
|
||||
/// Zero return means it's not a key type.
|
||||
/// </summary>
|
||||
public static ulong GetKeyCount(this DataViewType columnType) => (columnType as KeyDataViewType)?.Count ?? 0;
|
||||
|
||||
/// <summary>
|
||||
/// Sometimes it is necessary to cast the Count to an int. This performs overflow check.
|
||||
/// Zero return means it's not a key type.
|
||||
/// </summary>
|
||||
public static int GetKeyCountAsInt32(this DataViewType columnType, IExceptionContext ectx = null)
|
||||
{
|
||||
/// <summary>
|
||||
/// Whether this type is a standard scalar type completely determined by its <see cref="DataViewType.RawType"/>
|
||||
/// (not a <see cref="KeyDataViewType"/> or <see cref="StructuredDataViewType"/>, etc).
|
||||
/// </summary>
|
||||
public static bool IsStandardScalar(this DataViewType columnType) =>
|
||||
(columnType is NumberDataViewType) || (columnType is TextDataViewType) || (columnType is BooleanDataViewType) ||
|
||||
(columnType is RowIdDataViewType) || (columnType is TimeSpanDataViewType) ||
|
||||
(columnType is DateTimeDataViewType) || (columnType is DateTimeOffsetDataViewType);
|
||||
ulong count = columnType.GetKeyCount();
|
||||
ectx.Check(count <= int.MaxValue, nameof(KeyDataViewType) + "." + nameof(KeyDataViewType.Count) + " exceeds int.MaxValue.");
|
||||
return (int)count;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Zero return means it's not a key type.
|
||||
/// </summary>
|
||||
public static ulong GetKeyCount(this DataViewType columnType) => (columnType as KeyDataViewType)?.Count ?? 0;
|
||||
/// <summary>
|
||||
/// For non-vector types, this returns the column type itself (i.e., return <paramref name="columnType"/>).
|
||||
/// For vector types, this returns the type of the items stored as values in vector.
|
||||
/// </summary>
|
||||
public static DataViewType GetItemType(this DataViewType columnType) => (columnType as VectorDataViewType)?.ItemType ?? columnType;
|
||||
|
||||
/// <summary>
|
||||
/// Sometimes it is necessary to cast the Count to an int. This performs overflow check.
|
||||
/// Zero return means it's not a key type.
|
||||
/// </summary>
|
||||
public static int GetKeyCountAsInt32(this DataViewType columnType, IExceptionContext ectx = null)
|
||||
{
|
||||
ulong count = columnType.GetKeyCount();
|
||||
ectx.Check(count <= int.MaxValue, nameof(KeyDataViewType) + "." + nameof(KeyDataViewType.Count) + " exceeds int.MaxValue.");
|
||||
return (int)count;
|
||||
}
|
||||
/// <summary>
|
||||
/// Zero return means either it's not a vector or the size is unknown.
|
||||
/// </summary>
|
||||
public static int GetVectorSize(this DataViewType columnType) => (columnType as VectorDataViewType)?.Size ?? 0;
|
||||
|
||||
/// <summary>
|
||||
/// For non-vector types, this returns the column type itself (i.e., return <paramref name="columnType"/>).
|
||||
/// For vector types, this returns the type of the items stored as values in vector.
|
||||
/// </summary>
|
||||
public static DataViewType GetItemType(this DataViewType columnType) => (columnType as VectorDataViewType)?.ItemType ?? columnType;
|
||||
/// <summary>
|
||||
/// For non-vectors, this returns one. For unknown size vectors, it returns zero.
|
||||
/// For known sized vectors, it returns size.
|
||||
/// </summary>
|
||||
public static int GetValueCount(this DataViewType columnType) => (columnType as VectorDataViewType)?.Size ?? 1;
|
||||
|
||||
/// <summary>
|
||||
/// Zero return means either it's not a vector or the size is unknown.
|
||||
/// </summary>
|
||||
public static int GetVectorSize(this DataViewType columnType) => (columnType as VectorDataViewType)?.Size ?? 0;
|
||||
/// <summary>
|
||||
/// Whether this is a vector type with known size. Returns false for non-vector types.
|
||||
/// Equivalent to <c><see cref="GetVectorSize"/> > 0</c>.
|
||||
/// </summary>
|
||||
public static bool IsKnownSizeVector(this DataViewType columnType) => columnType.GetVectorSize() > 0;
|
||||
|
||||
/// <summary>
|
||||
/// For non-vectors, this returns one. For unknown size vectors, it returns zero.
|
||||
/// For known sized vectors, it returns size.
|
||||
/// </summary>
|
||||
public static int GetValueCount(this DataViewType columnType) => (columnType as VectorDataViewType)?.Size ?? 1;
|
||||
/// <summary>
|
||||
/// Gets the equivalent <see cref="InternalDataKind"/> for the <paramref name="columnType"/>'s RawType.
|
||||
/// This can return default(<see cref="InternalDataKind"/>) if the RawType doesn't have a corresponding
|
||||
/// <see cref="InternalDataKind"/>.
|
||||
/// </summary>
|
||||
public static InternalDataKind GetRawKind(this DataViewType columnType)
|
||||
{
|
||||
columnType.RawType.TryGetDataKind(out InternalDataKind result);
|
||||
return result;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Whether this is a vector type with known size. Returns false for non-vector types.
|
||||
/// Equivalent to <c><see cref="GetVectorSize"/> > 0</c>.
|
||||
/// </summary>
|
||||
public static bool IsKnownSizeVector(this DataViewType columnType) => columnType.GetVectorSize() > 0;
|
||||
/// <summary>
|
||||
/// Equivalent to calling Equals(ColumnType) for non-vector types. For vector type,
|
||||
/// returns true if current and other vector types have the same size and item type.
|
||||
/// </summary>
|
||||
public static bool SameSizeAndItemType(this DataViewType columnType, DataViewType other)
|
||||
{
|
||||
if (other == null)
|
||||
return false;
|
||||
|
||||
/// <summary>
|
||||
/// Gets the equivalent <see cref="InternalDataKind"/> for the <paramref name="columnType"/>'s RawType.
|
||||
/// This can return default(<see cref="InternalDataKind"/>) if the RawType doesn't have a corresponding
|
||||
/// <see cref="InternalDataKind"/>.
|
||||
/// </summary>
|
||||
public static InternalDataKind GetRawKind(this DataViewType columnType)
|
||||
{
|
||||
columnType.RawType.TryGetDataKind(out InternalDataKind result);
|
||||
return result;
|
||||
}
|
||||
if (columnType.Equals(other))
|
||||
return true;
|
||||
|
||||
/// <summary>
|
||||
/// Equivalent to calling Equals(ColumnType) for non-vector types. For vector type,
|
||||
/// returns true if current and other vector types have the same size and item type.
|
||||
/// </summary>
|
||||
public static bool SameSizeAndItemType(this DataViewType columnType, DataViewType other)
|
||||
{
|
||||
if (other == null)
|
||||
return false;
|
||||
// For vector types, we don't care about the factoring of the dimensions.
|
||||
if (!(columnType is VectorDataViewType vectorType) || !(other is VectorDataViewType otherVectorType))
|
||||
return false;
|
||||
if (!vectorType.ItemType.Equals(otherVectorType.ItemType))
|
||||
return false;
|
||||
return vectorType.Size == otherVectorType.Size;
|
||||
}
|
||||
|
||||
if (columnType.Equals(other))
|
||||
return true;
|
||||
public static PrimitiveDataViewType PrimitiveTypeFromType(Type type)
|
||||
{
|
||||
if (type == typeof(ReadOnlyMemory<char>) || type == typeof(string))
|
||||
return TextDataViewType.Instance;
|
||||
if (type == typeof(bool))
|
||||
return BooleanDataViewType.Instance;
|
||||
if (type == typeof(TimeSpan))
|
||||
return TimeSpanDataViewType.Instance;
|
||||
if (type == typeof(DateTime))
|
||||
return DateTimeDataViewType.Instance;
|
||||
if (type == typeof(DateTimeOffset))
|
||||
return DateTimeOffsetDataViewType.Instance;
|
||||
if (type == typeof(DataViewRowId))
|
||||
return RowIdDataViewType.Instance;
|
||||
return NumberTypeFromType(type);
|
||||
}
|
||||
|
||||
// For vector types, we don't care about the factoring of the dimensions.
|
||||
if (!(columnType is VectorDataViewType vectorType) || !(other is VectorDataViewType otherVectorType))
|
||||
return false;
|
||||
if (!vectorType.ItemType.Equals(otherVectorType.ItemType))
|
||||
return false;
|
||||
return vectorType.Size == otherVectorType.Size;
|
||||
}
|
||||
public static PrimitiveDataViewType PrimitiveTypeFromKind(InternalDataKind kind)
|
||||
{
|
||||
if (kind == InternalDataKind.TX)
|
||||
return TextDataViewType.Instance;
|
||||
if (kind == InternalDataKind.BL)
|
||||
return BooleanDataViewType.Instance;
|
||||
if (kind == InternalDataKind.TS)
|
||||
return TimeSpanDataViewType.Instance;
|
||||
if (kind == InternalDataKind.DT)
|
||||
return DateTimeDataViewType.Instance;
|
||||
if (kind == InternalDataKind.DZ)
|
||||
return DateTimeOffsetDataViewType.Instance;
|
||||
if (kind == InternalDataKind.UG)
|
||||
return RowIdDataViewType.Instance;
|
||||
return NumberTypeFromKind(kind);
|
||||
}
|
||||
|
||||
public static PrimitiveDataViewType PrimitiveTypeFromType(Type type)
|
||||
{
|
||||
if (type == typeof(ReadOnlyMemory<char>) || type == typeof(string))
|
||||
return TextDataViewType.Instance;
|
||||
if (type == typeof(bool))
|
||||
return BooleanDataViewType.Instance;
|
||||
if (type == typeof(TimeSpan))
|
||||
return TimeSpanDataViewType.Instance;
|
||||
if (type == typeof(DateTime))
|
||||
return DateTimeDataViewType.Instance;
|
||||
if (type == typeof(DateTimeOffset))
|
||||
return DateTimeOffsetDataViewType.Instance;
|
||||
if (type == typeof(DataViewRowId))
|
||||
return RowIdDataViewType.Instance;
|
||||
return NumberTypeFromType(type);
|
||||
}
|
||||
|
||||
public static PrimitiveDataViewType PrimitiveTypeFromKind(InternalDataKind kind)
|
||||
{
|
||||
if (kind == InternalDataKind.TX)
|
||||
return TextDataViewType.Instance;
|
||||
if (kind == InternalDataKind.BL)
|
||||
return BooleanDataViewType.Instance;
|
||||
if (kind == InternalDataKind.TS)
|
||||
return TimeSpanDataViewType.Instance;
|
||||
if (kind == InternalDataKind.DT)
|
||||
return DateTimeDataViewType.Instance;
|
||||
if (kind == InternalDataKind.DZ)
|
||||
return DateTimeOffsetDataViewType.Instance;
|
||||
if (kind == InternalDataKind.UG)
|
||||
return RowIdDataViewType.Instance;
|
||||
public static NumberDataViewType NumberTypeFromType(Type type)
|
||||
{
|
||||
InternalDataKind kind;
|
||||
if (type.TryGetDataKind(out kind))
|
||||
return NumberTypeFromKind(kind);
|
||||
}
|
||||
|
||||
public static NumberDataViewType NumberTypeFromType(Type type)
|
||||
Contracts.Assert(false);
|
||||
throw new InvalidOperationException($"Bad type in {nameof(ColumnTypeExtensions)}.{nameof(NumberTypeFromType)}: {type}");
|
||||
}
|
||||
|
||||
private static NumberDataViewType NumberTypeFromKind(InternalDataKind kind)
|
||||
{
|
||||
switch (kind)
|
||||
{
|
||||
InternalDataKind kind;
|
||||
if (type.TryGetDataKind(out kind))
|
||||
return NumberTypeFromKind(kind);
|
||||
|
||||
Contracts.Assert(false);
|
||||
throw new InvalidOperationException($"Bad type in {nameof(ColumnTypeExtensions)}.{nameof(NumberTypeFromType)}: {type}");
|
||||
case InternalDataKind.I1:
|
||||
return NumberDataViewType.SByte;
|
||||
case InternalDataKind.U1:
|
||||
return NumberDataViewType.Byte;
|
||||
case InternalDataKind.I2:
|
||||
return NumberDataViewType.Int16;
|
||||
case InternalDataKind.U2:
|
||||
return NumberDataViewType.UInt16;
|
||||
case InternalDataKind.I4:
|
||||
return NumberDataViewType.Int32;
|
||||
case InternalDataKind.U4:
|
||||
return NumberDataViewType.UInt32;
|
||||
case InternalDataKind.I8:
|
||||
return NumberDataViewType.Int64;
|
||||
case InternalDataKind.U8:
|
||||
return NumberDataViewType.UInt64;
|
||||
case InternalDataKind.R4:
|
||||
return NumberDataViewType.Single;
|
||||
case InternalDataKind.R8:
|
||||
return NumberDataViewType.Double;
|
||||
}
|
||||
|
||||
private static NumberDataViewType NumberTypeFromKind(InternalDataKind kind)
|
||||
{
|
||||
switch (kind)
|
||||
{
|
||||
case InternalDataKind.I1:
|
||||
return NumberDataViewType.SByte;
|
||||
case InternalDataKind.U1:
|
||||
return NumberDataViewType.Byte;
|
||||
case InternalDataKind.I2:
|
||||
return NumberDataViewType.Int16;
|
||||
case InternalDataKind.U2:
|
||||
return NumberDataViewType.UInt16;
|
||||
case InternalDataKind.I4:
|
||||
return NumberDataViewType.Int32;
|
||||
case InternalDataKind.U4:
|
||||
return NumberDataViewType.UInt32;
|
||||
case InternalDataKind.I8:
|
||||
return NumberDataViewType.Int64;
|
||||
case InternalDataKind.U8:
|
||||
return NumberDataViewType.UInt64;
|
||||
case InternalDataKind.R4:
|
||||
return NumberDataViewType.Single;
|
||||
case InternalDataKind.R8:
|
||||
return NumberDataViewType.Double;
|
||||
}
|
||||
|
||||
Contracts.Assert(false);
|
||||
throw new InvalidOperationException($"Bad data kind in {nameof(ColumnTypeExtensions)}.{nameof(NumberTypeFromKind)}: {kind}");
|
||||
}
|
||||
Contracts.Assert(false);
|
||||
throw new InvalidOperationException($"Bad data kind in {nameof(ColumnTypeExtensions)}.{nameof(NumberTypeFromKind)}: {kind}");
|
||||
}
|
||||
}
|
||||
|
|
|
@ -4,380 +4,379 @@
|
|||
|
||||
using System;
|
||||
using Microsoft.ML.Runtime;
|
||||
namespace Microsoft.ML.Data
|
||||
namespace Microsoft.ML.Data;
|
||||
|
||||
/// <summary>
|
||||
/// Specifies a simple data type.
|
||||
/// </summary>
|
||||
/// <remarks>
|
||||
/// <format type="text/markdown"><![CDATA[
|
||||
/// Some transforms use the default value and/or missing value of the data types.
|
||||
/// The table below shows the default value definition for each of the data types.
|
||||
///
|
||||
/// | Type | Default Value | IsDefault Indicator |
|
||||
/// | -- | -- | -- |
|
||||
/// | <xref:Microsoft.ML.Data.DataKind.String> or [text](xref:Microsoft.ML.Data.TextDataViewType) | Empty or `null` string (both result in empty `System.ReadOnlyMemory<char>` | <xref:System.ReadOnlyMemory`1.IsEmpty*> |
|
||||
/// | [Key](xref:Microsoft.ML.Data.KeyDataViewType) type (supported by the unsigned integer types in `DataKind`) | Not defined | Always `false` |
|
||||
/// | All other types | Default value of the corresponding system type as defined by .NET standard. In C#, default value expression `default(T)` provides that value. | Equality test with the default value |
|
||||
///
|
||||
/// The table below shows the missing value definition for each of the data types.
|
||||
///
|
||||
/// | Type | Missing Value | IsMissing Indicator |
|
||||
/// | -- | -- | -- |
|
||||
/// | <xref:Microsoft.ML.Data.DataKind.String> or [text](xref:Microsoft.ML.Data.TextDataViewType) | Not defined | Always `false` |
|
||||
/// | [Key](xref:Microsoft.ML.Data.KeyDataViewType) type (supported by the unsigned integer types in `DataKind`) | `0` | Equality test with `0` |
|
||||
/// | <xref:Microsoft.ML.Data.DataKind.Single> | <xref:System.Single.NaN> | <xref:System.Single.IsNaN(System.Single)> |
|
||||
/// | <xref:Microsoft.ML.Data.DataKind.Double> | <xref:System.Double.NaN> | <xref:System.Double.IsNaN(System.Double)> |
|
||||
/// | All other types | Not defined | Always `false` |
|
||||
///
|
||||
/// ]]>
|
||||
/// </format>
|
||||
/// </remarks>
|
||||
// Data type specifiers mainly used in creating text loader and type converter.
|
||||
public enum DataKind : byte
|
||||
{
|
||||
/// <summary>1-byte integer, type of <see cref="System.SByte"/>.</summary>
|
||||
SByte = 1,
|
||||
/// <summary>1-byte unsigned integer, type of <see cref="System.Byte"/>.</summary>
|
||||
Byte = 2,
|
||||
/// <summary>2-byte integer, type of <see cref="System.Int16"/>.</summary>
|
||||
Int16 = 3,
|
||||
/// <summary>2-byte unsigned integer, type of <see cref="System.UInt16"/>.</summary>
|
||||
UInt16 = 4,
|
||||
/// <summary>4-byte integer, type of <see cref="System.Int32"/>.</summary>
|
||||
Int32 = 5,
|
||||
/// <summary>4-byte unsigned integer, type of <see cref="System.UInt32"/>.</summary>
|
||||
UInt32 = 6,
|
||||
/// <summary>8-byte integer, type of <see cref="System.Int64"/>.</summary>
|
||||
Int64 = 7,
|
||||
/// <summary>8-byte unsigned integer, type of <see cref="System.UInt64"/>.</summary>
|
||||
UInt64 = 8,
|
||||
/// <summary>4-byte floating-point number, type of <see cref="System.Single"/>.</summary>
|
||||
Single = 9,
|
||||
/// <summary>8-byte floating-point number, type of <see cref="System.Double"/>.</summary>
|
||||
Double = 10,
|
||||
/// <summary>
|
||||
/// Specifies a simple data type.
|
||||
/// string, type of <see cref="System.ReadOnlyMemory{T}"/>, where T is <see cref="char"/>.
|
||||
/// Also compatible with <see cref="System.String"/>.
|
||||
/// </summary>
|
||||
/// <remarks>
|
||||
/// <format type="text/markdown"><![CDATA[
|
||||
/// Some transforms use the default value and/or missing value of the data types.
|
||||
/// The table below shows the default value definition for each of the data types.
|
||||
///
|
||||
/// | Type | Default Value | IsDefault Indicator |
|
||||
/// | -- | -- | -- |
|
||||
/// | <xref:Microsoft.ML.Data.DataKind.String> or [text](xref:Microsoft.ML.Data.TextDataViewType) | Empty or `null` string (both result in empty `System.ReadOnlyMemory<char>` | <xref:System.ReadOnlyMemory`1.IsEmpty*> |
|
||||
/// | [Key](xref:Microsoft.ML.Data.KeyDataViewType) type (supported by the unsigned integer types in `DataKind`) | Not defined | Always `false` |
|
||||
/// | All other types | Default value of the corresponding system type as defined by .NET standard. In C#, default value expression `default(T)` provides that value. | Equality test with the default value |
|
||||
///
|
||||
/// The table below shows the missing value definition for each of the data types.
|
||||
///
|
||||
/// | Type | Missing Value | IsMissing Indicator |
|
||||
/// | -- | -- | -- |
|
||||
/// | <xref:Microsoft.ML.Data.DataKind.String> or [text](xref:Microsoft.ML.Data.TextDataViewType) | Not defined | Always `false` |
|
||||
/// | [Key](xref:Microsoft.ML.Data.KeyDataViewType) type (supported by the unsigned integer types in `DataKind`) | `0` | Equality test with `0` |
|
||||
/// | <xref:Microsoft.ML.Data.DataKind.Single> | <xref:System.Single.NaN> | <xref:System.Single.IsNaN(System.Single)> |
|
||||
/// | <xref:Microsoft.ML.Data.DataKind.Double> | <xref:System.Double.NaN> | <xref:System.Double.IsNaN(System.Double)> |
|
||||
/// | All other types | Not defined | Always `false` |
|
||||
///
|
||||
/// ]]>
|
||||
/// </format>
|
||||
/// </remarks>
|
||||
// Data type specifiers mainly used in creating text loader and type converter.
|
||||
public enum DataKind : byte
|
||||
{
|
||||
/// <summary>1-byte integer, type of <see cref="System.SByte"/>.</summary>
|
||||
SByte = 1,
|
||||
/// <summary>1-byte unsigned integer, type of <see cref="System.Byte"/>.</summary>
|
||||
Byte = 2,
|
||||
/// <summary>2-byte integer, type of <see cref="System.Int16"/>.</summary>
|
||||
Int16 = 3,
|
||||
/// <summary>2-byte unsigned integer, type of <see cref="System.UInt16"/>.</summary>
|
||||
UInt16 = 4,
|
||||
/// <summary>4-byte integer, type of <see cref="System.Int32"/>.</summary>
|
||||
Int32 = 5,
|
||||
/// <summary>4-byte unsigned integer, type of <see cref="System.UInt32"/>.</summary>
|
||||
UInt32 = 6,
|
||||
/// <summary>8-byte integer, type of <see cref="System.Int64"/>.</summary>
|
||||
Int64 = 7,
|
||||
/// <summary>8-byte unsigned integer, type of <see cref="System.UInt64"/>.</summary>
|
||||
UInt64 = 8,
|
||||
/// <summary>4-byte floating-point number, type of <see cref="System.Single"/>.</summary>
|
||||
Single = 9,
|
||||
/// <summary>8-byte floating-point number, type of <see cref="System.Double"/>.</summary>
|
||||
Double = 10,
|
||||
/// <summary>
|
||||
/// string, type of <see cref="System.ReadOnlyMemory{T}"/>, where T is <see cref="char"/>.
|
||||
/// Also compatible with <see cref="System.String"/>.
|
||||
/// </summary>
|
||||
String = 11,
|
||||
/// <summary>boolean variable type, type of <see cref="System.Boolean"/>.</summary>
|
||||
Boolean = 12,
|
||||
/// <summary>type of <see cref="System.TimeSpan"/>.</summary>
|
||||
TimeSpan = 13,
|
||||
/// <summary>type of <see cref="System.DateTime"/>.</summary>
|
||||
DateTime = 14,
|
||||
/// <summary>type of <see cref="System.DateTimeOffset"/>.</summary>
|
||||
DateTimeOffset = 15,
|
||||
}
|
||||
String = 11,
|
||||
/// <summary>boolean variable type, type of <see cref="System.Boolean"/>.</summary>
|
||||
Boolean = 12,
|
||||
/// <summary>type of <see cref="System.TimeSpan"/>.</summary>
|
||||
TimeSpan = 13,
|
||||
/// <summary>type of <see cref="System.DateTime"/>.</summary>
|
||||
DateTime = 14,
|
||||
/// <summary>type of <see cref="System.DateTimeOffset"/>.</summary>
|
||||
DateTimeOffset = 15,
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Data type specifier used in command line. <see cref="InternalDataKind"/> is the underlying version of <see cref="DataKind"/>
|
||||
/// used for command line and entry point BC.
|
||||
/// </summary>
|
||||
[BestFriend]
|
||||
internal enum InternalDataKind : byte
|
||||
{
|
||||
// Notes:
|
||||
// * These values are serialized, so changing them breaks binary formats.
|
||||
// * We intentionally skip zero.
|
||||
// * Some code depends on sizeof(DataKind) == sizeof(byte).
|
||||
/// <summary>
|
||||
/// Data type specifier used in command line. <see cref="InternalDataKind"/> is the underlying version of <see cref="DataKind"/>
|
||||
/// used for command line and entry point BC.
|
||||
/// </summary>
|
||||
[BestFriend]
|
||||
internal enum InternalDataKind : byte
|
||||
{
|
||||
// Notes:
|
||||
// * These values are serialized, so changing them breaks binary formats.
|
||||
// * We intentionally skip zero.
|
||||
// * Some code depends on sizeof(DataKind) == sizeof(byte).
|
||||
|
||||
I1 = DataKind.SByte,
|
||||
U1 = DataKind.Byte,
|
||||
I2 = DataKind.Int16,
|
||||
U2 = DataKind.UInt16,
|
||||
I4 = DataKind.Int32,
|
||||
U4 = DataKind.UInt32,
|
||||
I8 = DataKind.Int64,
|
||||
U8 = DataKind.UInt64,
|
||||
R4 = DataKind.Single,
|
||||
R8 = DataKind.Double,
|
||||
Num = R4,
|
||||
I1 = DataKind.SByte,
|
||||
U1 = DataKind.Byte,
|
||||
I2 = DataKind.Int16,
|
||||
U2 = DataKind.UInt16,
|
||||
I4 = DataKind.Int32,
|
||||
U4 = DataKind.UInt32,
|
||||
I8 = DataKind.Int64,
|
||||
U8 = DataKind.UInt64,
|
||||
R4 = DataKind.Single,
|
||||
R8 = DataKind.Double,
|
||||
Num = R4,
|
||||
|
||||
TX = DataKind.String,
|
||||
TX = DataKind.String,
|
||||
#pragma warning disable MSML_GeneralName // The data kind enum has its own logic, independent of C# naming conventions.
|
||||
TXT = TX,
|
||||
Text = TX,
|
||||
TXT = TX,
|
||||
Text = TX,
|
||||
|
||||
BL = DataKind.Boolean,
|
||||
Bool = BL,
|
||||
BL = DataKind.Boolean,
|
||||
Bool = BL,
|
||||
|
||||
TS = DataKind.TimeSpan,
|
||||
TimeSpan = TS,
|
||||
DT = DataKind.DateTime,
|
||||
DateTime = DT,
|
||||
DZ = DataKind.DateTimeOffset,
|
||||
DateTimeZone = DZ,
|
||||
TS = DataKind.TimeSpan,
|
||||
TimeSpan = TS,
|
||||
DT = DataKind.DateTime,
|
||||
DateTime = DT,
|
||||
DZ = DataKind.DateTimeOffset,
|
||||
DateTimeZone = DZ,
|
||||
|
||||
UG = 16, // Unsigned 16-byte integer.
|
||||
U16 = UG,
|
||||
UG = 16, // Unsigned 16-byte integer.
|
||||
U16 = UG,
|
||||
#pragma warning restore MSML_GeneralName
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Extension methods related to the DataKind enum.
|
||||
/// </summary>
|
||||
[BestFriend]
|
||||
internal static class InternalDataKindExtensions
|
||||
{
|
||||
public const InternalDataKind KindMin = InternalDataKind.I1;
|
||||
public const InternalDataKind KindLim = InternalDataKind.U16 + 1;
|
||||
public const int KindCount = KindLim - KindMin;
|
||||
|
||||
/// <summary>
|
||||
/// Maps a DataKind to a value suitable for indexing into an array of size KindCount.
|
||||
/// </summary>
|
||||
public static int ToIndex(this InternalDataKind kind)
|
||||
{
|
||||
return kind - KindMin;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Extension methods related to the DataKind enum.
|
||||
/// Maps from an index into an array of size KindCount to the corresponding DataKind
|
||||
/// </summary>
|
||||
[BestFriend]
|
||||
internal static class InternalDataKindExtensions
|
||||
public static InternalDataKind FromIndex(int index)
|
||||
{
|
||||
public const InternalDataKind KindMin = InternalDataKind.I1;
|
||||
public const InternalDataKind KindLim = InternalDataKind.U16 + 1;
|
||||
public const int KindCount = KindLim - KindMin;
|
||||
Contracts.Check(0 <= index && index < KindCount);
|
||||
return (InternalDataKind)(index + (int)KindMin);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Maps a DataKind to a value suitable for indexing into an array of size KindCount.
|
||||
/// </summary>
|
||||
public static int ToIndex(this InternalDataKind kind)
|
||||
/// <summary>
|
||||
/// This function converts <paramref name="dataKind"/> to <see cref="InternalDataKind"/>.
|
||||
/// Because <see cref="DataKind"/> is a subset of <see cref="InternalDataKind"/>, the conversion is straightforward.
|
||||
/// </summary>
|
||||
public static InternalDataKind ToInternalDataKind(this DataKind dataKind) => (InternalDataKind)dataKind;
|
||||
|
||||
/// <summary>
|
||||
/// This function converts <paramref name="kind"/> to <see cref="DataKind"/>.
|
||||
/// Because <see cref="DataKind"/> is a subset of <see cref="InternalDataKind"/>, we should check if <paramref name="kind"/>
|
||||
/// can be found in <see cref="DataKind"/>.
|
||||
/// </summary>
|
||||
public static DataKind ToDataKind(this InternalDataKind kind)
|
||||
{
|
||||
Contracts.Check(kind != InternalDataKind.UG);
|
||||
return (DataKind)kind;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// For integer DataKinds, this returns the maximum legal value. For un-supported kinds,
|
||||
/// it returns zero.
|
||||
/// </summary>
|
||||
public static ulong ToMaxInt(this InternalDataKind kind)
|
||||
{
|
||||
switch (kind)
|
||||
{
|
||||
return kind - KindMin;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Maps from an index into an array of size KindCount to the corresponding DataKind
|
||||
/// </summary>
|
||||
public static InternalDataKind FromIndex(int index)
|
||||
{
|
||||
Contracts.Check(0 <= index && index < KindCount);
|
||||
return (InternalDataKind)(index + (int)KindMin);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// This function converts <paramref name="dataKind"/> to <see cref="InternalDataKind"/>.
|
||||
/// Because <see cref="DataKind"/> is a subset of <see cref="InternalDataKind"/>, the conversion is straightforward.
|
||||
/// </summary>
|
||||
public static InternalDataKind ToInternalDataKind(this DataKind dataKind) => (InternalDataKind)dataKind;
|
||||
|
||||
/// <summary>
|
||||
/// This function converts <paramref name="kind"/> to <see cref="DataKind"/>.
|
||||
/// Because <see cref="DataKind"/> is a subset of <see cref="InternalDataKind"/>, we should check if <paramref name="kind"/>
|
||||
/// can be found in <see cref="DataKind"/>.
|
||||
/// </summary>
|
||||
public static DataKind ToDataKind(this InternalDataKind kind)
|
||||
{
|
||||
Contracts.Check(kind != InternalDataKind.UG);
|
||||
return (DataKind)kind;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// For integer DataKinds, this returns the maximum legal value. For un-supported kinds,
|
||||
/// it returns zero.
|
||||
/// </summary>
|
||||
public static ulong ToMaxInt(this InternalDataKind kind)
|
||||
{
|
||||
switch (kind)
|
||||
{
|
||||
case InternalDataKind.I1:
|
||||
return (ulong)sbyte.MaxValue;
|
||||
case InternalDataKind.U1:
|
||||
return byte.MaxValue;
|
||||
case InternalDataKind.I2:
|
||||
return (ulong)short.MaxValue;
|
||||
case InternalDataKind.U2:
|
||||
return ushort.MaxValue;
|
||||
case InternalDataKind.I4:
|
||||
return int.MaxValue;
|
||||
case InternalDataKind.U4:
|
||||
return uint.MaxValue;
|
||||
case InternalDataKind.I8:
|
||||
return long.MaxValue;
|
||||
case InternalDataKind.U8:
|
||||
return ulong.MaxValue;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// For integer Types, this returns the maximum legal value. For un-supported Types,
|
||||
/// it returns zero.
|
||||
/// </summary>
|
||||
public static ulong ToMaxInt(this Type type)
|
||||
{
|
||||
if (type == typeof(sbyte))
|
||||
case InternalDataKind.I1:
|
||||
return (ulong)sbyte.MaxValue;
|
||||
else if (type == typeof(byte))
|
||||
case InternalDataKind.U1:
|
||||
return byte.MaxValue;
|
||||
else if (type == typeof(short))
|
||||
case InternalDataKind.I2:
|
||||
return (ulong)short.MaxValue;
|
||||
else if (type == typeof(ushort))
|
||||
case InternalDataKind.U2:
|
||||
return ushort.MaxValue;
|
||||
else if (type == typeof(int))
|
||||
case InternalDataKind.I4:
|
||||
return int.MaxValue;
|
||||
else if (type == typeof(uint))
|
||||
case InternalDataKind.U4:
|
||||
return uint.MaxValue;
|
||||
else if (type == typeof(long))
|
||||
case InternalDataKind.I8:
|
||||
return long.MaxValue;
|
||||
else if (type == typeof(ulong))
|
||||
case InternalDataKind.U8:
|
||||
return ulong.MaxValue;
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// For integer DataKinds, this returns the minimum legal value. For un-supported kinds,
|
||||
/// it returns one.
|
||||
/// </summary>
|
||||
public static long ToMinInt(this InternalDataKind kind)
|
||||
return 0;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// For integer Types, this returns the maximum legal value. For un-supported Types,
|
||||
/// it returns zero.
|
||||
/// </summary>
|
||||
public static ulong ToMaxInt(this Type type)
|
||||
{
|
||||
if (type == typeof(sbyte))
|
||||
return (ulong)sbyte.MaxValue;
|
||||
else if (type == typeof(byte))
|
||||
return byte.MaxValue;
|
||||
else if (type == typeof(short))
|
||||
return (ulong)short.MaxValue;
|
||||
else if (type == typeof(ushort))
|
||||
return ushort.MaxValue;
|
||||
else if (type == typeof(int))
|
||||
return int.MaxValue;
|
||||
else if (type == typeof(uint))
|
||||
return uint.MaxValue;
|
||||
else if (type == typeof(long))
|
||||
return long.MaxValue;
|
||||
else if (type == typeof(ulong))
|
||||
return ulong.MaxValue;
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// For integer DataKinds, this returns the minimum legal value. For un-supported kinds,
|
||||
/// it returns one.
|
||||
/// </summary>
|
||||
public static long ToMinInt(this InternalDataKind kind)
|
||||
{
|
||||
switch (kind)
|
||||
{
|
||||
switch (kind)
|
||||
{
|
||||
case InternalDataKind.I1:
|
||||
return sbyte.MinValue;
|
||||
case InternalDataKind.U1:
|
||||
return byte.MinValue;
|
||||
case InternalDataKind.I2:
|
||||
return short.MinValue;
|
||||
case InternalDataKind.U2:
|
||||
return ushort.MinValue;
|
||||
case InternalDataKind.I4:
|
||||
return int.MinValue;
|
||||
case InternalDataKind.U4:
|
||||
return uint.MinValue;
|
||||
case InternalDataKind.I8:
|
||||
return long.MinValue;
|
||||
case InternalDataKind.U8:
|
||||
return 0;
|
||||
}
|
||||
|
||||
return 1;
|
||||
case InternalDataKind.I1:
|
||||
return sbyte.MinValue;
|
||||
case InternalDataKind.U1:
|
||||
return byte.MinValue;
|
||||
case InternalDataKind.I2:
|
||||
return short.MinValue;
|
||||
case InternalDataKind.U2:
|
||||
return ushort.MinValue;
|
||||
case InternalDataKind.I4:
|
||||
return int.MinValue;
|
||||
case InternalDataKind.U4:
|
||||
return uint.MinValue;
|
||||
case InternalDataKind.I8:
|
||||
return long.MinValue;
|
||||
case InternalDataKind.U8:
|
||||
return 0;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Maps a DataKind to the associated .Net representation type.
|
||||
/// </summary>
|
||||
public static Type ToType(this InternalDataKind kind)
|
||||
return 1;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Maps a DataKind to the associated .Net representation type.
|
||||
/// </summary>
|
||||
public static Type ToType(this InternalDataKind kind)
|
||||
{
|
||||
switch (kind)
|
||||
{
|
||||
switch (kind)
|
||||
{
|
||||
case InternalDataKind.I1:
|
||||
return typeof(sbyte);
|
||||
case InternalDataKind.U1:
|
||||
return typeof(byte);
|
||||
case InternalDataKind.I2:
|
||||
return typeof(short);
|
||||
case InternalDataKind.U2:
|
||||
return typeof(ushort);
|
||||
case InternalDataKind.I4:
|
||||
return typeof(int);
|
||||
case InternalDataKind.U4:
|
||||
return typeof(uint);
|
||||
case InternalDataKind.I8:
|
||||
return typeof(long);
|
||||
case InternalDataKind.U8:
|
||||
return typeof(ulong);
|
||||
case InternalDataKind.R4:
|
||||
return typeof(Single);
|
||||
case InternalDataKind.R8:
|
||||
return typeof(Double);
|
||||
case InternalDataKind.TX:
|
||||
return typeof(ReadOnlyMemory<char>);
|
||||
case InternalDataKind.BL:
|
||||
return typeof(bool);
|
||||
case InternalDataKind.TS:
|
||||
return typeof(TimeSpan);
|
||||
case InternalDataKind.DT:
|
||||
return typeof(DateTime);
|
||||
case InternalDataKind.DZ:
|
||||
return typeof(DateTimeOffset);
|
||||
case InternalDataKind.UG:
|
||||
return typeof(DataViewRowId);
|
||||
}
|
||||
|
||||
return null;
|
||||
case InternalDataKind.I1:
|
||||
return typeof(sbyte);
|
||||
case InternalDataKind.U1:
|
||||
return typeof(byte);
|
||||
case InternalDataKind.I2:
|
||||
return typeof(short);
|
||||
case InternalDataKind.U2:
|
||||
return typeof(ushort);
|
||||
case InternalDataKind.I4:
|
||||
return typeof(int);
|
||||
case InternalDataKind.U4:
|
||||
return typeof(uint);
|
||||
case InternalDataKind.I8:
|
||||
return typeof(long);
|
||||
case InternalDataKind.U8:
|
||||
return typeof(ulong);
|
||||
case InternalDataKind.R4:
|
||||
return typeof(Single);
|
||||
case InternalDataKind.R8:
|
||||
return typeof(Double);
|
||||
case InternalDataKind.TX:
|
||||
return typeof(ReadOnlyMemory<char>);
|
||||
case InternalDataKind.BL:
|
||||
return typeof(bool);
|
||||
case InternalDataKind.TS:
|
||||
return typeof(TimeSpan);
|
||||
case InternalDataKind.DT:
|
||||
return typeof(DateTime);
|
||||
case InternalDataKind.DZ:
|
||||
return typeof(DateTimeOffset);
|
||||
case InternalDataKind.UG:
|
||||
return typeof(DataViewRowId);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Try to map a System.Type to a corresponding DataKind value.
|
||||
/// </summary>
|
||||
public static bool TryGetDataKind(this Type type, out InternalDataKind kind)
|
||||
return null;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Try to map a System.Type to a corresponding DataKind value.
|
||||
/// </summary>
|
||||
public static bool TryGetDataKind(this Type type, out InternalDataKind kind)
|
||||
{
|
||||
Contracts.CheckValueOrNull(type);
|
||||
|
||||
// REVIEW: Make this more efficient. Should we have a global dictionary?
|
||||
if (type == typeof(sbyte))
|
||||
kind = InternalDataKind.I1;
|
||||
else if (type == typeof(byte))
|
||||
kind = InternalDataKind.U1;
|
||||
else if (type == typeof(short))
|
||||
kind = InternalDataKind.I2;
|
||||
else if (type == typeof(ushort))
|
||||
kind = InternalDataKind.U2;
|
||||
else if (type == typeof(int))
|
||||
kind = InternalDataKind.I4;
|
||||
else if (type == typeof(uint))
|
||||
kind = InternalDataKind.U4;
|
||||
else if (type == typeof(long))
|
||||
kind = InternalDataKind.I8;
|
||||
else if (type == typeof(ulong))
|
||||
kind = InternalDataKind.U8;
|
||||
else if (type == typeof(Single))
|
||||
kind = InternalDataKind.R4;
|
||||
else if (type == typeof(Double))
|
||||
kind = InternalDataKind.R8;
|
||||
else if (type == typeof(ReadOnlyMemory<char>) || type == typeof(string))
|
||||
kind = InternalDataKind.TX;
|
||||
else if (type == typeof(bool))
|
||||
kind = InternalDataKind.BL;
|
||||
else if (type == typeof(TimeSpan))
|
||||
kind = InternalDataKind.TS;
|
||||
else if (type == typeof(DateTime))
|
||||
kind = InternalDataKind.DT;
|
||||
else if (type == typeof(DateTimeOffset))
|
||||
kind = InternalDataKind.DZ;
|
||||
else if (type == typeof(DataViewRowId))
|
||||
kind = InternalDataKind.UG;
|
||||
else
|
||||
{
|
||||
Contracts.CheckValueOrNull(type);
|
||||
|
||||
// REVIEW: Make this more efficient. Should we have a global dictionary?
|
||||
if (type == typeof(sbyte))
|
||||
kind = InternalDataKind.I1;
|
||||
else if (type == typeof(byte))
|
||||
kind = InternalDataKind.U1;
|
||||
else if (type == typeof(short))
|
||||
kind = InternalDataKind.I2;
|
||||
else if (type == typeof(ushort))
|
||||
kind = InternalDataKind.U2;
|
||||
else if (type == typeof(int))
|
||||
kind = InternalDataKind.I4;
|
||||
else if (type == typeof(uint))
|
||||
kind = InternalDataKind.U4;
|
||||
else if (type == typeof(long))
|
||||
kind = InternalDataKind.I8;
|
||||
else if (type == typeof(ulong))
|
||||
kind = InternalDataKind.U8;
|
||||
else if (type == typeof(Single))
|
||||
kind = InternalDataKind.R4;
|
||||
else if (type == typeof(Double))
|
||||
kind = InternalDataKind.R8;
|
||||
else if (type == typeof(ReadOnlyMemory<char>) || type == typeof(string))
|
||||
kind = InternalDataKind.TX;
|
||||
else if (type == typeof(bool))
|
||||
kind = InternalDataKind.BL;
|
||||
else if (type == typeof(TimeSpan))
|
||||
kind = InternalDataKind.TS;
|
||||
else if (type == typeof(DateTime))
|
||||
kind = InternalDataKind.DT;
|
||||
else if (type == typeof(DateTimeOffset))
|
||||
kind = InternalDataKind.DZ;
|
||||
else if (type == typeof(DataViewRowId))
|
||||
kind = InternalDataKind.UG;
|
||||
else
|
||||
{
|
||||
kind = default(InternalDataKind);
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
kind = default(InternalDataKind);
|
||||
return false;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Get the canonical string for a DataKind. Note that using DataKind.ToString() is not stable
|
||||
/// and is also slow, so use this instead.
|
||||
/// </summary>
|
||||
public static string GetString(this InternalDataKind kind)
|
||||
return true;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Get the canonical string for a DataKind. Note that using DataKind.ToString() is not stable
|
||||
/// and is also slow, so use this instead.
|
||||
/// </summary>
|
||||
public static string GetString(this InternalDataKind kind)
|
||||
{
|
||||
switch (kind)
|
||||
{
|
||||
switch (kind)
|
||||
{
|
||||
case InternalDataKind.I1:
|
||||
return "I1";
|
||||
case InternalDataKind.I2:
|
||||
return "I2";
|
||||
case InternalDataKind.I4:
|
||||
return "I4";
|
||||
case InternalDataKind.I8:
|
||||
return "I8";
|
||||
case InternalDataKind.U1:
|
||||
return "U1";
|
||||
case InternalDataKind.U2:
|
||||
return "U2";
|
||||
case InternalDataKind.U4:
|
||||
return "U4";
|
||||
case InternalDataKind.U8:
|
||||
return "U8";
|
||||
case InternalDataKind.R4:
|
||||
return "R4";
|
||||
case InternalDataKind.R8:
|
||||
return "R8";
|
||||
case InternalDataKind.BL:
|
||||
return "BL";
|
||||
case InternalDataKind.TX:
|
||||
return "TX";
|
||||
case InternalDataKind.TS:
|
||||
return "TS";
|
||||
case InternalDataKind.DT:
|
||||
return "DT";
|
||||
case InternalDataKind.DZ:
|
||||
return "DZ";
|
||||
case InternalDataKind.UG:
|
||||
return "UG";
|
||||
}
|
||||
return "";
|
||||
case InternalDataKind.I1:
|
||||
return "I1";
|
||||
case InternalDataKind.I2:
|
||||
return "I2";
|
||||
case InternalDataKind.I4:
|
||||
return "I4";
|
||||
case InternalDataKind.I8:
|
||||
return "I8";
|
||||
case InternalDataKind.U1:
|
||||
return "U1";
|
||||
case InternalDataKind.U2:
|
||||
return "U2";
|
||||
case InternalDataKind.U4:
|
||||
return "U4";
|
||||
case InternalDataKind.U8:
|
||||
return "U8";
|
||||
case InternalDataKind.R4:
|
||||
return "R4";
|
||||
case InternalDataKind.R8:
|
||||
return "R8";
|
||||
case InternalDataKind.BL:
|
||||
return "BL";
|
||||
case InternalDataKind.TX:
|
||||
return "TX";
|
||||
case InternalDataKind.TS:
|
||||
return "TS";
|
||||
case InternalDataKind.DT:
|
||||
return "DT";
|
||||
case InternalDataKind.DZ:
|
||||
return "DZ";
|
||||
case InternalDataKind.UG:
|
||||
return "UG";
|
||||
}
|
||||
return "";
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2,17 +2,16 @@
|
|||
// The .NET Foundation licenses this file to you under the MIT license.
|
||||
// See the LICENSE file in the project root for more information.
|
||||
|
||||
namespace Microsoft.ML.Command
|
||||
{
|
||||
/// <summary>
|
||||
/// The signature for commands.
|
||||
/// </summary>
|
||||
[BestFriend]
|
||||
internal delegate void SignatureCommand();
|
||||
namespace Microsoft.ML.Command;
|
||||
|
||||
[BestFriend]
|
||||
internal interface ICommand
|
||||
{
|
||||
void Run();
|
||||
}
|
||||
/// <summary>
|
||||
/// The signature for commands.
|
||||
/// </summary>
|
||||
[BestFriend]
|
||||
internal delegate void SignatureCommand();
|
||||
|
||||
[BestFriend]
|
||||
internal interface ICommand
|
||||
{
|
||||
void Run();
|
||||
}
|
||||
|
|
|
@ -8,315 +8,314 @@ using System.Linq;
|
|||
using Microsoft.ML.Data;
|
||||
using Microsoft.ML.Runtime;
|
||||
|
||||
namespace Microsoft.ML
|
||||
namespace Microsoft.ML;
|
||||
|
||||
/// <summary>
|
||||
/// A set of 'requirements' to the incoming schema, as well as a set of 'promises' of the outgoing schema.
|
||||
/// This is more relaxed than the proper <see cref="DataViewSchema"/>, since it's only a subset of the columns,
|
||||
/// and also since it doesn't specify exact <see cref="DataViewType"/>'s for vectors and keys.
|
||||
/// </summary>
|
||||
public sealed class SchemaShape : IReadOnlyList<SchemaShape.Column>
|
||||
{
|
||||
/// <summary>
|
||||
/// A set of 'requirements' to the incoming schema, as well as a set of 'promises' of the outgoing schema.
|
||||
/// This is more relaxed than the proper <see cref="DataViewSchema"/>, since it's only a subset of the columns,
|
||||
/// and also since it doesn't specify exact <see cref="DataViewType"/>'s for vectors and keys.
|
||||
/// </summary>
|
||||
public sealed class SchemaShape : IReadOnlyList<SchemaShape.Column>
|
||||
private readonly Column[] _columns;
|
||||
|
||||
private static readonly SchemaShape _empty = new SchemaShape(Enumerable.Empty<Column>());
|
||||
|
||||
public int Count => _columns.Count();
|
||||
|
||||
public Column this[int index] => _columns[index];
|
||||
|
||||
public struct Column
|
||||
{
|
||||
private readonly Column[] _columns;
|
||||
|
||||
private static readonly SchemaShape _empty = new SchemaShape(Enumerable.Empty<Column>());
|
||||
|
||||
public int Count => _columns.Count();
|
||||
|
||||
public Column this[int index] => _columns[index];
|
||||
|
||||
public struct Column
|
||||
public enum VectorKind
|
||||
{
|
||||
public enum VectorKind
|
||||
{
|
||||
Scalar,
|
||||
Vector,
|
||||
VariableVector
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// The column name.
|
||||
/// </summary>
|
||||
public readonly string Name;
|
||||
|
||||
/// <summary>
|
||||
/// The type of the column: scalar, fixed vector or variable vector.
|
||||
/// </summary>
|
||||
public readonly VectorKind Kind;
|
||||
|
||||
/// <summary>
|
||||
/// The 'raw' type of column item: must be a primitive type or a structured type.
|
||||
/// </summary>
|
||||
public readonly DataViewType ItemType;
|
||||
/// <summary>
|
||||
/// The flag whether the column is actually a key. If yes, <see cref="ItemType"/> is representing
|
||||
/// the underlying primitive type.
|
||||
/// </summary>
|
||||
public readonly bool IsKey;
|
||||
/// <summary>
|
||||
/// The annotations that are present for this column.
|
||||
/// </summary>
|
||||
public readonly SchemaShape Annotations;
|
||||
|
||||
[BestFriend]
|
||||
internal Column(string name, VectorKind vecKind, DataViewType itemType, bool isKey, SchemaShape annotations = null)
|
||||
{
|
||||
Contracts.CheckNonEmpty(name, nameof(name));
|
||||
Contracts.CheckValueOrNull(annotations);
|
||||
Contracts.CheckParam(!(itemType is KeyDataViewType), nameof(itemType), "Item type cannot be a key");
|
||||
Contracts.CheckParam(!(itemType is VectorDataViewType), nameof(itemType), "Item type cannot be a vector");
|
||||
Contracts.CheckParam(!isKey || KeyDataViewType.IsValidDataType(itemType.RawType), nameof(itemType), "The item type must be valid for a key");
|
||||
|
||||
Name = name;
|
||||
Kind = vecKind;
|
||||
ItemType = itemType;
|
||||
IsKey = isKey;
|
||||
Annotations = annotations ?? _empty;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Returns whether <paramref name="source"/> is a valid input, if this object represents a
|
||||
/// requirement.
|
||||
///
|
||||
/// Namely, it returns true iff:
|
||||
/// - The <see cref="Name"/>, <see cref="Kind"/>, <see cref="ItemType"/>, <see cref="IsKey"/> fields match.
|
||||
/// - The columns of <see cref="Annotations"/> of <paramref name="source"/> is a superset of our <see cref="Annotations"/> columns.
|
||||
/// - Each such annotation column is itself compatible with the input annotation column.
|
||||
/// </summary>
|
||||
[BestFriend]
|
||||
internal bool IsCompatibleWith(Column source)
|
||||
{
|
||||
Contracts.Check(source.IsValid, nameof(source));
|
||||
if (Name != source.Name)
|
||||
return false;
|
||||
if (Kind != source.Kind)
|
||||
return false;
|
||||
if (!ItemType.Equals(source.ItemType))
|
||||
return false;
|
||||
if (IsKey != source.IsKey)
|
||||
return false;
|
||||
foreach (var annotationCol in Annotations)
|
||||
{
|
||||
if (!source.Annotations.TryFindColumn(annotationCol.Name, out var inputAnnotationCol))
|
||||
return false;
|
||||
if (!annotationCol.IsCompatibleWith(inputAnnotationCol))
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
[BestFriend]
|
||||
internal string GetTypeString()
|
||||
{
|
||||
string result = ItemType.ToString();
|
||||
if (IsKey)
|
||||
result = $"Key<{result}>";
|
||||
if (Kind == VectorKind.Vector)
|
||||
result = $"Vector<{result}>";
|
||||
else if (Kind == VectorKind.VariableVector)
|
||||
result = $"VarVector<{result}>";
|
||||
return result;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Return if this structure is not identical to the default value of <see cref="Column"/>. If true,
|
||||
/// it means this structure is initialized properly and therefore considered as valid.
|
||||
/// </summary>
|
||||
[BestFriend]
|
||||
internal bool IsValid => Name != null;
|
||||
}
|
||||
|
||||
public SchemaShape(IEnumerable<Column> columns)
|
||||
{
|
||||
Contracts.CheckValue(columns, nameof(columns));
|
||||
_columns = columns.ToArray();
|
||||
Contracts.CheckParam(columns.All(c => c.IsValid), nameof(columns), "Some items are not initialized properly.");
|
||||
Scalar,
|
||||
Vector,
|
||||
VariableVector
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Given a <paramref name="type"/>, extract the type parameters that describe this type
|
||||
/// as a <see cref="SchemaShape"/>'s column type.
|
||||
/// The column name.
|
||||
/// </summary>
|
||||
/// <param name="type">The actual column type to process.</param>
|
||||
/// <param name="vecKind">The vector kind of <paramref name="type"/>.</param>
|
||||
/// <param name="itemType">The item type of <paramref name="type"/>.</param>
|
||||
/// <param name="isKey">Whether <paramref name="type"/> (or its item type) is a key.</param>
|
||||
[BestFriend]
|
||||
internal static void GetColumnTypeShape(DataViewType type,
|
||||
out Column.VectorKind vecKind,
|
||||
out DataViewType itemType,
|
||||
out bool isKey)
|
||||
{
|
||||
if (type is VectorDataViewType vectorType)
|
||||
{
|
||||
if (vectorType.IsKnownSize)
|
||||
{
|
||||
vecKind = Column.VectorKind.Vector;
|
||||
}
|
||||
else
|
||||
{
|
||||
vecKind = Column.VectorKind.VariableVector;
|
||||
}
|
||||
public readonly string Name;
|
||||
|
||||
itemType = vectorType.ItemType;
|
||||
/// <summary>
|
||||
/// The type of the column: scalar, fixed vector or variable vector.
|
||||
/// </summary>
|
||||
public readonly VectorKind Kind;
|
||||
|
||||
/// <summary>
|
||||
/// The 'raw' type of column item: must be a primitive type or a structured type.
|
||||
/// </summary>
|
||||
public readonly DataViewType ItemType;
|
||||
/// <summary>
|
||||
/// The flag whether the column is actually a key. If yes, <see cref="ItemType"/> is representing
|
||||
/// the underlying primitive type.
|
||||
/// </summary>
|
||||
public readonly bool IsKey;
|
||||
/// <summary>
|
||||
/// The annotations that are present for this column.
|
||||
/// </summary>
|
||||
public readonly SchemaShape Annotations;
|
||||
|
||||
[BestFriend]
|
||||
internal Column(string name, VectorKind vecKind, DataViewType itemType, bool isKey, SchemaShape annotations = null)
|
||||
{
|
||||
Contracts.CheckNonEmpty(name, nameof(name));
|
||||
Contracts.CheckValueOrNull(annotations);
|
||||
Contracts.CheckParam(!(itemType is KeyDataViewType), nameof(itemType), "Item type cannot be a key");
|
||||
Contracts.CheckParam(!(itemType is VectorDataViewType), nameof(itemType), "Item type cannot be a vector");
|
||||
Contracts.CheckParam(!isKey || KeyDataViewType.IsValidDataType(itemType.RawType), nameof(itemType), "The item type must be valid for a key");
|
||||
|
||||
Name = name;
|
||||
Kind = vecKind;
|
||||
ItemType = itemType;
|
||||
IsKey = isKey;
|
||||
Annotations = annotations ?? _empty;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Returns whether <paramref name="source"/> is a valid input, if this object represents a
|
||||
/// requirement.
|
||||
///
|
||||
/// Namely, it returns true iff:
|
||||
/// - The <see cref="Name"/>, <see cref="Kind"/>, <see cref="ItemType"/>, <see cref="IsKey"/> fields match.
|
||||
/// - The columns of <see cref="Annotations"/> of <paramref name="source"/> is a superset of our <see cref="Annotations"/> columns.
|
||||
/// - Each such annotation column is itself compatible with the input annotation column.
|
||||
/// </summary>
|
||||
[BestFriend]
|
||||
internal bool IsCompatibleWith(Column source)
|
||||
{
|
||||
Contracts.Check(source.IsValid, nameof(source));
|
||||
if (Name != source.Name)
|
||||
return false;
|
||||
if (Kind != source.Kind)
|
||||
return false;
|
||||
if (!ItemType.Equals(source.ItemType))
|
||||
return false;
|
||||
if (IsKey != source.IsKey)
|
||||
return false;
|
||||
foreach (var annotationCol in Annotations)
|
||||
{
|
||||
if (!source.Annotations.TryFindColumn(annotationCol.Name, out var inputAnnotationCol))
|
||||
return false;
|
||||
if (!annotationCol.IsCompatibleWith(inputAnnotationCol))
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
[BestFriend]
|
||||
internal string GetTypeString()
|
||||
{
|
||||
string result = ItemType.ToString();
|
||||
if (IsKey)
|
||||
result = $"Key<{result}>";
|
||||
if (Kind == VectorKind.Vector)
|
||||
result = $"Vector<{result}>";
|
||||
else if (Kind == VectorKind.VariableVector)
|
||||
result = $"VarVector<{result}>";
|
||||
return result;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Return if this structure is not identical to the default value of <see cref="Column"/>. If true,
|
||||
/// it means this structure is initialized properly and therefore considered as valid.
|
||||
/// </summary>
|
||||
[BestFriend]
|
||||
internal bool IsValid => Name != null;
|
||||
}
|
||||
|
||||
public SchemaShape(IEnumerable<Column> columns)
|
||||
{
|
||||
Contracts.CheckValue(columns, nameof(columns));
|
||||
_columns = columns.ToArray();
|
||||
Contracts.CheckParam(columns.All(c => c.IsValid), nameof(columns), "Some items are not initialized properly.");
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Given a <paramref name="type"/>, extract the type parameters that describe this type
|
||||
/// as a <see cref="SchemaShape"/>'s column type.
|
||||
/// </summary>
|
||||
/// <param name="type">The actual column type to process.</param>
|
||||
/// <param name="vecKind">The vector kind of <paramref name="type"/>.</param>
|
||||
/// <param name="itemType">The item type of <paramref name="type"/>.</param>
|
||||
/// <param name="isKey">Whether <paramref name="type"/> (or its item type) is a key.</param>
|
||||
[BestFriend]
|
||||
internal static void GetColumnTypeShape(DataViewType type,
|
||||
out Column.VectorKind vecKind,
|
||||
out DataViewType itemType,
|
||||
out bool isKey)
|
||||
{
|
||||
if (type is VectorDataViewType vectorType)
|
||||
{
|
||||
if (vectorType.IsKnownSize)
|
||||
{
|
||||
vecKind = Column.VectorKind.Vector;
|
||||
}
|
||||
else
|
||||
{
|
||||
vecKind = Column.VectorKind.Scalar;
|
||||
itemType = type;
|
||||
vecKind = Column.VectorKind.VariableVector;
|
||||
}
|
||||
|
||||
isKey = itemType is KeyDataViewType;
|
||||
if (isKey)
|
||||
itemType = ColumnTypeExtensions.PrimitiveTypeFromType(itemType.RawType);
|
||||
itemType = vectorType.ItemType;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Create a schema shape out of the fully defined schema.
|
||||
/// </summary>
|
||||
[BestFriend]
|
||||
internal static SchemaShape Create(DataViewSchema schema)
|
||||
else
|
||||
{
|
||||
Contracts.CheckValue(schema, nameof(schema));
|
||||
var cols = new List<Column>();
|
||||
|
||||
for (int iCol = 0; iCol < schema.Count; iCol++)
|
||||
{
|
||||
if (!schema[iCol].IsHidden)
|
||||
{
|
||||
// First create the annotations.
|
||||
var mCols = new List<Column>();
|
||||
foreach (var annotationColumn in schema[iCol].Annotations.Schema)
|
||||
{
|
||||
GetColumnTypeShape(annotationColumn.Type, out var mVecKind, out var mItemType, out var mIsKey);
|
||||
mCols.Add(new Column(annotationColumn.Name, mVecKind, mItemType, mIsKey));
|
||||
}
|
||||
var annotations = mCols.Count > 0 ? new SchemaShape(mCols) : _empty;
|
||||
// Next create the single column.
|
||||
GetColumnTypeShape(schema[iCol].Type, out var vecKind, out var itemType, out var isKey);
|
||||
cols.Add(new Column(schema[iCol].Name, vecKind, itemType, isKey, annotations));
|
||||
}
|
||||
}
|
||||
return new SchemaShape(cols);
|
||||
vecKind = Column.VectorKind.Scalar;
|
||||
itemType = type;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Returns if there is a column with a specified <paramref name="name"/> and if so stores it in <paramref name="column"/>.
|
||||
/// </summary>
|
||||
[BestFriend]
|
||||
internal bool TryFindColumn(string name, out Column column)
|
||||
{
|
||||
Contracts.CheckValue(name, nameof(name));
|
||||
column = _columns.FirstOrDefault(x => x.Name == name);
|
||||
return column.IsValid;
|
||||
}
|
||||
|
||||
public IEnumerator<Column> GetEnumerator() => ((IEnumerable<Column>)_columns).GetEnumerator();
|
||||
|
||||
IEnumerator IEnumerable.GetEnumerator() => GetEnumerator();
|
||||
|
||||
// REVIEW: I think we should have an IsCompatible method to check if it's OK to use one schema shape
|
||||
// as an input to another schema shape. I started writing, but realized that there's more than one way to check for
|
||||
// the 'compatibility': as in, 'CAN be compatible' vs. 'WILL be compatible'.
|
||||
isKey = itemType is KeyDataViewType;
|
||||
if (isKey)
|
||||
itemType = ColumnTypeExtensions.PrimitiveTypeFromType(itemType.RawType);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// The 'data loader' takes a certain kind of input and turns it into an <see cref="IDataView"/>.
|
||||
/// Create a schema shape out of the fully defined schema.
|
||||
/// </summary>
|
||||
/// <typeparam name="TSource">The type of input the loader takes.</typeparam>
|
||||
public interface IDataLoader<in TSource> : ICanSaveModel
|
||||
{
|
||||
/// <summary>
|
||||
/// Produce the data view from the specified input.
|
||||
/// Note that <see cref="IDataView"/>'s are lazy, so no actual loading happens here, just schema validation.
|
||||
/// </summary>
|
||||
IDataView Load(TSource input);
|
||||
|
||||
/// <summary>
|
||||
/// The output schema of the loader.
|
||||
/// </summary>
|
||||
DataViewSchema GetOutputSchema();
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Sometimes we need to 'fit' an <see cref="IDataLoader{TIn}"/>.
|
||||
/// A DataLoader estimator is the object that does it.
|
||||
/// </summary>
|
||||
public interface IDataLoaderEstimator<in TSource, out TLoader>
|
||||
where TLoader : IDataLoader<TSource>
|
||||
{
|
||||
// REVIEW: you could consider the transformer to take a different <typeparamref name="TSource"/>, but we don't have such components
|
||||
// yet, so why complicate matters?
|
||||
|
||||
/// <summary>
|
||||
/// Train and return a data loader.
|
||||
/// </summary>
|
||||
TLoader Fit(TSource input);
|
||||
|
||||
/// <summary>
|
||||
/// The 'promise' of the output schema.
|
||||
/// It will be used for schema propagation.
|
||||
/// </summary>
|
||||
SchemaShape GetOutputSchema();
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// The transformer is a component that transforms data.
|
||||
/// It also supports 'schema propagation' to answer the question of 'how will the data with this schema look, after you transform it?'.
|
||||
/// </summary>
|
||||
public interface ITransformer : ICanSaveModel
|
||||
{
|
||||
/// <summary>
|
||||
/// Schema propagation for transformers.
|
||||
/// Returns the output schema of the data, if the input schema is like the one provided.
|
||||
/// </summary>
|
||||
DataViewSchema GetOutputSchema(DataViewSchema inputSchema);
|
||||
|
||||
/// <summary>
|
||||
/// Take the data in, make transformations, output the data.
|
||||
/// Note that <see cref="IDataView"/>'s are lazy, so no actual transformations happen here, just schema validation.
|
||||
/// </summary>
|
||||
IDataView Transform(IDataView input);
|
||||
|
||||
/// <summary>
|
||||
/// Whether a call to <see cref="GetRowToRowMapper(DataViewSchema)"/> should succeed, on an
|
||||
/// appropriate schema.
|
||||
/// </summary>
|
||||
bool IsRowToRowMapper { get; }
|
||||
|
||||
/// <summary>
|
||||
/// Constructs a row-to-row mapper based on an input schema. If <see cref="IsRowToRowMapper"/>
|
||||
/// is <c>false</c>, then an exception should be thrown. If the input schema is in any way
|
||||
/// unsuitable for constructing the mapper, an exception should likewise be thrown.
|
||||
/// </summary>
|
||||
/// <param name="inputSchema">The input schema for which we should get the mapper.</param>
|
||||
/// <returns>The row to row mapper.</returns>
|
||||
IRowToRowMapper GetRowToRowMapper(DataViewSchema inputSchema);
|
||||
}
|
||||
|
||||
[BestFriend]
|
||||
internal interface ITransformerWithDifferentMappingAtTrainingTime : ITransformer
|
||||
internal static SchemaShape Create(DataViewSchema schema)
|
||||
{
|
||||
IDataView TransformForTrainingPipeline(IDataView input);
|
||||
Contracts.CheckValue(schema, nameof(schema));
|
||||
var cols = new List<Column>();
|
||||
|
||||
for (int iCol = 0; iCol < schema.Count; iCol++)
|
||||
{
|
||||
if (!schema[iCol].IsHidden)
|
||||
{
|
||||
// First create the annotations.
|
||||
var mCols = new List<Column>();
|
||||
foreach (var annotationColumn in schema[iCol].Annotations.Schema)
|
||||
{
|
||||
GetColumnTypeShape(annotationColumn.Type, out var mVecKind, out var mItemType, out var mIsKey);
|
||||
mCols.Add(new Column(annotationColumn.Name, mVecKind, mItemType, mIsKey));
|
||||
}
|
||||
var annotations = mCols.Count > 0 ? new SchemaShape(mCols) : _empty;
|
||||
// Next create the single column.
|
||||
GetColumnTypeShape(schema[iCol].Type, out var vecKind, out var itemType, out var isKey);
|
||||
cols.Add(new Column(schema[iCol].Name, vecKind, itemType, isKey, annotations));
|
||||
}
|
||||
}
|
||||
return new SchemaShape(cols);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// The estimator (in Spark terminology) is an 'untrained transformer'. It needs to 'fit' on the data to manufacture
|
||||
/// a transformer.
|
||||
/// It also provides the 'schema propagation' like transformers do, but over <see cref="SchemaShape"/> instead of <see cref="DataViewSchema"/>.
|
||||
/// Returns if there is a column with a specified <paramref name="name"/> and if so stores it in <paramref name="column"/>.
|
||||
/// </summary>
|
||||
public interface IEstimator<out TTransformer>
|
||||
where TTransformer : ITransformer
|
||||
[BestFriend]
|
||||
internal bool TryFindColumn(string name, out Column column)
|
||||
{
|
||||
/// <summary>
|
||||
/// Train and return a transformer.
|
||||
/// </summary>
|
||||
TTransformer Fit(IDataView input);
|
||||
|
||||
/// <summary>
|
||||
/// Schema propagation for estimators.
|
||||
/// Returns the output schema shape of the estimator, if the input schema shape is like the one provided.
|
||||
/// </summary>
|
||||
SchemaShape GetOutputSchema(SchemaShape inputSchema);
|
||||
Contracts.CheckValue(name, nameof(name));
|
||||
column = _columns.FirstOrDefault(x => x.Name == name);
|
||||
return column.IsValid;
|
||||
}
|
||||
|
||||
public IEnumerator<Column> GetEnumerator() => ((IEnumerable<Column>)_columns).GetEnumerator();
|
||||
|
||||
IEnumerator IEnumerable.GetEnumerator() => GetEnumerator();
|
||||
|
||||
// REVIEW: I think we should have an IsCompatible method to check if it's OK to use one schema shape
|
||||
// as an input to another schema shape. I started writing, but realized that there's more than one way to check for
|
||||
// the 'compatibility': as in, 'CAN be compatible' vs. 'WILL be compatible'.
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// The 'data loader' takes a certain kind of input and turns it into an <see cref="IDataView"/>.
|
||||
/// </summary>
|
||||
/// <typeparam name="TSource">The type of input the loader takes.</typeparam>
|
||||
public interface IDataLoader<in TSource> : ICanSaveModel
|
||||
{
|
||||
/// <summary>
|
||||
/// Produce the data view from the specified input.
|
||||
/// Note that <see cref="IDataView"/>'s are lazy, so no actual loading happens here, just schema validation.
|
||||
/// </summary>
|
||||
IDataView Load(TSource input);
|
||||
|
||||
/// <summary>
|
||||
/// The output schema of the loader.
|
||||
/// </summary>
|
||||
DataViewSchema GetOutputSchema();
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Sometimes we need to 'fit' an <see cref="IDataLoader{TIn}"/>.
|
||||
/// A DataLoader estimator is the object that does it.
|
||||
/// </summary>
|
||||
public interface IDataLoaderEstimator<in TSource, out TLoader>
|
||||
where TLoader : IDataLoader<TSource>
|
||||
{
|
||||
// REVIEW: you could consider the transformer to take a different <typeparamref name="TSource"/>, but we don't have such components
|
||||
// yet, so why complicate matters?
|
||||
|
||||
/// <summary>
|
||||
/// Train and return a data loader.
|
||||
/// </summary>
|
||||
TLoader Fit(TSource input);
|
||||
|
||||
/// <summary>
|
||||
/// The 'promise' of the output schema.
|
||||
/// It will be used for schema propagation.
|
||||
/// </summary>
|
||||
SchemaShape GetOutputSchema();
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// The transformer is a component that transforms data.
|
||||
/// It also supports 'schema propagation' to answer the question of 'how will the data with this schema look, after you transform it?'.
|
||||
/// </summary>
|
||||
public interface ITransformer : ICanSaveModel
|
||||
{
|
||||
/// <summary>
|
||||
/// Schema propagation for transformers.
|
||||
/// Returns the output schema of the data, if the input schema is like the one provided.
|
||||
/// </summary>
|
||||
DataViewSchema GetOutputSchema(DataViewSchema inputSchema);
|
||||
|
||||
/// <summary>
|
||||
/// Take the data in, make transformations, output the data.
|
||||
/// Note that <see cref="IDataView"/>'s are lazy, so no actual transformations happen here, just schema validation.
|
||||
/// </summary>
|
||||
IDataView Transform(IDataView input);
|
||||
|
||||
/// <summary>
|
||||
/// Whether a call to <see cref="GetRowToRowMapper(DataViewSchema)"/> should succeed, on an
|
||||
/// appropriate schema.
|
||||
/// </summary>
|
||||
bool IsRowToRowMapper { get; }
|
||||
|
||||
/// <summary>
|
||||
/// Constructs a row-to-row mapper based on an input schema. If <see cref="IsRowToRowMapper"/>
|
||||
/// is <c>false</c>, then an exception should be thrown. If the input schema is in any way
|
||||
/// unsuitable for constructing the mapper, an exception should likewise be thrown.
|
||||
/// </summary>
|
||||
/// <param name="inputSchema">The input schema for which we should get the mapper.</param>
|
||||
/// <returns>The row to row mapper.</returns>
|
||||
IRowToRowMapper GetRowToRowMapper(DataViewSchema inputSchema);
|
||||
}
|
||||
|
||||
[BestFriend]
|
||||
internal interface ITransformerWithDifferentMappingAtTrainingTime : ITransformer
|
||||
{
|
||||
IDataView TransformForTrainingPipeline(IDataView input);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// The estimator (in Spark terminology) is an 'untrained transformer'. It needs to 'fit' on the data to manufacture
|
||||
/// a transformer.
|
||||
/// It also provides the 'schema propagation' like transformers do, but over <see cref="SchemaShape"/> instead of <see cref="DataViewSchema"/>.
|
||||
/// </summary>
|
||||
public interface IEstimator<out TTransformer>
|
||||
where TTransformer : ITransformer
|
||||
{
|
||||
/// <summary>
|
||||
/// Train and return a transformer.
|
||||
/// </summary>
|
||||
TTransformer Fit(IDataView input);
|
||||
|
||||
/// <summary>
|
||||
/// Schema propagation for estimators.
|
||||
/// Returns the output schema shape of the estimator, if the input schema shape is like the one provided.
|
||||
/// </summary>
|
||||
SchemaShape GetOutputSchema(SchemaShape inputSchema);
|
||||
}
|
||||
|
|
|
@ -8,191 +8,190 @@ using System.IO;
|
|||
using Microsoft.ML.Internal.Utilities;
|
||||
using Microsoft.ML.Runtime;
|
||||
|
||||
namespace Microsoft.ML.Data
|
||||
namespace Microsoft.ML.Data;
|
||||
|
||||
/// <summary>
|
||||
/// A file handle.
|
||||
/// </summary>
|
||||
public interface IFileHandle : IDisposable
|
||||
{
|
||||
/// <summary>
|
||||
/// A file handle.
|
||||
/// Returns whether CreateWriteStream is expected to succeed. Typically, once
|
||||
/// CreateWriteStream has been called once, this will forever more return false.
|
||||
/// </summary>
|
||||
public interface IFileHandle : IDisposable
|
||||
{
|
||||
/// <summary>
|
||||
/// Returns whether CreateWriteStream is expected to succeed. Typically, once
|
||||
/// CreateWriteStream has been called once, this will forever more return false.
|
||||
/// </summary>
|
||||
bool CanWrite { get; }
|
||||
|
||||
/// <summary>
|
||||
/// Returns whether OpenReadStream is expected to succeed.
|
||||
/// </summary>
|
||||
bool CanRead { get; }
|
||||
|
||||
/// <summary>
|
||||
/// Create a writable stream for this file handle.
|
||||
/// </summary>
|
||||
Stream CreateWriteStream();
|
||||
|
||||
/// <summary>
|
||||
/// Open a readable stream for this file handle.
|
||||
/// </summary>
|
||||
Stream OpenReadStream();
|
||||
}
|
||||
bool CanWrite { get; }
|
||||
|
||||
/// <summary>
|
||||
/// A simple disk-based file handle.
|
||||
/// Returns whether OpenReadStream is expected to succeed.
|
||||
/// </summary>
|
||||
public sealed class SimpleFileHandle : IFileHandle
|
||||
bool CanRead { get; }
|
||||
|
||||
/// <summary>
|
||||
/// Create a writable stream for this file handle.
|
||||
/// </summary>
|
||||
Stream CreateWriteStream();
|
||||
|
||||
/// <summary>
|
||||
/// Open a readable stream for this file handle.
|
||||
/// </summary>
|
||||
Stream OpenReadStream();
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// A simple disk-based file handle.
|
||||
/// </summary>
|
||||
public sealed class SimpleFileHandle : IFileHandle
|
||||
{
|
||||
private readonly string _fullPath;
|
||||
|
||||
// Exception context.
|
||||
private readonly IExceptionContext _ectx;
|
||||
|
||||
private readonly object _lock;
|
||||
|
||||
// Whether to delete the file when this is disposed.
|
||||
private readonly bool _autoDelete;
|
||||
|
||||
// Whether this file has contents. This is false if the file needs CreateWriteStream to be
|
||||
// called (before OpenReadStream can be called).
|
||||
private bool _wrote;
|
||||
// If non-null, the active write stream. This should be disposed before the first OpenReadStream call.
|
||||
private Stream _streamWrite;
|
||||
|
||||
// This contains the potentially active read streams. This is set to null once this file
|
||||
// handle has been disposed.
|
||||
private List<Stream> _streams;
|
||||
|
||||
private bool IsDisposed => _streams == null;
|
||||
|
||||
public SimpleFileHandle(IExceptionContext ectx, string path, bool needsWrite, bool autoDelete)
|
||||
{
|
||||
private readonly string _fullPath;
|
||||
Contracts.CheckValue(ectx, nameof(ectx));
|
||||
ectx.CheckNonEmpty(path, nameof(path));
|
||||
|
||||
// Exception context.
|
||||
private readonly IExceptionContext _ectx;
|
||||
_ectx = ectx;
|
||||
_fullPath = Path.GetFullPath(path);
|
||||
|
||||
private readonly object _lock;
|
||||
_autoDelete = autoDelete;
|
||||
|
||||
// Whether to delete the file when this is disposed.
|
||||
private readonly bool _autoDelete;
|
||||
// The file has already been written to iff needsWrite is false.
|
||||
_wrote = !needsWrite;
|
||||
|
||||
// Whether this file has contents. This is false if the file needs CreateWriteStream to be
|
||||
// called (before OpenReadStream can be called).
|
||||
private bool _wrote;
|
||||
// If non-null, the active write stream. This should be disposed before the first OpenReadStream call.
|
||||
private Stream _streamWrite;
|
||||
// REVIEW: Should this do some basic validation? Eg, for output files, ensure that
|
||||
// the directory exists (and perhaps even create an empty file); for input files, ensure
|
||||
// that the file exists (and perhaps even attempt to open it).
|
||||
|
||||
// This contains the potentially active read streams. This is set to null once this file
|
||||
// handle has been disposed.
|
||||
private List<Stream> _streams;
|
||||
_lock = new object();
|
||||
_streams = new List<Stream>();
|
||||
}
|
||||
|
||||
private bool IsDisposed => _streams == null;
|
||||
public bool CanWrite => !_wrote && !IsDisposed;
|
||||
|
||||
public SimpleFileHandle(IExceptionContext ectx, string path, bool needsWrite, bool autoDelete)
|
||||
{
|
||||
Contracts.CheckValue(ectx, nameof(ectx));
|
||||
ectx.CheckNonEmpty(path, nameof(path));
|
||||
public bool CanRead => _wrote && !IsDisposed;
|
||||
|
||||
_ectx = ectx;
|
||||
_fullPath = Path.GetFullPath(path);
|
||||
|
||||
_autoDelete = autoDelete;
|
||||
|
||||
// The file has already been written to iff needsWrite is false.
|
||||
_wrote = !needsWrite;
|
||||
|
||||
// REVIEW: Should this do some basic validation? Eg, for output files, ensure that
|
||||
// the directory exists (and perhaps even create an empty file); for input files, ensure
|
||||
// that the file exists (and perhaps even attempt to open it).
|
||||
|
||||
_lock = new object();
|
||||
_streams = new List<Stream>();
|
||||
}
|
||||
|
||||
public bool CanWrite => !_wrote && !IsDisposed;
|
||||
|
||||
public bool CanRead => _wrote && !IsDisposed;
|
||||
|
||||
public void Dispose()
|
||||
{
|
||||
lock (_lock)
|
||||
{
|
||||
if (IsDisposed)
|
||||
return;
|
||||
|
||||
Contracts.Assert(_streams != null);
|
||||
|
||||
// REVIEW: Is it safe to dispose these streams? What if they are
|
||||
// being used on other threads? Does that matter?
|
||||
if (_streamWrite != null)
|
||||
{
|
||||
try
|
||||
{
|
||||
_streamWrite.CloseEx();
|
||||
_streamWrite.Dispose();
|
||||
}
|
||||
catch
|
||||
{
|
||||
// REVIEW: What should we do here?
|
||||
Contracts.Assert(false, "Closing a SimpleFileHandle write stream failed!");
|
||||
}
|
||||
_streamWrite = null;
|
||||
}
|
||||
|
||||
foreach (var stream in _streams)
|
||||
{
|
||||
try
|
||||
{
|
||||
stream.CloseEx();
|
||||
stream.Dispose();
|
||||
}
|
||||
catch
|
||||
{
|
||||
// REVIEW: What should we do here?
|
||||
Contracts.Assert(false, "Closing a SimpleFileHandle read stream failed!");
|
||||
}
|
||||
}
|
||||
|
||||
_streams = null;
|
||||
Contracts.Assert(IsDisposed);
|
||||
|
||||
if (_autoDelete)
|
||||
{
|
||||
try
|
||||
{
|
||||
// Finally, delete the file.
|
||||
File.Delete(_fullPath);
|
||||
}
|
||||
catch
|
||||
{
|
||||
// REVIEW: What should we do here?
|
||||
Contracts.Assert(false, "Deleting a SimpleFileHandle physical file failed!");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void CheckNotDisposed()
|
||||
public void Dispose()
|
||||
{
|
||||
lock (_lock)
|
||||
{
|
||||
if (IsDisposed)
|
||||
throw _ectx.Except("SimpleFileHandle has already been disposed");
|
||||
}
|
||||
return;
|
||||
|
||||
public Stream CreateWriteStream()
|
||||
{
|
||||
lock (_lock)
|
||||
Contracts.Assert(_streams != null);
|
||||
|
||||
// REVIEW: Is it safe to dispose these streams? What if they are
|
||||
// being used on other threads? Does that matter?
|
||||
if (_streamWrite != null)
|
||||
{
|
||||
CheckNotDisposed();
|
||||
|
||||
if (_wrote)
|
||||
throw _ectx.Except("CreateWriteStream called multiple times on SimpleFileHandle");
|
||||
|
||||
Contracts.Assert(_streamWrite == null);
|
||||
_streamWrite = new FileStream(_fullPath, FileMode.Create, FileAccess.Write);
|
||||
_wrote = true;
|
||||
return _streamWrite;
|
||||
}
|
||||
}
|
||||
|
||||
public Stream OpenReadStream()
|
||||
{
|
||||
lock (_lock)
|
||||
{
|
||||
CheckNotDisposed();
|
||||
|
||||
if (!_wrote)
|
||||
throw _ectx.Except("SimpleFileHandle hasn't been written yet");
|
||||
|
||||
if (_streamWrite != null)
|
||||
try
|
||||
{
|
||||
if (_streamWrite.CanWrite)
|
||||
throw _ectx.Except("Write stream for SimpleFileHandle hasn't been disposed");
|
||||
_streamWrite = null;
|
||||
_streamWrite.CloseEx();
|
||||
_streamWrite.Dispose();
|
||||
}
|
||||
|
||||
// Drop read streams that have already been disposed.
|
||||
_streams.RemoveAll(s => !s.CanRead);
|
||||
|
||||
var stream = new FileStream(_fullPath, FileMode.Open, FileAccess.Read, FileShare.Read);
|
||||
_streams.Add(stream);
|
||||
return stream;
|
||||
catch
|
||||
{
|
||||
// REVIEW: What should we do here?
|
||||
Contracts.Assert(false, "Closing a SimpleFileHandle write stream failed!");
|
||||
}
|
||||
_streamWrite = null;
|
||||
}
|
||||
|
||||
foreach (var stream in _streams)
|
||||
{
|
||||
try
|
||||
{
|
||||
stream.CloseEx();
|
||||
stream.Dispose();
|
||||
}
|
||||
catch
|
||||
{
|
||||
// REVIEW: What should we do here?
|
||||
Contracts.Assert(false, "Closing a SimpleFileHandle read stream failed!");
|
||||
}
|
||||
}
|
||||
|
||||
_streams = null;
|
||||
Contracts.Assert(IsDisposed);
|
||||
|
||||
if (_autoDelete)
|
||||
{
|
||||
try
|
||||
{
|
||||
// Finally, delete the file.
|
||||
File.Delete(_fullPath);
|
||||
}
|
||||
catch
|
||||
{
|
||||
// REVIEW: What should we do here?
|
||||
Contracts.Assert(false, "Deleting a SimpleFileHandle physical file failed!");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void CheckNotDisposed()
|
||||
{
|
||||
if (IsDisposed)
|
||||
throw _ectx.Except("SimpleFileHandle has already been disposed");
|
||||
}
|
||||
|
||||
public Stream CreateWriteStream()
|
||||
{
|
||||
lock (_lock)
|
||||
{
|
||||
CheckNotDisposed();
|
||||
|
||||
if (_wrote)
|
||||
throw _ectx.Except("CreateWriteStream called multiple times on SimpleFileHandle");
|
||||
|
||||
Contracts.Assert(_streamWrite == null);
|
||||
_streamWrite = new FileStream(_fullPath, FileMode.Create, FileAccess.Write);
|
||||
_wrote = true;
|
||||
return _streamWrite;
|
||||
}
|
||||
}
|
||||
|
||||
public Stream OpenReadStream()
|
||||
{
|
||||
lock (_lock)
|
||||
{
|
||||
CheckNotDisposed();
|
||||
|
||||
if (!_wrote)
|
||||
throw _ectx.Except("SimpleFileHandle hasn't been written yet");
|
||||
|
||||
if (_streamWrite != null)
|
||||
{
|
||||
if (_streamWrite.CanWrite)
|
||||
throw _ectx.Except("Write stream for SimpleFileHandle hasn't been disposed");
|
||||
_streamWrite = null;
|
||||
}
|
||||
|
||||
// Drop read streams that have already been disposed.
|
||||
_streams.RemoveAll(s => !s.CanRead);
|
||||
|
||||
var stream = new FileStream(_fullPath, FileMode.Open, FileAccess.Read, FileShare.Read);
|
||||
_streams.Add(stream);
|
||||
return stream;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5,319 +5,318 @@
|
|||
using System;
|
||||
using Microsoft.ML.Data;
|
||||
|
||||
namespace Microsoft.ML.Runtime
|
||||
namespace Microsoft.ML.Runtime;
|
||||
|
||||
/// <summary>
|
||||
/// A channel provider can create new channels and generic information pipes.
|
||||
/// </summary>
|
||||
public interface IChannelProvider : IExceptionContext
|
||||
{
|
||||
/// <summary>
|
||||
/// A channel provider can create new channels and generic information pipes.
|
||||
/// Start a standard message channel.
|
||||
/// </summary>
|
||||
public interface IChannelProvider : IExceptionContext
|
||||
{
|
||||
/// <summary>
|
||||
/// Start a standard message channel.
|
||||
/// </summary>
|
||||
IChannel Start(string name);
|
||||
IChannel Start(string name);
|
||||
|
||||
/// <summary>
|
||||
/// Start a generic information pipe.
|
||||
/// </summary>
|
||||
IPipe<TMessage> StartPipe<TMessage>(string name);
|
||||
/// <summary>
|
||||
/// Start a generic information pipe.
|
||||
/// </summary>
|
||||
IPipe<TMessage> StartPipe<TMessage>(string name);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Utility class for IHostEnvironment
|
||||
/// </summary>
|
||||
[BestFriend]
|
||||
internal static class HostEnvironmentExtensions
|
||||
{
|
||||
/// <summary>
|
||||
/// Return a file handle for an input "file".
|
||||
/// </summary>
|
||||
public static IFileHandle OpenInputFile(this IHostEnvironment env, string path)
|
||||
{
|
||||
Contracts.AssertValue(env);
|
||||
Contracts.CheckNonWhiteSpace(path, nameof(path));
|
||||
return new SimpleFileHandle(env, path, needsWrite: false, autoDelete: false);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Utility class for IHostEnvironment
|
||||
/// Create an output "file" and return a handle to it.
|
||||
/// </summary>
|
||||
[BestFriend]
|
||||
internal static class HostEnvironmentExtensions
|
||||
public static IFileHandle CreateOutputFile(this IHostEnvironment env, string path)
|
||||
{
|
||||
/// <summary>
|
||||
/// Return a file handle for an input "file".
|
||||
/// </summary>
|
||||
public static IFileHandle OpenInputFile(this IHostEnvironment env, string path)
|
||||
{
|
||||
Contracts.AssertValue(env);
|
||||
Contracts.CheckNonWhiteSpace(path, nameof(path));
|
||||
return new SimpleFileHandle(env, path, needsWrite: false, autoDelete: false);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Create an output "file" and return a handle to it.
|
||||
/// </summary>
|
||||
public static IFileHandle CreateOutputFile(this IHostEnvironment env, string path)
|
||||
{
|
||||
Contracts.AssertValue(env);
|
||||
Contracts.CheckNonWhiteSpace(path, nameof(path));
|
||||
return new SimpleFileHandle(env, path, needsWrite: true, autoDelete: false);
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// The host environment interface creates hosts for components. Note that the methods of
|
||||
/// this interface should be called from the main thread for the environment. To get an environment
|
||||
/// to service another thread, call Fork and pass the return result to that thread.
|
||||
/// </summary>
|
||||
public interface IHostEnvironment : IChannelProvider, IProgressChannelProvider
|
||||
{
|
||||
/// <summary>
|
||||
/// Create a host with the given registration name.
|
||||
/// </summary>
|
||||
IHost Register(string name, int? seed = null, bool? verbose = null);
|
||||
|
||||
/// <summary>
|
||||
/// The catalog of loadable components (<see cref="LoadableClassAttribute"/>) that are available in this host.
|
||||
/// </summary>
|
||||
ComponentCatalog ComponentCatalog { get; }
|
||||
}
|
||||
|
||||
[BestFriend]
|
||||
internal interface ICancelable
|
||||
{
|
||||
/// <summary>
|
||||
/// Signal to stop execution in all the hosts.
|
||||
/// </summary>
|
||||
void CancelExecution();
|
||||
|
||||
/// <summary>
|
||||
/// Flag which indicates host execution has been stopped.
|
||||
/// </summary>
|
||||
bool IsCanceled { get; }
|
||||
}
|
||||
|
||||
[BestFriend]
|
||||
internal interface IHostEnvironmentInternal : IHostEnvironment
|
||||
{
|
||||
/// <summary>
|
||||
/// The seed property that, if assigned, makes components requiring randomness behave deterministically.
|
||||
/// </summary>
|
||||
int? Seed { get; }
|
||||
|
||||
/// <summary>
|
||||
/// The location for the temp files created by ML.NET
|
||||
/// </summary>
|
||||
string TempFilePath { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Allow falling back to run on CPU if couldn't run on GPU.
|
||||
/// </summary>
|
||||
bool FallbackToCpu { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// GPU device ID to run execution on, <see langword="null" /> to run on CPU.
|
||||
/// </summary>
|
||||
int? GpuDeviceId { get; set; }
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// A host is coupled to a component and provides random number generation and concurrency guidance.
|
||||
/// Note that the random number generation, like the host environment methods, should be accessed only
|
||||
/// from the main thread for the component.
|
||||
/// </summary>
|
||||
public interface IHost : IHostEnvironment
|
||||
{
|
||||
/// <summary>
|
||||
/// The random number generator issued to this component. Note that random number
|
||||
/// generators are NOT thread safe.
|
||||
/// </summary>
|
||||
Random Rand { get; }
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// A generic information pipe. Note that pipes are disposable. Generally, Done should
|
||||
/// be called before disposing to signal a normal shut-down of the pipe, as opposed
|
||||
/// to an aborted completion.
|
||||
/// </summary>
|
||||
public interface IPipe<TMessage> : IExceptionContext, IDisposable
|
||||
{
|
||||
/// <summary>
|
||||
/// The caller relinquishes ownership of the <paramref name="msg"/> object.
|
||||
/// </summary>
|
||||
void Send(TMessage msg);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// The kinds of standard channel messages.
|
||||
/// Note: These values should never be changed. We can add new kinds, but don't change these values.
|
||||
/// Other code bases, including native code for other projects depends on these values.
|
||||
/// </summary>
|
||||
public enum ChannelMessageKind
|
||||
{
|
||||
Trace = 0,
|
||||
Info = 1,
|
||||
Warning = 2,
|
||||
Error = 3
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// A flag that can be attached to a message or exception to indicate that
|
||||
/// it has a certain class of sensitive data. By default, messages should be
|
||||
/// specified as being of unknown sensitivity, which is to say, every
|
||||
/// sensitivity flag is turned on, corresponding to <see cref="Unknown"/>.
|
||||
/// Messages that are totally safe should be marked as <see cref="None"/>.
|
||||
/// However, if, say, one prints out data from a file (for example, this might
|
||||
/// be done when expressing parse errors), it should be flagged in that case
|
||||
/// with <see cref="UserData"/>.
|
||||
/// </summary>
|
||||
[Flags]
|
||||
public enum MessageSensitivity
|
||||
{
|
||||
/// <summary>
|
||||
/// For non-sensitive data.
|
||||
/// </summary>
|
||||
None = 0,
|
||||
|
||||
/// <summary>
|
||||
/// For messages that may contain user-data from data files.
|
||||
/// </summary>
|
||||
UserData = 0x1,
|
||||
|
||||
/// <summary>
|
||||
/// For messages that contain information like column names from datasets.
|
||||
/// Note that, despite being part of the schema, annotations should be treated
|
||||
/// as user data, since it is often derived from user data. Note also that
|
||||
/// types, despite being part of the schema, are not considered "sensitive"
|
||||
/// as such, in the same way that column names might be.
|
||||
/// </summary>
|
||||
Schema = 0x2,
|
||||
|
||||
// REVIEW: Other potentially sensitive things might include
|
||||
// stack traces in certain environments.
|
||||
|
||||
/// <summary>
|
||||
/// The default value, unknown, is treated as if everything is sensitive.
|
||||
/// </summary>
|
||||
Unknown = ~None,
|
||||
|
||||
/// <summary>
|
||||
/// An alias for <see cref="Unknown"/>, so it is functionally the same, except
|
||||
/// semantically it communicates the idea that we want all bits set.
|
||||
/// </summary>
|
||||
All = Unknown,
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// A channel message.
|
||||
/// </summary>
|
||||
public readonly struct ChannelMessage
|
||||
{
|
||||
public readonly ChannelMessageKind Kind;
|
||||
public readonly MessageSensitivity Sensitivity;
|
||||
private readonly string _message;
|
||||
private readonly object[] _args;
|
||||
|
||||
/// <summary>
|
||||
/// Line endings may not be normalized.
|
||||
/// </summary>
|
||||
public string Message => _args != null ? string.Format(_message, _args) : _message;
|
||||
|
||||
[BestFriend]
|
||||
internal ChannelMessage(ChannelMessageKind kind, MessageSensitivity sensitivity, string message)
|
||||
{
|
||||
Contracts.CheckNonEmpty(message, nameof(message));
|
||||
Kind = kind;
|
||||
Sensitivity = sensitivity;
|
||||
_message = message;
|
||||
_args = null;
|
||||
}
|
||||
|
||||
[BestFriend]
|
||||
internal ChannelMessage(ChannelMessageKind kind, MessageSensitivity sensitivity, string fmt, params object[] args)
|
||||
{
|
||||
Contracts.CheckNonEmpty(fmt, nameof(fmt));
|
||||
Contracts.CheckNonEmpty(args, nameof(args));
|
||||
Kind = kind;
|
||||
Sensitivity = sensitivity;
|
||||
_message = fmt;
|
||||
_args = args;
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// A standard communication channel.
|
||||
/// </summary>
|
||||
public interface IChannel : IPipe<ChannelMessage>
|
||||
{
|
||||
void Trace(MessageSensitivity sensitivity, string fmt);
|
||||
void Trace(MessageSensitivity sensitivity, string fmt, params object[] args);
|
||||
void Error(MessageSensitivity sensitivity, string fmt);
|
||||
void Error(MessageSensitivity sensitivity, string fmt, params object[] args);
|
||||
void Warning(MessageSensitivity sensitivity, string fmt);
|
||||
void Warning(MessageSensitivity sensitivity, string fmt, params object[] args);
|
||||
void Info(MessageSensitivity sensitivity, string fmt);
|
||||
void Info(MessageSensitivity sensitivity, string fmt, params object[] args);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// General utility extension methods for objects in the "host" universe, i.e.,
|
||||
/// <see cref="IHostEnvironment"/>, <see cref="IHost"/>, and <see cref="IChannel"/>
|
||||
/// that do not belong in more specific areas, for example, <see cref="Contracts"/> or
|
||||
/// component creation.
|
||||
/// </summary>
|
||||
[BestFriend]
|
||||
internal static class HostExtensions
|
||||
{
|
||||
public static T Apply<T>(this IHost host, string channelName, Func<IChannel, T> func)
|
||||
{
|
||||
T t;
|
||||
using (var ch = host.Start(channelName))
|
||||
{
|
||||
t = func(ch);
|
||||
}
|
||||
return t;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Convenience variant of <see cref="IChannel.Trace(MessageSensitivity, string)"/>
|
||||
/// setting <see cref="MessageSensitivity.Unknown"/>.
|
||||
/// </summary>
|
||||
public static void Trace(this IChannel ch, string fmt)
|
||||
=> ch.Trace(MessageSensitivity.Unknown, fmt);
|
||||
|
||||
/// <summary>
|
||||
/// Convenience variant of <see cref="IChannel.Trace(MessageSensitivity, string, object[])"/>
|
||||
/// setting <see cref="MessageSensitivity.Unknown"/>.
|
||||
/// </summary>
|
||||
public static void Trace(this IChannel ch, string fmt, params object[] args)
|
||||
=> ch.Trace(MessageSensitivity.Unknown, fmt, args);
|
||||
|
||||
/// <summary>
|
||||
/// Convenience variant of <see cref="IChannel.Error(MessageSensitivity, string)"/>
|
||||
/// setting <see cref="MessageSensitivity.Unknown"/>.
|
||||
/// </summary>
|
||||
public static void Error(this IChannel ch, string fmt)
|
||||
=> ch.Error(MessageSensitivity.Unknown, fmt);
|
||||
|
||||
/// <summary>
|
||||
/// Convenience variant of <see cref="IChannel.Error(MessageSensitivity, string, object[])"/>
|
||||
/// setting <see cref="MessageSensitivity.Unknown"/>.
|
||||
/// </summary>
|
||||
public static void Error(this IChannel ch, string fmt, params object[] args)
|
||||
=> ch.Error(MessageSensitivity.Unknown, fmt, args);
|
||||
|
||||
/// <summary>
|
||||
/// Convenience variant of <see cref="IChannel.Warning(MessageSensitivity, string)"/>
|
||||
/// setting <see cref="MessageSensitivity.Unknown"/>.
|
||||
/// </summary>
|
||||
public static void Warning(this IChannel ch, string fmt)
|
||||
=> ch.Warning(MessageSensitivity.Unknown, fmt);
|
||||
|
||||
/// <summary>
|
||||
/// Convenience variant of <see cref="IChannel.Warning(MessageSensitivity, string, object[])"/>
|
||||
/// setting <see cref="MessageSensitivity.Unknown"/>.
|
||||
/// </summary>
|
||||
public static void Warning(this IChannel ch, string fmt, params object[] args)
|
||||
=> ch.Warning(MessageSensitivity.Unknown, fmt, args);
|
||||
|
||||
/// <summary>
|
||||
/// Convenience variant of <see cref="IChannel.Info(MessageSensitivity, string)"/>
|
||||
/// setting <see cref="MessageSensitivity.Unknown"/>.
|
||||
/// </summary>
|
||||
public static void Info(this IChannel ch, string fmt)
|
||||
=> ch.Info(MessageSensitivity.Unknown, fmt);
|
||||
|
||||
/// <summary>
|
||||
/// Convenience variant of <see cref="IChannel.Info(MessageSensitivity, string, object[])"/>
|
||||
/// setting <see cref="MessageSensitivity.Unknown"/>.
|
||||
/// </summary>
|
||||
public static void Info(this IChannel ch, string fmt, params object[] args)
|
||||
=> ch.Info(MessageSensitivity.Unknown, fmt, args);
|
||||
Contracts.AssertValue(env);
|
||||
Contracts.CheckNonWhiteSpace(path, nameof(path));
|
||||
return new SimpleFileHandle(env, path, needsWrite: true, autoDelete: false);
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// The host environment interface creates hosts for components. Note that the methods of
|
||||
/// this interface should be called from the main thread for the environment. To get an environment
|
||||
/// to service another thread, call Fork and pass the return result to that thread.
|
||||
/// </summary>
|
||||
public interface IHostEnvironment : IChannelProvider, IProgressChannelProvider
|
||||
{
|
||||
/// <summary>
|
||||
/// Create a host with the given registration name.
|
||||
/// </summary>
|
||||
IHost Register(string name, int? seed = null, bool? verbose = null);
|
||||
|
||||
/// <summary>
|
||||
/// The catalog of loadable components (<see cref="LoadableClassAttribute"/>) that are available in this host.
|
||||
/// </summary>
|
||||
ComponentCatalog ComponentCatalog { get; }
|
||||
}
|
||||
|
||||
[BestFriend]
|
||||
internal interface ICancelable
|
||||
{
|
||||
/// <summary>
|
||||
/// Signal to stop execution in all the hosts.
|
||||
/// </summary>
|
||||
void CancelExecution();
|
||||
|
||||
/// <summary>
|
||||
/// Flag which indicates host execution has been stopped.
|
||||
/// </summary>
|
||||
bool IsCanceled { get; }
|
||||
}
|
||||
|
||||
[BestFriend]
|
||||
internal interface IHostEnvironmentInternal : IHostEnvironment
|
||||
{
|
||||
/// <summary>
|
||||
/// The seed property that, if assigned, makes components requiring randomness behave deterministically.
|
||||
/// </summary>
|
||||
int? Seed { get; }
|
||||
|
||||
/// <summary>
|
||||
/// The location for the temp files created by ML.NET
|
||||
/// </summary>
|
||||
string TempFilePath { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Allow falling back to run on CPU if couldn't run on GPU.
|
||||
/// </summary>
|
||||
bool FallbackToCpu { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// GPU device ID to run execution on, <see langword="null" /> to run on CPU.
|
||||
/// </summary>
|
||||
int? GpuDeviceId { get; set; }
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// A host is coupled to a component and provides random number generation and concurrency guidance.
|
||||
/// Note that the random number generation, like the host environment methods, should be accessed only
|
||||
/// from the main thread for the component.
|
||||
/// </summary>
|
||||
public interface IHost : IHostEnvironment
|
||||
{
|
||||
/// <summary>
|
||||
/// The random number generator issued to this component. Note that random number
|
||||
/// generators are NOT thread safe.
|
||||
/// </summary>
|
||||
Random Rand { get; }
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// A generic information pipe. Note that pipes are disposable. Generally, Done should
|
||||
/// be called before disposing to signal a normal shut-down of the pipe, as opposed
|
||||
/// to an aborted completion.
|
||||
/// </summary>
|
||||
public interface IPipe<TMessage> : IExceptionContext, IDisposable
|
||||
{
|
||||
/// <summary>
|
||||
/// The caller relinquishes ownership of the <paramref name="msg"/> object.
|
||||
/// </summary>
|
||||
void Send(TMessage msg);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// The kinds of standard channel messages.
|
||||
/// Note: These values should never be changed. We can add new kinds, but don't change these values.
|
||||
/// Other code bases, including native code for other projects depends on these values.
|
||||
/// </summary>
|
||||
public enum ChannelMessageKind
|
||||
{
|
||||
Trace = 0,
|
||||
Info = 1,
|
||||
Warning = 2,
|
||||
Error = 3
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// A flag that can be attached to a message or exception to indicate that
|
||||
/// it has a certain class of sensitive data. By default, messages should be
|
||||
/// specified as being of unknown sensitivity, which is to say, every
|
||||
/// sensitivity flag is turned on, corresponding to <see cref="Unknown"/>.
|
||||
/// Messages that are totally safe should be marked as <see cref="None"/>.
|
||||
/// However, if, say, one prints out data from a file (for example, this might
|
||||
/// be done when expressing parse errors), it should be flagged in that case
|
||||
/// with <see cref="UserData"/>.
|
||||
/// </summary>
|
||||
[Flags]
|
||||
public enum MessageSensitivity
|
||||
{
|
||||
/// <summary>
|
||||
/// For non-sensitive data.
|
||||
/// </summary>
|
||||
None = 0,
|
||||
|
||||
/// <summary>
|
||||
/// For messages that may contain user-data from data files.
|
||||
/// </summary>
|
||||
UserData = 0x1,
|
||||
|
||||
/// <summary>
|
||||
/// For messages that contain information like column names from datasets.
|
||||
/// Note that, despite being part of the schema, annotations should be treated
|
||||
/// as user data, since it is often derived from user data. Note also that
|
||||
/// types, despite being part of the schema, are not considered "sensitive"
|
||||
/// as such, in the same way that column names might be.
|
||||
/// </summary>
|
||||
Schema = 0x2,
|
||||
|
||||
// REVIEW: Other potentially sensitive things might include
|
||||
// stack traces in certain environments.
|
||||
|
||||
/// <summary>
|
||||
/// The default value, unknown, is treated as if everything is sensitive.
|
||||
/// </summary>
|
||||
Unknown = ~None,
|
||||
|
||||
/// <summary>
|
||||
/// An alias for <see cref="Unknown"/>, so it is functionally the same, except
|
||||
/// semantically it communicates the idea that we want all bits set.
|
||||
/// </summary>
|
||||
All = Unknown,
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// A channel message.
|
||||
/// </summary>
|
||||
public readonly struct ChannelMessage
|
||||
{
|
||||
public readonly ChannelMessageKind Kind;
|
||||
public readonly MessageSensitivity Sensitivity;
|
||||
private readonly string _message;
|
||||
private readonly object[] _args;
|
||||
|
||||
/// <summary>
|
||||
/// Line endings may not be normalized.
|
||||
/// </summary>
|
||||
public string Message => _args != null ? string.Format(_message, _args) : _message;
|
||||
|
||||
[BestFriend]
|
||||
internal ChannelMessage(ChannelMessageKind kind, MessageSensitivity sensitivity, string message)
|
||||
{
|
||||
Contracts.CheckNonEmpty(message, nameof(message));
|
||||
Kind = kind;
|
||||
Sensitivity = sensitivity;
|
||||
_message = message;
|
||||
_args = null;
|
||||
}
|
||||
|
||||
[BestFriend]
|
||||
internal ChannelMessage(ChannelMessageKind kind, MessageSensitivity sensitivity, string fmt, params object[] args)
|
||||
{
|
||||
Contracts.CheckNonEmpty(fmt, nameof(fmt));
|
||||
Contracts.CheckNonEmpty(args, nameof(args));
|
||||
Kind = kind;
|
||||
Sensitivity = sensitivity;
|
||||
_message = fmt;
|
||||
_args = args;
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// A standard communication channel.
|
||||
/// </summary>
|
||||
public interface IChannel : IPipe<ChannelMessage>
|
||||
{
|
||||
void Trace(MessageSensitivity sensitivity, string fmt);
|
||||
void Trace(MessageSensitivity sensitivity, string fmt, params object[] args);
|
||||
void Error(MessageSensitivity sensitivity, string fmt);
|
||||
void Error(MessageSensitivity sensitivity, string fmt, params object[] args);
|
||||
void Warning(MessageSensitivity sensitivity, string fmt);
|
||||
void Warning(MessageSensitivity sensitivity, string fmt, params object[] args);
|
||||
void Info(MessageSensitivity sensitivity, string fmt);
|
||||
void Info(MessageSensitivity sensitivity, string fmt, params object[] args);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// General utility extension methods for objects in the "host" universe, i.e.,
|
||||
/// <see cref="IHostEnvironment"/>, <see cref="IHost"/>, and <see cref="IChannel"/>
|
||||
/// that do not belong in more specific areas, for example, <see cref="Contracts"/> or
|
||||
/// component creation.
|
||||
/// </summary>
|
||||
[BestFriend]
|
||||
internal static class HostExtensions
|
||||
{
|
||||
public static T Apply<T>(this IHost host, string channelName, Func<IChannel, T> func)
|
||||
{
|
||||
T t;
|
||||
using (var ch = host.Start(channelName))
|
||||
{
|
||||
t = func(ch);
|
||||
}
|
||||
return t;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Convenience variant of <see cref="IChannel.Trace(MessageSensitivity, string)"/>
|
||||
/// setting <see cref="MessageSensitivity.Unknown"/>.
|
||||
/// </summary>
|
||||
public static void Trace(this IChannel ch, string fmt)
|
||||
=> ch.Trace(MessageSensitivity.Unknown, fmt);
|
||||
|
||||
/// <summary>
|
||||
/// Convenience variant of <see cref="IChannel.Trace(MessageSensitivity, string, object[])"/>
|
||||
/// setting <see cref="MessageSensitivity.Unknown"/>.
|
||||
/// </summary>
|
||||
public static void Trace(this IChannel ch, string fmt, params object[] args)
|
||||
=> ch.Trace(MessageSensitivity.Unknown, fmt, args);
|
||||
|
||||
/// <summary>
|
||||
/// Convenience variant of <see cref="IChannel.Error(MessageSensitivity, string)"/>
|
||||
/// setting <see cref="MessageSensitivity.Unknown"/>.
|
||||
/// </summary>
|
||||
public static void Error(this IChannel ch, string fmt)
|
||||
=> ch.Error(MessageSensitivity.Unknown, fmt);
|
||||
|
||||
/// <summary>
|
||||
/// Convenience variant of <see cref="IChannel.Error(MessageSensitivity, string, object[])"/>
|
||||
/// setting <see cref="MessageSensitivity.Unknown"/>.
|
||||
/// </summary>
|
||||
public static void Error(this IChannel ch, string fmt, params object[] args)
|
||||
=> ch.Error(MessageSensitivity.Unknown, fmt, args);
|
||||
|
||||
/// <summary>
|
||||
/// Convenience variant of <see cref="IChannel.Warning(MessageSensitivity, string)"/>
|
||||
/// setting <see cref="MessageSensitivity.Unknown"/>.
|
||||
/// </summary>
|
||||
public static void Warning(this IChannel ch, string fmt)
|
||||
=> ch.Warning(MessageSensitivity.Unknown, fmt);
|
||||
|
||||
/// <summary>
|
||||
/// Convenience variant of <see cref="IChannel.Warning(MessageSensitivity, string, object[])"/>
|
||||
/// setting <see cref="MessageSensitivity.Unknown"/>.
|
||||
/// </summary>
|
||||
public static void Warning(this IChannel ch, string fmt, params object[] args)
|
||||
=> ch.Warning(MessageSensitivity.Unknown, fmt, args);
|
||||
|
||||
/// <summary>
|
||||
/// Convenience variant of <see cref="IChannel.Info(MessageSensitivity, string)"/>
|
||||
/// setting <see cref="MessageSensitivity.Unknown"/>.
|
||||
/// </summary>
|
||||
public static void Info(this IChannel ch, string fmt)
|
||||
=> ch.Info(MessageSensitivity.Unknown, fmt);
|
||||
|
||||
/// <summary>
|
||||
/// Convenience variant of <see cref="IChannel.Info(MessageSensitivity, string, object[])"/>
|
||||
/// setting <see cref="MessageSensitivity.Unknown"/>.
|
||||
/// </summary>
|
||||
public static void Info(this IChannel ch, string fmt, params object[] args)
|
||||
=> ch.Info(MessageSensitivity.Unknown, fmt, args);
|
||||
}
|
||||
|
|
|
@ -5,140 +5,139 @@
|
|||
using System;
|
||||
using System.Collections.Generic;
|
||||
|
||||
namespace Microsoft.ML.Runtime
|
||||
namespace Microsoft.ML.Runtime;
|
||||
|
||||
/// <summary>
|
||||
/// This is a factory interface for <see cref="IProgressChannel"/>.
|
||||
/// Both <see cref="IHostEnvironment"/> and <see cref="IProgressChannel"/> implement this interface,
|
||||
/// to allow for nested progress reporters.
|
||||
///
|
||||
/// REVIEW: make <see cref="IChannelProvider"/> implement this, instead of the environment?
|
||||
/// </summary>
|
||||
public interface IProgressChannelProvider
|
||||
{
|
||||
/// <summary>
|
||||
/// This is a factory interface for <see cref="IProgressChannel"/>.
|
||||
/// Both <see cref="IHostEnvironment"/> and <see cref="IProgressChannel"/> implement this interface,
|
||||
/// to allow for nested progress reporters.
|
||||
/// Create a progress channel for a computation named <paramref name="name"/>.
|
||||
/// </summary>
|
||||
IProgressChannel StartProgressChannel(string name);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// A common interface for progress reporting.
|
||||
/// It is expected that the progress channel interface is used from only one thread.
|
||||
///
|
||||
/// Supported workflow:
|
||||
/// 1) Create the channel via <see cref="IProgressChannelProvider.StartProgressChannel"/>.
|
||||
/// 2) Call <see cref="SetHeader"/> as many times as desired (including 0).
|
||||
/// Each call to <see cref="SetHeader"/> supersedes the previous one.
|
||||
/// 3) Report checkpoints (0 or more) by calling <see cref="Checkpoint"/>.
|
||||
/// 4) Repeat steps 2-3 as often as necessary.
|
||||
/// 5) Dispose the channel.
|
||||
/// </summary>
|
||||
public interface IProgressChannel : IProgressChannelProvider, IDisposable
|
||||
{
|
||||
/// <summary>
|
||||
/// Set up the reporting structure:
|
||||
/// - Set the 'header' of the progress reports, defining which progress units and metrics are going to be reported.
|
||||
/// - Provide a thread-safe delegate to be invoked whenever anyone needs to know the progress.
|
||||
///
|
||||
/// REVIEW: make <see cref="IChannelProvider"/> implement this, instead of the environment?
|
||||
/// It is acceptable to call <see cref="SetHeader"/> multiple times (or none), regardless of whether the calculation is running
|
||||
/// or not. Because of synchronization, the computation should not deny calls to the 'old' <paramref name="fillAction"/>
|
||||
/// delegates even after a new one is provided.
|
||||
/// </summary>
|
||||
public interface IProgressChannelProvider
|
||||
{
|
||||
/// <summary>
|
||||
/// Create a progress channel for a computation named <paramref name="name"/>.
|
||||
/// </summary>
|
||||
IProgressChannel StartProgressChannel(string name);
|
||||
}
|
||||
/// <param name="header">The header object.</param>
|
||||
/// <param name="fillAction">The delegate to provide actual progress. The <see cref="IProgressEntry"/> parameter of
|
||||
/// the delegate will correspond to the provided <paramref name="header"/>.</param>
|
||||
void SetHeader(ProgressHeader header, Action<IProgressEntry> fillAction);
|
||||
|
||||
/// <summary>
|
||||
/// A common interface for progress reporting.
|
||||
/// It is expected that the progress channel interface is used from only one thread.
|
||||
/// Submit a 'checkpoint' entry. These entries are guaranteed to be delivered to the progress listener,
|
||||
/// if it is interested. Typically, this would contain some intermediate metrics, that are only calculated
|
||||
/// at certain moments ('checkpoints') of the computation.
|
||||
///
|
||||
/// Supported workflow:
|
||||
/// 1) Create the channel via <see cref="IProgressChannelProvider.StartProgressChannel"/>.
|
||||
/// 2) Call <see cref="SetHeader"/> as many times as desired (including 0).
|
||||
/// Each call to <see cref="SetHeader"/> supersedes the previous one.
|
||||
/// 3) Report checkpoints (0 or more) by calling <see cref="Checkpoint"/>.
|
||||
/// 4) Repeat steps 2-3 as often as necessary.
|
||||
/// 5) Dispose the channel.
|
||||
/// For example, SDCA may report a checkpoint every time it computes the loss, or LBFGS may report a checkpoint
|
||||
/// every iteration.
|
||||
///
|
||||
/// The only parameter, <paramref name="values"/>, is interpreted in the following fashion:
|
||||
/// * First MetricNames.Length items, if present, are metrics.
|
||||
/// * Subsequent ProgressNames.Length items, if present, are progress units.
|
||||
/// * Subsequent ProgressNames.Length items, if present, are progress limits.
|
||||
/// * If any more values remain, an exception is thrown.
|
||||
/// </summary>
|
||||
public interface IProgressChannel : IProgressChannelProvider, IDisposable
|
||||
{
|
||||
/// <summary>
|
||||
/// Set up the reporting structure:
|
||||
/// - Set the 'header' of the progress reports, defining which progress units and metrics are going to be reported.
|
||||
/// - Provide a thread-safe delegate to be invoked whenever anyone needs to know the progress.
|
||||
///
|
||||
/// It is acceptable to call <see cref="SetHeader"/> multiple times (or none), regardless of whether the calculation is running
|
||||
/// or not. Because of synchronization, the computation should not deny calls to the 'old' <paramref name="fillAction"/>
|
||||
/// delegates even after a new one is provided.
|
||||
/// </summary>
|
||||
/// <param name="header">The header object.</param>
|
||||
/// <param name="fillAction">The delegate to provide actual progress. The <see cref="IProgressEntry"/> parameter of
|
||||
/// the delegate will correspond to the provided <paramref name="header"/>.</param>
|
||||
void SetHeader(ProgressHeader header, Action<IProgressEntry> fillAction);
|
||||
/// <param name="values">The metrics, progress units and progress limits.</param>
|
||||
void Checkpoint(params Double?[] values);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Submit a 'checkpoint' entry. These entries are guaranteed to be delivered to the progress listener,
|
||||
/// if it is interested. Typically, this would contain some intermediate metrics, that are only calculated
|
||||
/// at certain moments ('checkpoints') of the computation.
|
||||
///
|
||||
/// For example, SDCA may report a checkpoint every time it computes the loss, or LBFGS may report a checkpoint
|
||||
/// every iteration.
|
||||
///
|
||||
/// The only parameter, <paramref name="values"/>, is interpreted in the following fashion:
|
||||
/// * First MetricNames.Length items, if present, are metrics.
|
||||
/// * Subsequent ProgressNames.Length items, if present, are progress units.
|
||||
/// * Subsequent ProgressNames.Length items, if present, are progress limits.
|
||||
/// * If any more values remain, an exception is thrown.
|
||||
/// </summary>
|
||||
/// <param name="values">The metrics, progress units and progress limits.</param>
|
||||
void Checkpoint(params Double?[] values);
|
||||
/// <summary>
|
||||
/// This is the 'header' of the progress report.
|
||||
/// </summary>
|
||||
public sealed class ProgressHeader
|
||||
{
|
||||
/// <summary>
|
||||
/// These are the names of the progress 'units', from the least granular to the most granular.
|
||||
/// For example, neural network might have {'epoch', 'example'} and FastTree might have {'tree', 'split', 'feature'}.
|
||||
/// Will never be null, but can be empty.
|
||||
/// </summary>
|
||||
public readonly IReadOnlyList<string> UnitNames;
|
||||
|
||||
/// <summary>
|
||||
/// These are the names of the reported metrics. For example, this could be the 'loss', 'weight updates/sec' etc.
|
||||
/// Will never be null, but can be empty.
|
||||
/// </summary>
|
||||
public readonly IReadOnlyList<string> MetricNames;
|
||||
|
||||
/// <summary>
|
||||
/// Initialize the header. This will take ownership of the arrays.
|
||||
/// Both arrays can be null, even simultaneously. This 'empty' header indicated that the calculation doesn't report
|
||||
/// any units of progress, but the tracker can still track start, stop and elapsed time. Of course, if there's any
|
||||
/// progress or metrics to report, it is always better to report them.
|
||||
/// </summary>
|
||||
/// <param name="metricNames">The metrics that the calculation reports. These are completely independent, and there
|
||||
/// is no contract on whether the metric values should increase or not. As naming convention, <paramref name="metricNames"/>
|
||||
/// can have multiple words with spaces, and should be title-cased.</param>
|
||||
/// <param name="unitNames">The names of the progress units, listed from least granular to most granular.
|
||||
/// The idea is that the progress should be lexicographically increasing (like [0,0], [0,10], [1,0], [1,15], [2,5] etc.).
|
||||
/// As naming convention, <paramref name="unitNames"/> should be lower-cased and typically plural
|
||||
/// (for example, iterations, clusters, examples). </param>
|
||||
public ProgressHeader(string[] metricNames, string[] unitNames)
|
||||
{
|
||||
Contracts.CheckValueOrNull(unitNames);
|
||||
Contracts.CheckValueOrNull(metricNames);
|
||||
|
||||
UnitNames = unitNames ?? new string[0];
|
||||
MetricNames = metricNames ?? new string[0];
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// This is the 'header' of the progress report.
|
||||
/// A constructor for no metrics, just progress units. As naming convention, <paramref name="unitNames"/> should be lower-cased
|
||||
/// and typically plural (for example, iterations, clusters, examples).
|
||||
/// </summary>
|
||||
public sealed class ProgressHeader
|
||||
public ProgressHeader(params string[] unitNames)
|
||||
: this(null, unitNames)
|
||||
{
|
||||
/// <summary>
|
||||
/// These are the names of the progress 'units', from the least granular to the most granular.
|
||||
/// For example, neural network might have {'epoch', 'example'} and FastTree might have {'tree', 'split', 'feature'}.
|
||||
/// Will never be null, but can be empty.
|
||||
/// </summary>
|
||||
public readonly IReadOnlyList<string> UnitNames;
|
||||
|
||||
/// <summary>
|
||||
/// These are the names of the reported metrics. For example, this could be the 'loss', 'weight updates/sec' etc.
|
||||
/// Will never be null, but can be empty.
|
||||
/// </summary>
|
||||
public readonly IReadOnlyList<string> MetricNames;
|
||||
|
||||
/// <summary>
|
||||
/// Initialize the header. This will take ownership of the arrays.
|
||||
/// Both arrays can be null, even simultaneously. This 'empty' header indicated that the calculation doesn't report
|
||||
/// any units of progress, but the tracker can still track start, stop and elapsed time. Of course, if there's any
|
||||
/// progress or metrics to report, it is always better to report them.
|
||||
/// </summary>
|
||||
/// <param name="metricNames">The metrics that the calculation reports. These are completely independent, and there
|
||||
/// is no contract on whether the metric values should increase or not. As naming convention, <paramref name="metricNames"/>
|
||||
/// can have multiple words with spaces, and should be title-cased.</param>
|
||||
/// <param name="unitNames">The names of the progress units, listed from least granular to most granular.
|
||||
/// The idea is that the progress should be lexicographically increasing (like [0,0], [0,10], [1,0], [1,15], [2,5] etc.).
|
||||
/// As naming convention, <paramref name="unitNames"/> should be lower-cased and typically plural
|
||||
/// (for example, iterations, clusters, examples). </param>
|
||||
public ProgressHeader(string[] metricNames, string[] unitNames)
|
||||
{
|
||||
Contracts.CheckValueOrNull(unitNames);
|
||||
Contracts.CheckValueOrNull(metricNames);
|
||||
|
||||
UnitNames = unitNames ?? new string[0];
|
||||
MetricNames = metricNames ?? new string[0];
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// A constructor for no metrics, just progress units. As naming convention, <paramref name="unitNames"/> should be lower-cased
|
||||
/// and typically plural (for example, iterations, clusters, examples).
|
||||
/// </summary>
|
||||
public ProgressHeader(params string[] unitNames)
|
||||
: this(null, unitNames)
|
||||
{
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// A metric/progress holder item.
|
||||
/// </summary>
|
||||
public interface IProgressEntry
|
||||
{
|
||||
/// <summary>
|
||||
/// Set the progress value for the index <paramref name="index"/> to <paramref name="value"/>,
|
||||
/// and the limit value for the progress becomes 'unknown'.
|
||||
/// </summary>
|
||||
void SetProgress(int index, Double value);
|
||||
|
||||
/// <summary>
|
||||
/// Set the progress value for the index <paramref name="index"/> to <paramref name="value"/>,
|
||||
/// and the limit value to <paramref name="lim"/>. If <paramref name="lim"/> is a NAN, it is set to null instead.
|
||||
/// </summary>
|
||||
void SetProgress(int index, Double value, Double lim);
|
||||
|
||||
/// <summary>
|
||||
/// Sets the metric with index <paramref name="index"/> to <paramref name="value"/>.
|
||||
/// </summary>
|
||||
void SetMetric(int index, Double value);
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// A metric/progress holder item.
|
||||
/// </summary>
|
||||
public interface IProgressEntry
|
||||
{
|
||||
/// <summary>
|
||||
/// Set the progress value for the index <paramref name="index"/> to <paramref name="value"/>,
|
||||
/// and the limit value for the progress becomes 'unknown'.
|
||||
/// </summary>
|
||||
void SetProgress(int index, Double value);
|
||||
|
||||
/// <summary>
|
||||
/// Set the progress value for the index <paramref name="index"/> to <paramref name="value"/>,
|
||||
/// and the limit value to <paramref name="lim"/>. If <paramref name="lim"/> is a NAN, it is set to null instead.
|
||||
/// </summary>
|
||||
void SetProgress(int index, Double value, Double lim);
|
||||
|
||||
/// <summary>
|
||||
/// Sets the metric with index <paramref name="index"/> to <paramref name="value"/>.
|
||||
/// </summary>
|
||||
void SetMetric(int index, Double value);
|
||||
|
||||
}
|
||||
|
|
|
@ -5,47 +5,46 @@
|
|||
using System;
|
||||
using System.Collections.Generic;
|
||||
|
||||
namespace Microsoft.ML.Data
|
||||
namespace Microsoft.ML.Data;
|
||||
|
||||
/// <summary>
|
||||
/// This interface maps an input <see cref="DataViewRow"/> to an output <see cref="DataViewRow"/>. Typically, the output contains
|
||||
/// both the input columns and new columns added by the implementing class, although some implementations may
|
||||
/// return a subset of the input columns.
|
||||
/// This interface is similar to <see cref="ISchemaBoundRowMapper"/>, except it does not have any input role mappings,
|
||||
/// so to rebind, the same input column names must be used.
|
||||
/// Implementations of this interface are typically created over defined input <see cref="DataViewSchema"/>.
|
||||
/// </summary>
|
||||
public interface IRowToRowMapper
|
||||
{
|
||||
/// <summary>
|
||||
/// This interface maps an input <see cref="DataViewRow"/> to an output <see cref="DataViewRow"/>. Typically, the output contains
|
||||
/// both the input columns and new columns added by the implementing class, although some implementations may
|
||||
/// return a subset of the input columns.
|
||||
/// This interface is similar to <see cref="ISchemaBoundRowMapper"/>, except it does not have any input role mappings,
|
||||
/// so to rebind, the same input column names must be used.
|
||||
/// Implementations of this interface are typically created over defined input <see cref="DataViewSchema"/>.
|
||||
/// Mappers are defined as accepting inputs with this very specific schema.
|
||||
/// </summary>
|
||||
public interface IRowToRowMapper
|
||||
{
|
||||
/// <summary>
|
||||
/// Mappers are defined as accepting inputs with this very specific schema.
|
||||
/// </summary>
|
||||
DataViewSchema InputSchema { get; }
|
||||
DataViewSchema InputSchema { get; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets an instance of <see cref="DataViewSchema"/> which describes the columns' names and types in the output generated by this mapper.
|
||||
/// </summary>
|
||||
DataViewSchema OutputSchema { get; }
|
||||
/// <summary>
|
||||
/// Gets an instance of <see cref="DataViewSchema"/> which describes the columns' names and types in the output generated by this mapper.
|
||||
/// </summary>
|
||||
DataViewSchema OutputSchema { get; }
|
||||
|
||||
/// <summary>
|
||||
/// Given a set of columns, return the input columns that are needed to generate those output columns.
|
||||
/// </summary>
|
||||
IEnumerable<DataViewSchema.Column> GetDependencies(IEnumerable<DataViewSchema.Column> dependingColumns);
|
||||
/// <summary>
|
||||
/// Given a set of columns, return the input columns that are needed to generate those output columns.
|
||||
/// </summary>
|
||||
IEnumerable<DataViewSchema.Column> GetDependencies(IEnumerable<DataViewSchema.Column> dependingColumns);
|
||||
|
||||
/// <summary>
|
||||
/// Get an <see cref="DataViewRow"/> with the indicated active columns, based on the input <paramref name="input"/>.
|
||||
/// Getting values on inactive columns of the returned row will throw.
|
||||
///
|
||||
/// The <see cref="DataViewRow.Schema"/> of <paramref name="input"/> should be the same object as
|
||||
/// <see cref="InputSchema"/>. Implementors of this method should throw if that is not the case. Conversely,
|
||||
/// the returned value must have the same schema as <see cref="OutputSchema"/>.
|
||||
///
|
||||
/// This method creates a live connection between the input <see cref="DataViewRow"/> and the output <see
|
||||
/// cref="DataViewRow"/>. In particular, when the getters of the output <see cref="DataViewRow"/> are invoked, they invoke the
|
||||
/// getters of the input row and base the output values on the current values of the input <see cref="DataViewRow"/>.
|
||||
/// The output <see cref="DataViewRow"/> values are re-computed when requested through the getters. Also, the returned
|
||||
/// <see cref="DataViewRow"/> will dispose <paramref name="input"/> when it is disposed.
|
||||
/// </summary>
|
||||
DataViewRow GetRow(DataViewRow input, IEnumerable<DataViewSchema.Column> activeColumns);
|
||||
}
|
||||
/// <summary>
|
||||
/// Get an <see cref="DataViewRow"/> with the indicated active columns, based on the input <paramref name="input"/>.
|
||||
/// Getting values on inactive columns of the returned row will throw.
|
||||
///
|
||||
/// The <see cref="DataViewRow.Schema"/> of <paramref name="input"/> should be the same object as
|
||||
/// <see cref="InputSchema"/>. Implementors of this method should throw if that is not the case. Conversely,
|
||||
/// the returned value must have the same schema as <see cref="OutputSchema"/>.
|
||||
///
|
||||
/// This method creates a live connection between the input <see cref="DataViewRow"/> and the output <see
|
||||
/// cref="DataViewRow"/>. In particular, when the getters of the output <see cref="DataViewRow"/> are invoked, they invoke the
|
||||
/// getters of the input row and base the output values on the current values of the input <see cref="DataViewRow"/>.
|
||||
/// The output <see cref="DataViewRow"/> values are re-computed when requested through the getters. Also, the returned
|
||||
/// <see cref="DataViewRow"/> will dispose <paramref name="input"/> when it is disposed.
|
||||
/// </summary>
|
||||
DataViewRow GetRow(DataViewRow input, IEnumerable<DataViewSchema.Column> activeColumns);
|
||||
}
|
||||
|
|
|
@ -6,85 +6,84 @@ using System;
|
|||
using System.Collections.Generic;
|
||||
using Microsoft.ML.Runtime;
|
||||
|
||||
namespace Microsoft.ML.Data
|
||||
namespace Microsoft.ML.Data;
|
||||
|
||||
/// <summary>
|
||||
/// A mapper that can be bound to a <see cref="RoleMappedSchema"/> (which encapsulates a <see cref="DataViewSchema"/> and has mappings from column kinds
|
||||
/// to columns of that schema). Binding an <see cref="ISchemaBindableMapper"/> to a <see cref="RoleMappedSchema"/> produces an
|
||||
/// <see cref="ISchemaBoundMapper"/>, which is an interface that has methods to return the names and indices of the input columns
|
||||
/// needed by the mapper to compute its output. The <see cref="ISchemaBoundRowMapper"/> is an extention to this interface, that
|
||||
/// can also produce an output <see cref="DataViewRow"/> given an input <see cref="DataViewRow"/>. The <see cref="DataViewRow"/> produced generally contains only the output columns of the mapper, and not
|
||||
/// the input columns (but there is nothing preventing an <see cref="ISchemaBoundRowMapper"/> from mapping input columns directly to outputs).
|
||||
/// This interface is implemented by wrappers of IValueMapper based predictors, which are predictors that take a single
|
||||
/// features column. New predictors can implement <see cref="ISchemaBindableMapper"/> directly. Implementing <see cref="ISchemaBindableMapper"/>
|
||||
/// includes implementing a corresponding <see cref="ISchemaBoundMapper"/> (or <see cref="ISchemaBoundRowMapper"/>) and a corresponding ISchema
|
||||
/// for the output schema of the <see cref="ISchemaBoundMapper"/>. In case the <see cref="ISchemaBoundRowMapper"/> interface is implemented,
|
||||
/// the SimpleRow class can be used in the <see cref="IRowToRowMapper.GetRow"/> method.
|
||||
/// </summary>
|
||||
[BestFriend]
|
||||
internal interface ISchemaBindableMapper
|
||||
{
|
||||
ISchemaBoundMapper Bind(IHostEnvironment env, RoleMappedSchema schema);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// This interface is used to map a schema from input columns to output columns. The <see cref="ISchemaBoundMapper"/> should keep track
|
||||
/// of the input columns that are needed for the mapping.
|
||||
/// </summary>
|
||||
[BestFriend]
|
||||
internal interface ISchemaBoundMapper
|
||||
{
|
||||
/// <summary>
|
||||
/// A mapper that can be bound to a <see cref="RoleMappedSchema"/> (which encapsulates a <see cref="DataViewSchema"/> and has mappings from column kinds
|
||||
/// to columns of that schema). Binding an <see cref="ISchemaBindableMapper"/> to a <see cref="RoleMappedSchema"/> produces an
|
||||
/// <see cref="ISchemaBoundMapper"/>, which is an interface that has methods to return the names and indices of the input columns
|
||||
/// needed by the mapper to compute its output. The <see cref="ISchemaBoundRowMapper"/> is an extention to this interface, that
|
||||
/// can also produce an output <see cref="DataViewRow"/> given an input <see cref="DataViewRow"/>. The <see cref="DataViewRow"/> produced generally contains only the output columns of the mapper, and not
|
||||
/// the input columns (but there is nothing preventing an <see cref="ISchemaBoundRowMapper"/> from mapping input columns directly to outputs).
|
||||
/// This interface is implemented by wrappers of IValueMapper based predictors, which are predictors that take a single
|
||||
/// features column. New predictors can implement <see cref="ISchemaBindableMapper"/> directly. Implementing <see cref="ISchemaBindableMapper"/>
|
||||
/// includes implementing a corresponding <see cref="ISchemaBoundMapper"/> (or <see cref="ISchemaBoundRowMapper"/>) and a corresponding ISchema
|
||||
/// for the output schema of the <see cref="ISchemaBoundMapper"/>. In case the <see cref="ISchemaBoundRowMapper"/> interface is implemented,
|
||||
/// the SimpleRow class can be used in the <see cref="IRowToRowMapper.GetRow"/> method.
|
||||
/// The <see cref="RoleMappedSchema"/> that was passed to the <see cref="ISchemaBoundMapper"/> in the binding process.
|
||||
/// </summary>
|
||||
[BestFriend]
|
||||
internal interface ISchemaBindableMapper
|
||||
{
|
||||
ISchemaBoundMapper Bind(IHostEnvironment env, RoleMappedSchema schema);
|
||||
}
|
||||
RoleMappedSchema InputRoleMappedSchema { get; }
|
||||
|
||||
/// <summary>
|
||||
/// This interface is used to map a schema from input columns to output columns. The <see cref="ISchemaBoundMapper"/> should keep track
|
||||
/// of the input columns that are needed for the mapping.
|
||||
/// Gets schema of this mapper's output.
|
||||
/// </summary>
|
||||
[BestFriend]
|
||||
internal interface ISchemaBoundMapper
|
||||
{
|
||||
/// <summary>
|
||||
/// The <see cref="RoleMappedSchema"/> that was passed to the <see cref="ISchemaBoundMapper"/> in the binding process.
|
||||
/// </summary>
|
||||
RoleMappedSchema InputRoleMappedSchema { get; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets schema of this mapper's output.
|
||||
/// </summary>
|
||||
DataViewSchema OutputSchema { get; }
|
||||
|
||||
/// <summary>
|
||||
/// A property to get back the <see cref="ISchemaBindableMapper"/> that produced this <see cref="ISchemaBoundMapper"/>.
|
||||
/// </summary>
|
||||
ISchemaBindableMapper Bindable { get; }
|
||||
|
||||
/// <summary>
|
||||
/// This method returns the binding information: which input columns are used and in what roles.
|
||||
/// </summary>
|
||||
IEnumerable<KeyValuePair<RoleMappedSchema.ColumnRole, string>> GetInputColumnRoles();
|
||||
}
|
||||
DataViewSchema OutputSchema { get; }
|
||||
|
||||
/// <summary>
|
||||
/// This interface extends <see cref="ISchemaBoundMapper"/>.
|
||||
/// A property to get back the <see cref="ISchemaBindableMapper"/> that produced this <see cref="ISchemaBoundMapper"/>.
|
||||
/// </summary>
|
||||
[BestFriend]
|
||||
internal interface ISchemaBoundRowMapper : ISchemaBoundMapper
|
||||
{
|
||||
/// <summary>
|
||||
/// Input schema accepted.
|
||||
/// </summary>
|
||||
DataViewSchema InputSchema { get; }
|
||||
ISchemaBindableMapper Bindable { get; }
|
||||
|
||||
/// <summary>
|
||||
/// Given a set of columns, from the newly generated ones, return the input columns that are needed to generate those output columns.
|
||||
/// </summary>
|
||||
IEnumerable<DataViewSchema.Column> GetDependenciesForNewColumns(IEnumerable<DataViewSchema.Column> dependingColumns);
|
||||
|
||||
/// <summary>
|
||||
/// Get an <see cref="DataViewRow"/> with the indicated active columns, based on the input <paramref name="input"/>.
|
||||
/// Getting values on inactive columns of the returned row will throw.
|
||||
///
|
||||
/// The <see cref="DataViewRow.Schema"/> of <paramref name="input"/> should be the same object as
|
||||
/// <see cref="InputSchema"/>. Implementors of this method should throw if that is not the case. Conversely,
|
||||
/// the returned value must have the same schema as <see cref="ISchemaBoundMapper.OutputSchema"/>.
|
||||
///
|
||||
/// This method creates a live connection between the input <see cref="DataViewRow"/> and the output <see
|
||||
/// cref="DataViewRow"/>. In particular, when the getters of the output <see cref="DataViewRow"/> are invoked, they invoke the
|
||||
/// getters of the input row and base the output values on the current values of the input <see cref="DataViewRow"/>.
|
||||
/// The output <see cref="DataViewRow"/> values are re-computed when requested through the getters. Also, the returned
|
||||
/// <see cref="DataViewRow"/> will dispose <paramref name="input"/> when it is disposed.
|
||||
/// </summary>
|
||||
DataViewRow GetRow(DataViewRow input, IEnumerable<DataViewSchema.Column> activeColumns);
|
||||
}
|
||||
/// <summary>
|
||||
/// This method returns the binding information: which input columns are used and in what roles.
|
||||
/// </summary>
|
||||
IEnumerable<KeyValuePair<RoleMappedSchema.ColumnRole, string>> GetInputColumnRoles();
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// This interface extends <see cref="ISchemaBoundMapper"/>.
|
||||
/// </summary>
|
||||
[BestFriend]
|
||||
internal interface ISchemaBoundRowMapper : ISchemaBoundMapper
|
||||
{
|
||||
/// <summary>
|
||||
/// Input schema accepted.
|
||||
/// </summary>
|
||||
DataViewSchema InputSchema { get; }
|
||||
|
||||
/// <summary>
|
||||
/// Given a set of columns, from the newly generated ones, return the input columns that are needed to generate those output columns.
|
||||
/// </summary>
|
||||
IEnumerable<DataViewSchema.Column> GetDependenciesForNewColumns(IEnumerable<DataViewSchema.Column> dependingColumns);
|
||||
|
||||
/// <summary>
|
||||
/// Get an <see cref="DataViewRow"/> with the indicated active columns, based on the input <paramref name="input"/>.
|
||||
/// Getting values on inactive columns of the returned row will throw.
|
||||
///
|
||||
/// The <see cref="DataViewRow.Schema"/> of <paramref name="input"/> should be the same object as
|
||||
/// <see cref="InputSchema"/>. Implementors of this method should throw if that is not the case. Conversely,
|
||||
/// the returned value must have the same schema as <see cref="ISchemaBoundMapper.OutputSchema"/>.
|
||||
///
|
||||
/// This method creates a live connection between the input <see cref="DataViewRow"/> and the output <see
|
||||
/// cref="DataViewRow"/>. In particular, when the getters of the output <see cref="DataViewRow"/> are invoked, they invoke the
|
||||
/// getters of the input row and base the output values on the current values of the input <see cref="DataViewRow"/>.
|
||||
/// The output <see cref="DataViewRow"/> values are re-computed when requested through the getters. Also, the returned
|
||||
/// <see cref="DataViewRow"/> will dispose <paramref name="input"/> when it is disposed.
|
||||
/// </summary>
|
||||
DataViewRow GetRow(DataViewRow input, IEnumerable<DataViewSchema.Column> activeColumns);
|
||||
}
|
||||
|
|
|
@ -2,57 +2,56 @@
|
|||
// The .NET Foundation licenses this file to you under the MIT license.
|
||||
// See the LICENSE file in the project root for more information.
|
||||
|
||||
namespace Microsoft.ML.Data
|
||||
namespace Microsoft.ML.Data;
|
||||
|
||||
/// <summary>
|
||||
/// Delegate type to map/convert a value.
|
||||
/// </summary>
|
||||
[BestFriend]
|
||||
internal delegate void ValueMapper<TSrc, TDst>(in TSrc src, ref TDst dst);
|
||||
|
||||
/// <summary>
|
||||
/// Delegate type to map/convert among three values, for example, one input with two
|
||||
/// outputs, or two inputs with one output.
|
||||
/// </summary>
|
||||
[BestFriend]
|
||||
internal delegate void ValueMapper<TVal1, TVal2, TVal3>(in TVal1 val1, ref TVal2 val2, ref TVal3 val3);
|
||||
|
||||
/// <summary>
|
||||
/// Interface for mapping a single input value (of an indicated ColumnType) to
|
||||
/// an output value (of an indicated ColumnType). This interface is commonly implemented
|
||||
/// by predictors. Note that the input and output ColumnTypes determine the proper
|
||||
/// type arguments for GetMapper, but typically contain additional information like
|
||||
/// vector lengths.
|
||||
/// </summary>
|
||||
[BestFriend]
|
||||
internal interface IValueMapper
|
||||
{
|
||||
/// <summary>
|
||||
/// Delegate type to map/convert a value.
|
||||
/// </summary>
|
||||
[BestFriend]
|
||||
internal delegate void ValueMapper<TSrc, TDst>(in TSrc src, ref TDst dst);
|
||||
DataViewType InputType { get; }
|
||||
DataViewType OutputType { get; }
|
||||
|
||||
/// <summary>
|
||||
/// Delegate type to map/convert among three values, for example, one input with two
|
||||
/// outputs, or two inputs with one output.
|
||||
/// Get a delegate used for mapping from input to output values. Note that the delegate
|
||||
/// should only be used on a single thread - it should NOT be assumed to be safe for concurrency.
|
||||
/// </summary>
|
||||
[BestFriend]
|
||||
internal delegate void ValueMapper<TVal1, TVal2, TVal3>(in TVal1 val1, ref TVal2 val2, ref TVal3 val3);
|
||||
|
||||
/// <summary>
|
||||
/// Interface for mapping a single input value (of an indicated ColumnType) to
|
||||
/// an output value (of an indicated ColumnType). This interface is commonly implemented
|
||||
/// by predictors. Note that the input and output ColumnTypes determine the proper
|
||||
/// type arguments for GetMapper, but typically contain additional information like
|
||||
/// vector lengths.
|
||||
/// </summary>
|
||||
[BestFriend]
|
||||
internal interface IValueMapper
|
||||
{
|
||||
DataViewType InputType { get; }
|
||||
DataViewType OutputType { get; }
|
||||
|
||||
/// <summary>
|
||||
/// Get a delegate used for mapping from input to output values. Note that the delegate
|
||||
/// should only be used on a single thread - it should NOT be assumed to be safe for concurrency.
|
||||
/// </summary>
|
||||
ValueMapper<TSrc, TDst> GetMapper<TSrc, TDst>();
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Interface for mapping a single input value (of an indicated ColumnType) to an output value
|
||||
/// plus distribution value (of indicated ColumnTypes). This interface is commonly implemented
|
||||
/// by predictors. Note that the input, output, and distribution ColumnTypes determine the proper
|
||||
/// type arguments for GetMapper, but typically contain additional information like
|
||||
/// vector lengths.
|
||||
/// </summary>
|
||||
[BestFriend]
|
||||
internal interface IValueMapperDist : IValueMapper
|
||||
{
|
||||
DataViewType DistType { get; }
|
||||
|
||||
/// <summary>
|
||||
/// Get a delegate used for mapping from input to output values. Note that the delegate
|
||||
/// should only be used on a single thread - it should NOT be assumed to be safe for concurrency.
|
||||
/// </summary>
|
||||
ValueMapper<TSrc, TDst, TDist> GetMapper<TSrc, TDst, TDist>();
|
||||
}
|
||||
ValueMapper<TSrc, TDst> GetMapper<TSrc, TDst>();
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Interface for mapping a single input value (of an indicated ColumnType) to an output value
|
||||
/// plus distribution value (of indicated ColumnTypes). This interface is commonly implemented
|
||||
/// by predictors. Note that the input, output, and distribution ColumnTypes determine the proper
|
||||
/// type arguments for GetMapper, but typically contain additional information like
|
||||
/// vector lengths.
|
||||
/// </summary>
|
||||
[BestFriend]
|
||||
internal interface IValueMapperDist : IValueMapper
|
||||
{
|
||||
DataViewType DistType { get; }
|
||||
|
||||
/// <summary>
|
||||
/// Get a delegate used for mapping from input to output values. Note that the delegate
|
||||
/// should only be used on a single thread - it should NOT be assumed to be safe for concurrency.
|
||||
/// </summary>
|
||||
ValueMapper<TSrc, TDst, TDist> GetMapper<TSrc, TDst, TDist>();
|
||||
}
|
||||
|
|
|
@ -2,8 +2,7 @@
|
|||
// The .NET Foundation licenses this file to you under the MIT license.
|
||||
// See the LICENSE file in the project root for more information.
|
||||
|
||||
namespace Microsoft.ML.Data
|
||||
{
|
||||
[BestFriend]
|
||||
internal delegate bool InPredicate<T>(in T value);
|
||||
}
|
||||
namespace Microsoft.ML.Data;
|
||||
|
||||
[BestFriend]
|
||||
internal delegate bool InPredicate<T>(in T value);
|
||||
|
|
|
@ -4,21 +4,20 @@
|
|||
|
||||
using Microsoft.ML.Runtime;
|
||||
|
||||
namespace Microsoft.ML.Data
|
||||
namespace Microsoft.ML.Data;
|
||||
|
||||
/// <summary>
|
||||
/// Extension methods related to the <see cref="KeyDataViewType"/> class.
|
||||
/// </summary>
|
||||
[BestFriend]
|
||||
internal static class KeyTypeExtensions
|
||||
{
|
||||
/// <summary>
|
||||
/// Extension methods related to the <see cref="KeyDataViewType"/> class.
|
||||
/// Sometimes it is necessary to cast the Count to an int. This performs overflow check.
|
||||
/// </summary>
|
||||
[BestFriend]
|
||||
internal static class KeyTypeExtensions
|
||||
public static int GetCountAsInt32(this KeyDataViewType key, IExceptionContext ectx = null)
|
||||
{
|
||||
/// <summary>
|
||||
/// Sometimes it is necessary to cast the Count to an int. This performs overflow check.
|
||||
/// </summary>
|
||||
public static int GetCountAsInt32(this KeyDataViewType key, IExceptionContext ectx = null)
|
||||
{
|
||||
ectx.Check(key.Count <= int.MaxValue, nameof(KeyDataViewType) + "." + nameof(KeyDataViewType.Count) + " exceeds int.MaxValue.");
|
||||
return (int)key.Count;
|
||||
}
|
||||
ectx.Check(key.Count <= int.MaxValue, nameof(KeyDataViewType) + "." + nameof(KeyDataViewType.Count) + " exceeds int.MaxValue.");
|
||||
return (int)key.Count;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -4,50 +4,49 @@
|
|||
|
||||
using Microsoft.ML.Runtime;
|
||||
|
||||
namespace Microsoft.ML.Data
|
||||
namespace Microsoft.ML.Data;
|
||||
|
||||
/// <summary>
|
||||
/// Base class for a cursor has an input cursor, but still needs to do work on <see cref="DataViewRowCursor.MoveNext"/>.
|
||||
/// </summary>
|
||||
[BestFriend]
|
||||
internal abstract class LinkedRootCursorBase : RootCursorBase
|
||||
{
|
||||
|
||||
/// <summary>Gets the input cursor.</summary>
|
||||
protected DataViewRowCursor Input { get; }
|
||||
|
||||
/// <summary>
|
||||
/// Base class for a cursor has an input cursor, but still needs to do work on <see cref="DataViewRowCursor.MoveNext"/>.
|
||||
/// Returns the root cursor of the input. It should be used to perform <see cref="DataViewRowCursor.MoveNext"/>
|
||||
/// operations, but with the distinction, as compared to <see cref="SynchronizedCursorBase"/>, that this is not
|
||||
/// a simple passthrough, but rather very implementation specific. For example, a common usage of this class is
|
||||
/// on filter cursor implementations, where how that input cursor is consumed is very implementation specific.
|
||||
/// That is why this is <see langword="protected"/>, not <see langword="private"/>.
|
||||
/// </summary>
|
||||
[BestFriend]
|
||||
internal abstract class LinkedRootCursorBase : RootCursorBase
|
||||
protected DataViewRowCursor Root { get; }
|
||||
|
||||
private bool _disposed;
|
||||
|
||||
protected LinkedRootCursorBase(IChannelProvider provider, DataViewRowCursor input)
|
||||
: base(provider)
|
||||
{
|
||||
Ch.AssertValue(input, nameof(input));
|
||||
|
||||
/// <summary>Gets the input cursor.</summary>
|
||||
protected DataViewRowCursor Input { get; }
|
||||
Input = input;
|
||||
Root = Input is SynchronizedCursorBase snycInput ? snycInput.Root : input;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Returns the root cursor of the input. It should be used to perform <see cref="DataViewRowCursor.MoveNext"/>
|
||||
/// operations, but with the distinction, as compared to <see cref="SynchronizedCursorBase"/>, that this is not
|
||||
/// a simple passthrough, but rather very implementation specific. For example, a common usage of this class is
|
||||
/// on filter cursor implementations, where how that input cursor is consumed is very implementation specific.
|
||||
/// That is why this is <see langword="protected"/>, not <see langword="private"/>.
|
||||
/// </summary>
|
||||
protected DataViewRowCursor Root { get; }
|
||||
|
||||
private bool _disposed;
|
||||
|
||||
protected LinkedRootCursorBase(IChannelProvider provider, DataViewRowCursor input)
|
||||
: base(provider)
|
||||
protected override void Dispose(bool disposing)
|
||||
{
|
||||
if (_disposed)
|
||||
return;
|
||||
if (disposing)
|
||||
{
|
||||
Ch.AssertValue(input, nameof(input));
|
||||
Input.Dispose();
|
||||
// The base class should set the state to done under these circumstances.
|
||||
|
||||
Input = input;
|
||||
Root = Input is SynchronizedCursorBase snycInput ? snycInput.Root : input;
|
||||
}
|
||||
|
||||
protected override void Dispose(bool disposing)
|
||||
{
|
||||
if (_disposed)
|
||||
return;
|
||||
if (disposing)
|
||||
{
|
||||
Input.Dispose();
|
||||
// The base class should set the state to done under these circumstances.
|
||||
|
||||
}
|
||||
_disposed = true;
|
||||
base.Dispose(disposing);
|
||||
}
|
||||
_disposed = true;
|
||||
base.Dispose(disposing);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -4,40 +4,39 @@
|
|||
|
||||
using Microsoft.ML.Runtime;
|
||||
|
||||
namespace Microsoft.ML.Data
|
||||
namespace Microsoft.ML.Data;
|
||||
|
||||
/// <summary>
|
||||
/// Base class for creating a cursor of rows that filters out some input rows.
|
||||
/// </summary>
|
||||
[BestFriend]
|
||||
internal abstract class LinkedRowFilterCursorBase : LinkedRowRootCursorBase
|
||||
{
|
||||
/// <summary>
|
||||
/// Base class for creating a cursor of rows that filters out some input rows.
|
||||
/// </summary>
|
||||
[BestFriend]
|
||||
internal abstract class LinkedRowFilterCursorBase : LinkedRowRootCursorBase
|
||||
public override long Batch => Input.Batch;
|
||||
|
||||
protected LinkedRowFilterCursorBase(IChannelProvider provider, DataViewRowCursor input, DataViewSchema schema, bool[] active)
|
||||
: base(provider, input, schema, active)
|
||||
{
|
||||
public override long Batch => Input.Batch;
|
||||
|
||||
protected LinkedRowFilterCursorBase(IChannelProvider provider, DataViewRowCursor input, DataViewSchema schema, bool[] active)
|
||||
: base(provider, input, schema, active)
|
||||
{
|
||||
}
|
||||
|
||||
public override ValueGetter<DataViewRowId> GetIdGetter()
|
||||
{
|
||||
return Input.GetIdGetter();
|
||||
}
|
||||
|
||||
protected override bool MoveNextCore()
|
||||
{
|
||||
while (Root.MoveNext())
|
||||
{
|
||||
if (Accept())
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Return whether the current input row should be returned by this cursor.
|
||||
/// </summary>
|
||||
protected abstract bool Accept();
|
||||
}
|
||||
|
||||
public override ValueGetter<DataViewRowId> GetIdGetter()
|
||||
{
|
||||
return Input.GetIdGetter();
|
||||
}
|
||||
|
||||
protected override bool MoveNextCore()
|
||||
{
|
||||
while (Root.MoveNext())
|
||||
{
|
||||
if (Accept())
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Return whether the current input row should be returned by this cursor.
|
||||
/// </summary>
|
||||
protected abstract bool Accept();
|
||||
}
|
||||
|
|
|
@ -4,50 +4,49 @@
|
|||
|
||||
using Microsoft.ML.Runtime;
|
||||
|
||||
namespace Microsoft.ML.Data
|
||||
namespace Microsoft.ML.Data;
|
||||
|
||||
/// <summary>
|
||||
/// A base class for a <see cref="DataViewRowCursor"/> that has an input cursor, but still needs to do work on
|
||||
/// <see cref="DataViewRowCursor.MoveNext"/>. Note that the default
|
||||
/// <see cref="LinkedRowRootCursorBase.GetGetter{TValue}(DataViewSchema.Column)"/> assumes that each input column is exposed as an
|
||||
/// output column with the same column index.
|
||||
/// </summary>
|
||||
[BestFriend]
|
||||
internal abstract class LinkedRowRootCursorBase : LinkedRootCursorBase
|
||||
{
|
||||
/// <summary>
|
||||
/// A base class for a <see cref="DataViewRowCursor"/> that has an input cursor, but still needs to do work on
|
||||
/// <see cref="DataViewRowCursor.MoveNext"/>. Note that the default
|
||||
/// <see cref="LinkedRowRootCursorBase.GetGetter{TValue}(DataViewSchema.Column)"/> assumes that each input column is exposed as an
|
||||
/// output column with the same column index.
|
||||
/// </summary>
|
||||
[BestFriend]
|
||||
internal abstract class LinkedRowRootCursorBase : LinkedRootCursorBase
|
||||
private readonly bool[] _active;
|
||||
|
||||
/// <summary>Gets row's schema.</summary>
|
||||
public sealed override DataViewSchema Schema { get; }
|
||||
|
||||
protected LinkedRowRootCursorBase(IChannelProvider provider, DataViewRowCursor input, DataViewSchema schema, bool[] active)
|
||||
: base(provider, input)
|
||||
{
|
||||
private readonly bool[] _active;
|
||||
Ch.CheckValue(schema, nameof(schema));
|
||||
Ch.Check(active == null || active.Length == schema.Count);
|
||||
_active = active;
|
||||
Schema = schema;
|
||||
}
|
||||
|
||||
/// <summary>Gets row's schema.</summary>
|
||||
public sealed override DataViewSchema Schema { get; }
|
||||
/// <summary>
|
||||
/// Returns whether the given column is active in this row.
|
||||
/// </summary>
|
||||
public sealed override bool IsColumnActive(DataViewSchema.Column column)
|
||||
{
|
||||
Ch.Check(column.Index < Schema.Count);
|
||||
return _active == null || _active[column.Index];
|
||||
}
|
||||
|
||||
protected LinkedRowRootCursorBase(IChannelProvider provider, DataViewRowCursor input, DataViewSchema schema, bool[] active)
|
||||
: base(provider, input)
|
||||
{
|
||||
Ch.CheckValue(schema, nameof(schema));
|
||||
Ch.Check(active == null || active.Length == schema.Count);
|
||||
_active = active;
|
||||
Schema = schema;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Returns whether the given column is active in this row.
|
||||
/// </summary>
|
||||
public sealed override bool IsColumnActive(DataViewSchema.Column column)
|
||||
{
|
||||
Ch.Check(column.Index < Schema.Count);
|
||||
return _active == null || _active[column.Index];
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Returns a value getter delegate to fetch the value of column with the given columnIndex, from the row.
|
||||
/// This throws if the column is not active in this row, or if the type
|
||||
/// <typeparamref name="TValue"/> differs from this column's type.
|
||||
/// </summary>
|
||||
/// <typeparam name="TValue"> is the column's content type.</typeparam>
|
||||
/// <param name="column"> is the output column whose getter should be returned.</param>
|
||||
public override ValueGetter<TValue> GetGetter<TValue>(DataViewSchema.Column column)
|
||||
{
|
||||
return Input.GetGetter<TValue>(column);
|
||||
}
|
||||
/// <summary>
|
||||
/// Returns a value getter delegate to fetch the value of column with the given columnIndex, from the row.
|
||||
/// This throws if the column is not active in this row, or if the type
|
||||
/// <typeparamref name="TValue"/> differs from this column's type.
|
||||
/// </summary>
|
||||
/// <typeparam name="TValue"> is the column's content type.</typeparam>
|
||||
/// <param name="column"> is the output column whose getter should be returned.</param>
|
||||
public override ValueGetter<TValue> GetGetter<TValue>(DataViewSchema.Column column)
|
||||
{
|
||||
return Input.GetGetter<TValue>(column);
|
||||
}
|
||||
}
|
||||
|
|
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
|
@ -8,183 +8,182 @@ using System.Text;
|
|||
using Microsoft.ML.Internal.Utilities;
|
||||
using Microsoft.ML.Runtime;
|
||||
|
||||
namespace Microsoft.ML
|
||||
namespace Microsoft.ML;
|
||||
|
||||
/// <summary>
|
||||
/// This is a convenience context object for loading models from a repository, for
|
||||
/// implementors of ICanSaveModel. It is not mandated but designed to reduce the
|
||||
/// amount of boiler plate code. It can also be used when loading from a single stream,
|
||||
/// for implementors of ICanSaveInBinaryFormat.
|
||||
/// </summary>
|
||||
[BestFriend]
|
||||
internal sealed partial class ModelLoadContext : IDisposable
|
||||
{
|
||||
/// <summary>
|
||||
/// This is a convenience context object for loading models from a repository, for
|
||||
/// implementors of ICanSaveModel. It is not mandated but designed to reduce the
|
||||
/// amount of boiler plate code. It can also be used when loading from a single stream,
|
||||
/// for implementors of ICanSaveInBinaryFormat.
|
||||
/// When in repository mode, this is the repository we're reading from. It is null when
|
||||
/// in single-stream mode.
|
||||
/// </summary>
|
||||
public readonly RepositoryReader Repository;
|
||||
|
||||
/// <summary>
|
||||
/// When in repository mode, this is the directory we're reading from. Null means the root
|
||||
/// of the repository. It is always null in single-stream mode.
|
||||
/// </summary>
|
||||
public readonly string Directory;
|
||||
|
||||
/// <summary>
|
||||
/// The main stream reader.
|
||||
/// </summary>
|
||||
public readonly BinaryReader Reader;
|
||||
|
||||
/// <summary>
|
||||
/// The strings loaded from the main stream's string table.
|
||||
/// </summary>
|
||||
public readonly string[] Strings;
|
||||
|
||||
/// <summary>
|
||||
/// The name of the assembly that the loader lives in.
|
||||
/// </summary>
|
||||
/// <remarks>
|
||||
/// This may be null or empty if one was never written to the model, or is an older model version.
|
||||
/// </remarks>
|
||||
public readonly string LoaderAssemblyName;
|
||||
|
||||
/// <summary>
|
||||
/// The main stream's model header.
|
||||
/// </summary>
|
||||
[BestFriend]
|
||||
internal sealed partial class ModelLoadContext : IDisposable
|
||||
internal ModelHeader Header;
|
||||
|
||||
/// <summary>
|
||||
/// The min file position of the main stream.
|
||||
/// </summary>
|
||||
public readonly long FpMin;
|
||||
|
||||
/// <summary>
|
||||
/// Exception context provided by Repository (can be null).
|
||||
/// </summary>
|
||||
private readonly IExceptionContext _ectx;
|
||||
|
||||
/// <summary>
|
||||
/// Returns whether this context is in repository mode (true) or single-stream mode (false).
|
||||
/// </summary>
|
||||
public bool InRepository { get { return Repository != null; } }
|
||||
|
||||
/// <summary>
|
||||
/// Create a ModelLoadContext supporting loading from a repository, for implementors of ICanSaveModel.
|
||||
/// </summary>
|
||||
internal ModelLoadContext(RepositoryReader rep, Repository.Entry ent, string dir)
|
||||
{
|
||||
/// <summary>
|
||||
/// When in repository mode, this is the repository we're reading from. It is null when
|
||||
/// in single-stream mode.
|
||||
/// </summary>
|
||||
public readonly RepositoryReader Repository;
|
||||
Contracts.CheckValue(rep, nameof(rep));
|
||||
Repository = rep;
|
||||
_ectx = rep.ExceptionContext;
|
||||
|
||||
/// <summary>
|
||||
/// When in repository mode, this is the directory we're reading from. Null means the root
|
||||
/// of the repository. It is always null in single-stream mode.
|
||||
/// </summary>
|
||||
public readonly string Directory;
|
||||
_ectx.CheckValue(ent, nameof(ent));
|
||||
_ectx.CheckValueOrNull(dir);
|
||||
|
||||
/// <summary>
|
||||
/// The main stream reader.
|
||||
/// </summary>
|
||||
public readonly BinaryReader Reader;
|
||||
Directory = dir;
|
||||
|
||||
/// <summary>
|
||||
/// The strings loaded from the main stream's string table.
|
||||
/// </summary>
|
||||
public readonly string[] Strings;
|
||||
|
||||
/// <summary>
|
||||
/// The name of the assembly that the loader lives in.
|
||||
/// </summary>
|
||||
/// <remarks>
|
||||
/// This may be null or empty if one was never written to the model, or is an older model version.
|
||||
/// </remarks>
|
||||
public readonly string LoaderAssemblyName;
|
||||
|
||||
/// <summary>
|
||||
/// The main stream's model header.
|
||||
/// </summary>
|
||||
[BestFriend]
|
||||
internal ModelHeader Header;
|
||||
|
||||
/// <summary>
|
||||
/// The min file position of the main stream.
|
||||
/// </summary>
|
||||
public readonly long FpMin;
|
||||
|
||||
/// <summary>
|
||||
/// Exception context provided by Repository (can be null).
|
||||
/// </summary>
|
||||
private readonly IExceptionContext _ectx;
|
||||
|
||||
/// <summary>
|
||||
/// Returns whether this context is in repository mode (true) or single-stream mode (false).
|
||||
/// </summary>
|
||||
public bool InRepository { get { return Repository != null; } }
|
||||
|
||||
/// <summary>
|
||||
/// Create a ModelLoadContext supporting loading from a repository, for implementors of ICanSaveModel.
|
||||
/// </summary>
|
||||
internal ModelLoadContext(RepositoryReader rep, Repository.Entry ent, string dir)
|
||||
Reader = new BinaryReader(ent.Stream, Encoding.UTF8, leaveOpen: true);
|
||||
try
|
||||
{
|
||||
Contracts.CheckValue(rep, nameof(rep));
|
||||
Repository = rep;
|
||||
_ectx = rep.ExceptionContext;
|
||||
|
||||
_ectx.CheckValue(ent, nameof(ent));
|
||||
_ectx.CheckValueOrNull(dir);
|
||||
|
||||
Directory = dir;
|
||||
|
||||
Reader = new BinaryReader(ent.Stream, Encoding.UTF8, leaveOpen: true);
|
||||
try
|
||||
{
|
||||
ModelHeader.BeginRead(out FpMin, out Header, out Strings, out LoaderAssemblyName, Reader);
|
||||
}
|
||||
catch
|
||||
{
|
||||
Reader.Dispose();
|
||||
throw;
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Create a ModelLoadContext supporting loading from a single-stream, for implementors of ICanSaveInBinaryFormat.
|
||||
/// </summary>
|
||||
internal ModelLoadContext(BinaryReader reader, IExceptionContext ectx = null)
|
||||
{
|
||||
Contracts.AssertValueOrNull(ectx);
|
||||
_ectx = ectx;
|
||||
_ectx.CheckValue(reader, nameof(reader));
|
||||
|
||||
Repository = null;
|
||||
Directory = null;
|
||||
Reader = reader;
|
||||
ModelHeader.BeginRead(out FpMin, out Header, out Strings, out LoaderAssemblyName, Reader);
|
||||
}
|
||||
|
||||
public void CheckAtModel()
|
||||
catch
|
||||
{
|
||||
_ectx.Check(Reader.BaseStream.Position == FpMin + Header.FpModel);
|
||||
}
|
||||
|
||||
public void CheckAtModel(VersionInfo ver)
|
||||
{
|
||||
_ectx.Check(Reader.BaseStream.Position == FpMin + Header.FpModel);
|
||||
ModelHeader.CheckVersionInfo(ref Header, ver);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Performs version checks.
|
||||
/// </summary>
|
||||
public void CheckVersionInfo(VersionInfo ver)
|
||||
{
|
||||
ModelHeader.CheckVersionInfo(ref Header, ver);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Reads an integer from the load context's reader, and returns the associated string,
|
||||
/// or null (encoded as -1).
|
||||
/// </summary>
|
||||
public string LoadStringOrNull()
|
||||
{
|
||||
int id = Reader.ReadInt32();
|
||||
// Note that -1 means null. Empty strings are in the string table.
|
||||
_ectx.CheckDecode(-1 <= id && id < Utils.Size(Strings));
|
||||
if (id >= 0)
|
||||
return Strings[id];
|
||||
return null;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Reads an integer from the load context's reader, and returns the associated string.
|
||||
/// </summary>
|
||||
public string LoadString()
|
||||
{
|
||||
int id = Reader.ReadInt32();
|
||||
Contracts.CheckDecode(0 <= id && id < Utils.Size(Strings));
|
||||
return Strings[id];
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Reads an integer from the load context's reader, and returns the associated string.
|
||||
/// Throws if the string is empty or null.
|
||||
/// </summary>
|
||||
public string LoadNonEmptyString()
|
||||
{
|
||||
int id = Reader.ReadInt32();
|
||||
_ectx.CheckDecode(0 <= id && id < Utils.Size(Strings));
|
||||
var str = Strings[id];
|
||||
_ectx.CheckDecode(str.Length > 0);
|
||||
return str;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Commit the load operation. This completes reading of the main stream. When in repository
|
||||
/// mode, it disposes the Reader (but not the repository).
|
||||
/// </summary>
|
||||
public void Done()
|
||||
{
|
||||
ModelHeader.EndRead(FpMin, ref Header, Reader);
|
||||
Dispose();
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// When in repository mode, this disposes the Reader (but no the repository).
|
||||
/// </summary>
|
||||
public void Dispose()
|
||||
{
|
||||
// When in single-stream mode, we don't own the Reader.
|
||||
if (InRepository)
|
||||
Reader.Dispose();
|
||||
Reader.Dispose();
|
||||
throw;
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Create a ModelLoadContext supporting loading from a single-stream, for implementors of ICanSaveInBinaryFormat.
|
||||
/// </summary>
|
||||
internal ModelLoadContext(BinaryReader reader, IExceptionContext ectx = null)
|
||||
{
|
||||
Contracts.AssertValueOrNull(ectx);
|
||||
_ectx = ectx;
|
||||
_ectx.CheckValue(reader, nameof(reader));
|
||||
|
||||
Repository = null;
|
||||
Directory = null;
|
||||
Reader = reader;
|
||||
ModelHeader.BeginRead(out FpMin, out Header, out Strings, out LoaderAssemblyName, Reader);
|
||||
}
|
||||
|
||||
public void CheckAtModel()
|
||||
{
|
||||
_ectx.Check(Reader.BaseStream.Position == FpMin + Header.FpModel);
|
||||
}
|
||||
|
||||
public void CheckAtModel(VersionInfo ver)
|
||||
{
|
||||
_ectx.Check(Reader.BaseStream.Position == FpMin + Header.FpModel);
|
||||
ModelHeader.CheckVersionInfo(ref Header, ver);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Performs version checks.
|
||||
/// </summary>
|
||||
public void CheckVersionInfo(VersionInfo ver)
|
||||
{
|
||||
ModelHeader.CheckVersionInfo(ref Header, ver);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Reads an integer from the load context's reader, and returns the associated string,
|
||||
/// or null (encoded as -1).
|
||||
/// </summary>
|
||||
public string LoadStringOrNull()
|
||||
{
|
||||
int id = Reader.ReadInt32();
|
||||
// Note that -1 means null. Empty strings are in the string table.
|
||||
_ectx.CheckDecode(-1 <= id && id < Utils.Size(Strings));
|
||||
if (id >= 0)
|
||||
return Strings[id];
|
||||
return null;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Reads an integer from the load context's reader, and returns the associated string.
|
||||
/// </summary>
|
||||
public string LoadString()
|
||||
{
|
||||
int id = Reader.ReadInt32();
|
||||
Contracts.CheckDecode(0 <= id && id < Utils.Size(Strings));
|
||||
return Strings[id];
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Reads an integer from the load context's reader, and returns the associated string.
|
||||
/// Throws if the string is empty or null.
|
||||
/// </summary>
|
||||
public string LoadNonEmptyString()
|
||||
{
|
||||
int id = Reader.ReadInt32();
|
||||
_ectx.CheckDecode(0 <= id && id < Utils.Size(Strings));
|
||||
var str = Strings[id];
|
||||
_ectx.CheckDecode(str.Length > 0);
|
||||
return str;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Commit the load operation. This completes reading of the main stream. When in repository
|
||||
/// mode, it disposes the Reader (but not the repository).
|
||||
/// </summary>
|
||||
public void Done()
|
||||
{
|
||||
ModelHeader.EndRead(FpMin, ref Header, Reader);
|
||||
Dispose();
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// When in repository mode, this disposes the Reader (but no the repository).
|
||||
/// </summary>
|
||||
public void Dispose()
|
||||
{
|
||||
// When in single-stream mode, we don't own the Reader.
|
||||
if (InRepository)
|
||||
Reader.Dispose();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -9,357 +9,356 @@ using System.Text;
|
|||
using Microsoft.ML.Internal.Utilities;
|
||||
using Microsoft.ML.Runtime;
|
||||
|
||||
namespace Microsoft.ML
|
||||
namespace Microsoft.ML;
|
||||
|
||||
/// <summary>
|
||||
/// Signature for a repository based model loader. This is the dual of <see cref="ICanSaveModel"/>.
|
||||
/// </summary>
|
||||
[BestFriend]
|
||||
internal delegate void SignatureLoadModel(ModelLoadContext ctx);
|
||||
|
||||
internal sealed partial class ModelLoadContext : IDisposable
|
||||
{
|
||||
public const string ModelStreamName = "Model.key";
|
||||
internal const string NameBinary = "Model.bin";
|
||||
|
||||
/// <summary>
|
||||
/// Signature for a repository based model loader. This is the dual of <see cref="ICanSaveModel"/>.
|
||||
/// Returns the new assembly name to maintain backward compatibility.
|
||||
/// </summary>
|
||||
[BestFriend]
|
||||
internal delegate void SignatureLoadModel(ModelLoadContext ctx);
|
||||
|
||||
internal sealed partial class ModelLoadContext : IDisposable
|
||||
private string ForwardedLoaderAssemblyName
|
||||
{
|
||||
public const string ModelStreamName = "Model.key";
|
||||
internal const string NameBinary = "Model.bin";
|
||||
|
||||
/// <summary>
|
||||
/// Returns the new assembly name to maintain backward compatibility.
|
||||
/// </summary>
|
||||
private string ForwardedLoaderAssemblyName
|
||||
get
|
||||
{
|
||||
get
|
||||
string[] nameDetails = LoaderAssemblyName.Split(',');
|
||||
switch (nameDetails[0])
|
||||
{
|
||||
string[] nameDetails = LoaderAssemblyName.Split(',');
|
||||
switch (nameDetails[0])
|
||||
{
|
||||
case "Microsoft.ML.HalLearners":
|
||||
nameDetails[0] = "Microsoft.ML.Mkl.Components";
|
||||
break;
|
||||
case "Microsoft.ML.StandardLearners":
|
||||
nameDetails[0] = "Microsoft.ML.StandardTrainers";
|
||||
break;
|
||||
default:
|
||||
return LoaderAssemblyName;
|
||||
}
|
||||
|
||||
return string.Join(",", nameDetails);
|
||||
case "Microsoft.ML.HalLearners":
|
||||
nameDetails[0] = "Microsoft.ML.Mkl.Components";
|
||||
break;
|
||||
case "Microsoft.ML.StandardLearners":
|
||||
nameDetails[0] = "Microsoft.ML.StandardTrainers";
|
||||
break;
|
||||
default:
|
||||
return LoaderAssemblyName;
|
||||
}
|
||||
|
||||
return string.Join(",", nameDetails);
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Return whether this context contains a directory and stream for a sub-model with
|
||||
/// the indicated name. This does not attempt to load the sub-model.
|
||||
/// </summary>
|
||||
public bool ContainsModel(string name)
|
||||
{
|
||||
if (!InRepository)
|
||||
return false;
|
||||
if (string.IsNullOrEmpty(name))
|
||||
return false;
|
||||
|
||||
var dir = Path.Combine(Directory ?? "", name);
|
||||
var ent = Repository.OpenEntryOrNull(dir, ModelStreamName);
|
||||
if (ent != null)
|
||||
{
|
||||
ent.Dispose();
|
||||
return true;
|
||||
}
|
||||
|
||||
if ((ent = Repository.OpenEntryOrNull(dir, NameBinary)) != null)
|
||||
{
|
||||
ent.Dispose();
|
||||
return true;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Return whether this context contains a directory and stream for a sub-model with
|
||||
/// the indicated name. This does not attempt to load the sub-model.
|
||||
/// </summary>
|
||||
public bool ContainsModel(string name)
|
||||
{
|
||||
if (!InRepository)
|
||||
return false;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Load an optional object from the repository directory.
|
||||
/// Returns false iff no stream was found for the object, iff result is set to null.
|
||||
/// Throws if loading fails for any other reason.
|
||||
/// </summary>
|
||||
public static bool LoadModelOrNull<TRes, TSig>(IHostEnvironment env, out TRes result, RepositoryReader rep, string dir, params object[] extra)
|
||||
where TRes : class
|
||||
{
|
||||
Contracts.CheckValue(env, nameof(env));
|
||||
env.CheckValue(rep, nameof(rep));
|
||||
var ent = rep.OpenEntryOrNull(dir, ModelStreamName);
|
||||
if (ent != null)
|
||||
{
|
||||
using (ent)
|
||||
{
|
||||
// Provide the repository, entry, and directory name to the loadable class ctor.
|
||||
env.Assert(ent.Stream.Position == 0);
|
||||
LoadModel<TRes, TSig>(env, out result, rep, ent, dir, extra);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
if ((ent = rep.OpenEntryOrNull(dir, NameBinary)) != null)
|
||||
{
|
||||
using (ent)
|
||||
{
|
||||
env.Assert(ent.Stream.Position == 0);
|
||||
LoadModel<TRes, TSig>(env, out result, ent.Stream, extra);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
result = null;
|
||||
if (string.IsNullOrEmpty(name))
|
||||
return false;
|
||||
|
||||
var dir = Path.Combine(Directory ?? "", name);
|
||||
var ent = Repository.OpenEntryOrNull(dir, ModelStreamName);
|
||||
if (ent != null)
|
||||
{
|
||||
ent.Dispose();
|
||||
return true;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Load an object from the repository directory.
|
||||
/// </summary>
|
||||
public static void LoadModel<TRes, TSig>(IHostEnvironment env, out TRes result, RepositoryReader rep, string dir, params object[] extra)
|
||||
where TRes : class
|
||||
if ((ent = Repository.OpenEntryOrNull(dir, NameBinary)) != null)
|
||||
{
|
||||
Contracts.CheckValue(env, nameof(env));
|
||||
env.CheckValue(rep, nameof(rep));
|
||||
if (!LoadModelOrNull<TRes, TSig>(env, out result, rep, dir, extra))
|
||||
throw env.ExceptDecode("Corrupt model file");
|
||||
env.AssertValue(result);
|
||||
ent.Dispose();
|
||||
return true;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Load a sub model from the given sub directory if it exists. This requires InRepository to be true.
|
||||
/// Returns false iff no stream was found for the object, iff result is set to null.
|
||||
/// Throws if loading fails for any other reason.
|
||||
/// </summary>
|
||||
public bool LoadModelOrNull<TRes, TSig>(IHostEnvironment env, out TRes result, string name, params object[] extra)
|
||||
where TRes : class
|
||||
return false;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Load an optional object from the repository directory.
|
||||
/// Returns false iff no stream was found for the object, iff result is set to null.
|
||||
/// Throws if loading fails for any other reason.
|
||||
/// </summary>
|
||||
public static bool LoadModelOrNull<TRes, TSig>(IHostEnvironment env, out TRes result, RepositoryReader rep, string dir, params object[] extra)
|
||||
where TRes : class
|
||||
{
|
||||
Contracts.CheckValue(env, nameof(env));
|
||||
env.CheckValue(rep, nameof(rep));
|
||||
var ent = rep.OpenEntryOrNull(dir, ModelStreamName);
|
||||
if (ent != null)
|
||||
{
|
||||
_ectx.CheckValue(env, nameof(env));
|
||||
_ectx.Check(InRepository, "Can't load a sub-model when reading from a single stream");
|
||||
return LoadModelOrNull<TRes, TSig>(env, out result, Repository, Path.Combine(Directory ?? "", name), extra);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Load a sub model from the given sub directory. This requires InRepository to be true.
|
||||
/// </summary>
|
||||
public void LoadModel<TRes, TSig>(IHostEnvironment env, out TRes result, string name, params object[] extra)
|
||||
where TRes : class
|
||||
{
|
||||
_ectx.CheckValue(env, nameof(env));
|
||||
if (!LoadModelOrNull<TRes, TSig>(env, out result, name, extra))
|
||||
throw _ectx.ExceptDecode("Corrupt model file");
|
||||
_ectx.AssertValue(result);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Try to load from the given repository entry using the default loader(s) specified in the header.
|
||||
/// Returns false iff the default loader(s) could not be bound to a compatible loadable class.
|
||||
/// </summary>
|
||||
private static bool TryLoadModel<TRes, TSig>(IHostEnvironment env, out TRes result, RepositoryReader rep, Repository.Entry ent, string dir, params object[] extra)
|
||||
where TRes : class
|
||||
{
|
||||
Contracts.CheckValue(env, nameof(env));
|
||||
env.CheckValue(rep, nameof(rep));
|
||||
long fp = ent.Stream.Position;
|
||||
using (var ctx = new ModelLoadContext(rep, ent, dir))
|
||||
{
|
||||
env.Assert(fp == ctx.FpMin);
|
||||
if (ctx.TryLoadModelCore<TRes, TSig>(env, out result, extra))
|
||||
return true;
|
||||
}
|
||||
|
||||
// TryLoadModelCore should rewind on failure.
|
||||
Contracts.Assert(fp == ent.Stream.Position);
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Load from the given repository entry using the default loader(s) specified in the header.
|
||||
/// </summary>
|
||||
public static void LoadModel<TRes, TSig>(IHostEnvironment env, out TRes result, RepositoryReader rep, Repository.Entry ent, string dir, params object[] extra)
|
||||
where TRes : class
|
||||
{
|
||||
Contracts.CheckValue(env, nameof(env));
|
||||
env.CheckValue(rep, nameof(rep));
|
||||
if (!TryLoadModel<TRes, TSig>(env, out result, rep, ent, dir, extra))
|
||||
throw env.ExceptDecode("Couldn't load model: '{0}'", dir);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Try to load from the given stream (non-Repository).
|
||||
/// Returns false iff the default loader(s) could not be bound to a compatible loadable class.
|
||||
/// </summary>
|
||||
public static bool TryLoadModel<TRes, TSig>(IHostEnvironment env, out TRes result, Stream stream, params object[] extra)
|
||||
where TRes : class
|
||||
{
|
||||
Contracts.CheckValue(env, nameof(env));
|
||||
using (var reader = new BinaryReader(stream, Encoding.UTF8, leaveOpen: true))
|
||||
return TryLoadModel<TRes, TSig>(env, out result, reader, extra);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Load from the given stream (non-Repository) using the default loader(s) specified in the header.
|
||||
/// </summary>
|
||||
public static void LoadModel<TRes, TSig>(IHostEnvironment env, out TRes result, Stream stream, params object[] extra)
|
||||
where TRes : class
|
||||
{
|
||||
Contracts.CheckValue(env, nameof(env));
|
||||
if (!TryLoadModel<TRes, TSig>(env, out result, stream, extra))
|
||||
throw Contracts.ExceptDecode("Couldn't load model");
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Try to load from the given reader (non-Repository).
|
||||
/// Returns false iff the default loader(s) could not be bound to a compatible loadable class.
|
||||
/// </summary>
|
||||
public static bool TryLoadModel<TRes, TSig>(IHostEnvironment env, out TRes result, BinaryReader reader, params object[] extra)
|
||||
where TRes : class
|
||||
{
|
||||
Contracts.CheckValue(env, nameof(env));
|
||||
long fp = reader.BaseStream.Position;
|
||||
using (var ctx = new ModelLoadContext(reader))
|
||||
{
|
||||
Contracts.Assert(fp == ctx.FpMin);
|
||||
return ctx.TryLoadModelCore<TRes, TSig>(env, out result, extra);
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Load from the given reader (non-Repository) using the default loader(s) specified in the header.
|
||||
/// </summary>
|
||||
public static void LoadModel<TRes, TSig>(IHostEnvironment env, out TRes result, BinaryReader reader, params object[] extra)
|
||||
where TRes : class
|
||||
{
|
||||
Contracts.CheckValue(env, nameof(env));
|
||||
if (!TryLoadModel<TRes, TSig>(env, out result, reader, extra))
|
||||
throw Contracts.ExceptDecode("Couldn't load model");
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Tries to load.
|
||||
/// Returns false iff the default loader(s) could not be bound to a compatible loadable class.
|
||||
/// </summary>
|
||||
private bool TryLoadModelCore<TRes, TSig>(IHostEnvironment env, out TRes result, params object[] extra)
|
||||
where TRes : class
|
||||
{
|
||||
_ectx.AssertValue(env, "env");
|
||||
_ectx.Assert(Reader.BaseStream.Position == FpMin + Header.FpModel);
|
||||
|
||||
var args = ConcatArgsRev(extra, this);
|
||||
|
||||
EnsureLoaderAssemblyIsRegistered(env.ComponentCatalog);
|
||||
|
||||
object tmp;
|
||||
string sig = ModelHeader.GetLoaderSig(ref Header);
|
||||
if (!string.IsNullOrWhiteSpace(sig) &&
|
||||
ComponentCatalog.TryCreateInstance<object, TSig>(env, out tmp, sig, "", args))
|
||||
{
|
||||
result = tmp as TRes;
|
||||
if (result != null)
|
||||
{
|
||||
Done();
|
||||
return true;
|
||||
}
|
||||
// REVIEW: Should this fall through?
|
||||
}
|
||||
_ectx.Assert(Reader.BaseStream.Position == FpMin + Header.FpModel);
|
||||
|
||||
string sigAlt = ModelHeader.GetLoaderSigAlt(ref Header);
|
||||
if (!string.IsNullOrWhiteSpace(sigAlt) &&
|
||||
ComponentCatalog.TryCreateInstance<object, TSig>(env, out tmp, sigAlt, "", args))
|
||||
{
|
||||
result = tmp as TRes;
|
||||
if (result != null)
|
||||
{
|
||||
Done();
|
||||
return true;
|
||||
}
|
||||
// REVIEW: Should this fall through?
|
||||
}
|
||||
_ectx.Assert(Reader.BaseStream.Position == FpMin + Header.FpModel);
|
||||
|
||||
Reader.BaseStream.Position = FpMin;
|
||||
result = null;
|
||||
return false;
|
||||
}
|
||||
|
||||
private void EnsureLoaderAssemblyIsRegistered(ComponentCatalog catalog)
|
||||
{
|
||||
if (!string.IsNullOrEmpty(LoaderAssemblyName))
|
||||
{
|
||||
var assembly = Assembly.Load(ForwardedLoaderAssemblyName);
|
||||
catalog.RegisterAssembly(assembly);
|
||||
}
|
||||
}
|
||||
|
||||
private static object[] ConcatArgsRev(object[] args2, params object[] args1)
|
||||
{
|
||||
Contracts.AssertNonEmpty(args1);
|
||||
return Utils.Concat(args1, args2);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Try to load a sub model from the given sub directory. This requires InRepository to be true.
|
||||
/// </summary>
|
||||
public bool TryProcessSubModel(string dir, Action<ModelLoadContext> action)
|
||||
{
|
||||
_ectx.Check(InRepository, "Can't Load a sub-model when reading from a single stream");
|
||||
_ectx.CheckNonEmpty(dir, nameof(dir));
|
||||
_ectx.CheckValue(action, nameof(action));
|
||||
|
||||
string path = Path.Combine(Directory, dir);
|
||||
var ent = Repository.OpenEntryOrNull(path, ModelStreamName);
|
||||
if (ent == null)
|
||||
return false;
|
||||
|
||||
using (ent)
|
||||
{
|
||||
// Provide the repository, entry, and directory name to the loadable class ctor.
|
||||
_ectx.Assert(ent.Stream.Position == 0);
|
||||
using (var ctx = new ModelLoadContext(Repository, ent, path))
|
||||
action(ctx);
|
||||
env.Assert(ent.Stream.Position == 0);
|
||||
LoadModel<TRes, TSig>(env, out result, rep, ent, dir, extra);
|
||||
return true;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Try to load a binary stream from the current directory. This requires InRepository to be true.
|
||||
/// </summary>
|
||||
public bool TryLoadBinaryStream(string name, Action<BinaryReader> action)
|
||||
if ((ent = rep.OpenEntryOrNull(dir, NameBinary)) != null)
|
||||
{
|
||||
_ectx.Check(InRepository, "Can't Load a sub-model when reading from a single stream");
|
||||
_ectx.CheckNonEmpty(name, nameof(name));
|
||||
_ectx.CheckValue(action, nameof(action));
|
||||
|
||||
var ent = Repository.OpenEntryOrNull(Directory, name);
|
||||
if (ent == null)
|
||||
return false;
|
||||
|
||||
using (ent)
|
||||
using (var reader = new BinaryReader(ent.Stream, Encoding.UTF8, leaveOpen: true))
|
||||
{
|
||||
action(reader);
|
||||
env.Assert(ent.Stream.Position == 0);
|
||||
LoadModel<TRes, TSig>(env, out result, ent.Stream, extra);
|
||||
return true;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Try to load a text stream from the current directory. This requires InRepository to be true.
|
||||
/// </summary>
|
||||
public bool TryLoadTextStream(string name, Action<TextReader> action)
|
||||
result = null;
|
||||
return false;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Load an object from the repository directory.
|
||||
/// </summary>
|
||||
public static void LoadModel<TRes, TSig>(IHostEnvironment env, out TRes result, RepositoryReader rep, string dir, params object[] extra)
|
||||
where TRes : class
|
||||
{
|
||||
Contracts.CheckValue(env, nameof(env));
|
||||
env.CheckValue(rep, nameof(rep));
|
||||
if (!LoadModelOrNull<TRes, TSig>(env, out result, rep, dir, extra))
|
||||
throw env.ExceptDecode("Corrupt model file");
|
||||
env.AssertValue(result);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Load a sub model from the given sub directory if it exists. This requires InRepository to be true.
|
||||
/// Returns false iff no stream was found for the object, iff result is set to null.
|
||||
/// Throws if loading fails for any other reason.
|
||||
/// </summary>
|
||||
public bool LoadModelOrNull<TRes, TSig>(IHostEnvironment env, out TRes result, string name, params object[] extra)
|
||||
where TRes : class
|
||||
{
|
||||
_ectx.CheckValue(env, nameof(env));
|
||||
_ectx.Check(InRepository, "Can't load a sub-model when reading from a single stream");
|
||||
return LoadModelOrNull<TRes, TSig>(env, out result, Repository, Path.Combine(Directory ?? "", name), extra);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Load a sub model from the given sub directory. This requires InRepository to be true.
|
||||
/// </summary>
|
||||
public void LoadModel<TRes, TSig>(IHostEnvironment env, out TRes result, string name, params object[] extra)
|
||||
where TRes : class
|
||||
{
|
||||
_ectx.CheckValue(env, nameof(env));
|
||||
if (!LoadModelOrNull<TRes, TSig>(env, out result, name, extra))
|
||||
throw _ectx.ExceptDecode("Corrupt model file");
|
||||
_ectx.AssertValue(result);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Try to load from the given repository entry using the default loader(s) specified in the header.
|
||||
/// Returns false iff the default loader(s) could not be bound to a compatible loadable class.
|
||||
/// </summary>
|
||||
private static bool TryLoadModel<TRes, TSig>(IHostEnvironment env, out TRes result, RepositoryReader rep, Repository.Entry ent, string dir, params object[] extra)
|
||||
where TRes : class
|
||||
{
|
||||
Contracts.CheckValue(env, nameof(env));
|
||||
env.CheckValue(rep, nameof(rep));
|
||||
long fp = ent.Stream.Position;
|
||||
using (var ctx = new ModelLoadContext(rep, ent, dir))
|
||||
{
|
||||
_ectx.Check(InRepository, "Can't Load a sub-model when reading from a single stream");
|
||||
_ectx.CheckNonEmpty(name, nameof(name));
|
||||
_ectx.CheckValue(action, nameof(action));
|
||||
env.Assert(fp == ctx.FpMin);
|
||||
if (ctx.TryLoadModelCore<TRes, TSig>(env, out result, extra))
|
||||
return true;
|
||||
}
|
||||
|
||||
var ent = Repository.OpenEntryOrNull(Directory, name);
|
||||
if (ent == null)
|
||||
return false;
|
||||
// TryLoadModelCore should rewind on failure.
|
||||
Contracts.Assert(fp == ent.Stream.Position);
|
||||
|
||||
using (ent)
|
||||
using (var reader = new StreamReader(ent.Stream))
|
||||
{
|
||||
action(reader);
|
||||
}
|
||||
return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Load from the given repository entry using the default loader(s) specified in the header.
|
||||
/// </summary>
|
||||
public static void LoadModel<TRes, TSig>(IHostEnvironment env, out TRes result, RepositoryReader rep, Repository.Entry ent, string dir, params object[] extra)
|
||||
where TRes : class
|
||||
{
|
||||
Contracts.CheckValue(env, nameof(env));
|
||||
env.CheckValue(rep, nameof(rep));
|
||||
if (!TryLoadModel<TRes, TSig>(env, out result, rep, ent, dir, extra))
|
||||
throw env.ExceptDecode("Couldn't load model: '{0}'", dir);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Try to load from the given stream (non-Repository).
|
||||
/// Returns false iff the default loader(s) could not be bound to a compatible loadable class.
|
||||
/// </summary>
|
||||
public static bool TryLoadModel<TRes, TSig>(IHostEnvironment env, out TRes result, Stream stream, params object[] extra)
|
||||
where TRes : class
|
||||
{
|
||||
Contracts.CheckValue(env, nameof(env));
|
||||
using (var reader = new BinaryReader(stream, Encoding.UTF8, leaveOpen: true))
|
||||
return TryLoadModel<TRes, TSig>(env, out result, reader, extra);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Load from the given stream (non-Repository) using the default loader(s) specified in the header.
|
||||
/// </summary>
|
||||
public static void LoadModel<TRes, TSig>(IHostEnvironment env, out TRes result, Stream stream, params object[] extra)
|
||||
where TRes : class
|
||||
{
|
||||
Contracts.CheckValue(env, nameof(env));
|
||||
if (!TryLoadModel<TRes, TSig>(env, out result, stream, extra))
|
||||
throw Contracts.ExceptDecode("Couldn't load model");
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Try to load from the given reader (non-Repository).
|
||||
/// Returns false iff the default loader(s) could not be bound to a compatible loadable class.
|
||||
/// </summary>
|
||||
public static bool TryLoadModel<TRes, TSig>(IHostEnvironment env, out TRes result, BinaryReader reader, params object[] extra)
|
||||
where TRes : class
|
||||
{
|
||||
Contracts.CheckValue(env, nameof(env));
|
||||
long fp = reader.BaseStream.Position;
|
||||
using (var ctx = new ModelLoadContext(reader))
|
||||
{
|
||||
Contracts.Assert(fp == ctx.FpMin);
|
||||
return ctx.TryLoadModelCore<TRes, TSig>(env, out result, extra);
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Load from the given reader (non-Repository) using the default loader(s) specified in the header.
|
||||
/// </summary>
|
||||
public static void LoadModel<TRes, TSig>(IHostEnvironment env, out TRes result, BinaryReader reader, params object[] extra)
|
||||
where TRes : class
|
||||
{
|
||||
Contracts.CheckValue(env, nameof(env));
|
||||
if (!TryLoadModel<TRes, TSig>(env, out result, reader, extra))
|
||||
throw Contracts.ExceptDecode("Couldn't load model");
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Tries to load.
|
||||
/// Returns false iff the default loader(s) could not be bound to a compatible loadable class.
|
||||
/// </summary>
|
||||
private bool TryLoadModelCore<TRes, TSig>(IHostEnvironment env, out TRes result, params object[] extra)
|
||||
where TRes : class
|
||||
{
|
||||
_ectx.AssertValue(env, "env");
|
||||
_ectx.Assert(Reader.BaseStream.Position == FpMin + Header.FpModel);
|
||||
|
||||
var args = ConcatArgsRev(extra, this);
|
||||
|
||||
EnsureLoaderAssemblyIsRegistered(env.ComponentCatalog);
|
||||
|
||||
object tmp;
|
||||
string sig = ModelHeader.GetLoaderSig(ref Header);
|
||||
if (!string.IsNullOrWhiteSpace(sig) &&
|
||||
ComponentCatalog.TryCreateInstance<object, TSig>(env, out tmp, sig, "", args))
|
||||
{
|
||||
result = tmp as TRes;
|
||||
if (result != null)
|
||||
{
|
||||
Done();
|
||||
return true;
|
||||
}
|
||||
// REVIEW: Should this fall through?
|
||||
}
|
||||
_ectx.Assert(Reader.BaseStream.Position == FpMin + Header.FpModel);
|
||||
|
||||
string sigAlt = ModelHeader.GetLoaderSigAlt(ref Header);
|
||||
if (!string.IsNullOrWhiteSpace(sigAlt) &&
|
||||
ComponentCatalog.TryCreateInstance<object, TSig>(env, out tmp, sigAlt, "", args))
|
||||
{
|
||||
result = tmp as TRes;
|
||||
if (result != null)
|
||||
{
|
||||
Done();
|
||||
return true;
|
||||
}
|
||||
// REVIEW: Should this fall through?
|
||||
}
|
||||
_ectx.Assert(Reader.BaseStream.Position == FpMin + Header.FpModel);
|
||||
|
||||
Reader.BaseStream.Position = FpMin;
|
||||
result = null;
|
||||
return false;
|
||||
}
|
||||
|
||||
private void EnsureLoaderAssemblyIsRegistered(ComponentCatalog catalog)
|
||||
{
|
||||
if (!string.IsNullOrEmpty(LoaderAssemblyName))
|
||||
{
|
||||
var assembly = Assembly.Load(ForwardedLoaderAssemblyName);
|
||||
catalog.RegisterAssembly(assembly);
|
||||
}
|
||||
}
|
||||
|
||||
private static object[] ConcatArgsRev(object[] args2, params object[] args1)
|
||||
{
|
||||
Contracts.AssertNonEmpty(args1);
|
||||
return Utils.Concat(args1, args2);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Try to load a sub model from the given sub directory. This requires InRepository to be true.
|
||||
/// </summary>
|
||||
public bool TryProcessSubModel(string dir, Action<ModelLoadContext> action)
|
||||
{
|
||||
_ectx.Check(InRepository, "Can't Load a sub-model when reading from a single stream");
|
||||
_ectx.CheckNonEmpty(dir, nameof(dir));
|
||||
_ectx.CheckValue(action, nameof(action));
|
||||
|
||||
string path = Path.Combine(Directory, dir);
|
||||
var ent = Repository.OpenEntryOrNull(path, ModelStreamName);
|
||||
if (ent == null)
|
||||
return false;
|
||||
|
||||
using (ent)
|
||||
{
|
||||
// Provide the repository, entry, and directory name to the loadable class ctor.
|
||||
_ectx.Assert(ent.Stream.Position == 0);
|
||||
using (var ctx = new ModelLoadContext(Repository, ent, path))
|
||||
action(ctx);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Try to load a binary stream from the current directory. This requires InRepository to be true.
|
||||
/// </summary>
|
||||
public bool TryLoadBinaryStream(string name, Action<BinaryReader> action)
|
||||
{
|
||||
_ectx.Check(InRepository, "Can't Load a sub-model when reading from a single stream");
|
||||
_ectx.CheckNonEmpty(name, nameof(name));
|
||||
_ectx.CheckValue(action, nameof(action));
|
||||
|
||||
var ent = Repository.OpenEntryOrNull(Directory, name);
|
||||
if (ent == null)
|
||||
return false;
|
||||
|
||||
using (ent)
|
||||
using (var reader = new BinaryReader(ent.Stream, Encoding.UTF8, leaveOpen: true))
|
||||
{
|
||||
action(reader);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Try to load a text stream from the current directory. This requires InRepository to be true.
|
||||
/// </summary>
|
||||
public bool TryLoadTextStream(string name, Action<TextReader> action)
|
||||
{
|
||||
_ectx.Check(InRepository, "Can't Load a sub-model when reading from a single stream");
|
||||
_ectx.CheckNonEmpty(name, nameof(name));
|
||||
_ectx.CheckValue(action, nameof(action));
|
||||
|
||||
var ent = Repository.OpenEntryOrNull(Directory, name);
|
||||
if (ent == null)
|
||||
return false;
|
||||
|
||||
using (ent)
|
||||
using (var reader = new StreamReader(ent.Stream))
|
||||
{
|
||||
action(reader);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -8,253 +8,252 @@ using System.Text;
|
|||
using Microsoft.ML.Internal.Utilities;
|
||||
using Microsoft.ML.Runtime;
|
||||
|
||||
namespace Microsoft.ML
|
||||
namespace Microsoft.ML;
|
||||
|
||||
/// <summary>
|
||||
/// Convenience context object for saving models to a repository, for
|
||||
/// implementors of <see cref="ICanSaveModel"/>.
|
||||
/// </summary>
|
||||
/// <remarks>
|
||||
/// This class reduces the amount of boiler plate code needed to implement <see cref="ICanSaveModel"/>.
|
||||
/// It can also be used when saving to a single stream, for implementors of <see cref="ICanSaveInBinaryFormat"/>.
|
||||
/// </remarks>
|
||||
public sealed partial class ModelSaveContext : IDisposable
|
||||
{
|
||||
/// <summary>
|
||||
/// Convenience context object for saving models to a repository, for
|
||||
/// implementors of <see cref="ICanSaveModel"/>.
|
||||
/// When in repository mode, this is the repository we're writing to. It is null when
|
||||
/// in single-stream mode.
|
||||
/// </summary>
|
||||
/// <remarks>
|
||||
/// This class reduces the amount of boiler plate code needed to implement <see cref="ICanSaveModel"/>.
|
||||
/// It can also be used when saving to a single stream, for implementors of <see cref="ICanSaveInBinaryFormat"/>.
|
||||
/// </remarks>
|
||||
public sealed partial class ModelSaveContext : IDisposable
|
||||
[BestFriend]
|
||||
internal readonly RepositoryWriter Repository;
|
||||
|
||||
/// <summary>
|
||||
/// When in repository mode, this is the directory we're reading from. Null means the root
|
||||
/// of the repository. It is always null in single-stream mode.
|
||||
/// </summary>
|
||||
[BestFriend]
|
||||
internal readonly string Directory;
|
||||
|
||||
/// <summary>
|
||||
/// The main stream writer.
|
||||
/// </summary>
|
||||
[BestFriend]
|
||||
internal readonly BinaryWriter Writer;
|
||||
|
||||
/// <summary>
|
||||
/// The strings that will be saved in the main stream's string table.
|
||||
/// </summary>
|
||||
[BestFriend]
|
||||
internal readonly NormStr.Pool Strings;
|
||||
|
||||
/// <summary>
|
||||
/// The main stream's model header.
|
||||
/// </summary>
|
||||
[BestFriend]
|
||||
internal ModelHeader Header;
|
||||
|
||||
/// <summary>
|
||||
/// The min file position of the main stream.
|
||||
/// </summary>
|
||||
[BestFriend]
|
||||
internal readonly long FpMin;
|
||||
|
||||
/// <summary>
|
||||
/// The wrapped entry.
|
||||
/// </summary>
|
||||
private readonly Repository.Entry _ent;
|
||||
|
||||
/// <summary>
|
||||
/// Exception context provided by Repository (can be null).
|
||||
/// </summary>
|
||||
private readonly IExceptionContext _ectx;
|
||||
|
||||
/// <summary>
|
||||
/// The assembly name where the loader resides.
|
||||
/// </summary>
|
||||
private string _loaderAssemblyName;
|
||||
|
||||
/// <summary>
|
||||
/// Returns whether this context is in repository mode (true) or single-stream mode (false).
|
||||
/// </summary>
|
||||
[BestFriend]
|
||||
internal bool InRepository => Repository != null;
|
||||
|
||||
/// <summary>
|
||||
/// Create a <see cref="ModelSaveContext"/> supporting saving to a repository, for implementors of <see cref="ICanSaveModel"/>.
|
||||
/// </summary>
|
||||
internal ModelSaveContext(RepositoryWriter rep, string dir, string name)
|
||||
{
|
||||
/// <summary>
|
||||
/// When in repository mode, this is the repository we're writing to. It is null when
|
||||
/// in single-stream mode.
|
||||
/// </summary>
|
||||
[BestFriend]
|
||||
internal readonly RepositoryWriter Repository;
|
||||
Contracts.CheckValue(rep, nameof(rep));
|
||||
Repository = rep;
|
||||
_ectx = rep.ExceptionContext;
|
||||
|
||||
/// <summary>
|
||||
/// When in repository mode, this is the directory we're reading from. Null means the root
|
||||
/// of the repository. It is always null in single-stream mode.
|
||||
/// </summary>
|
||||
[BestFriend]
|
||||
internal readonly string Directory;
|
||||
_ectx.CheckValueOrNull(dir);
|
||||
_ectx.CheckNonEmpty(name, nameof(name));
|
||||
|
||||
/// <summary>
|
||||
/// The main stream writer.
|
||||
/// </summary>
|
||||
[BestFriend]
|
||||
internal readonly BinaryWriter Writer;
|
||||
Directory = dir;
|
||||
Strings = new NormStr.Pool();
|
||||
|
||||
/// <summary>
|
||||
/// The strings that will be saved in the main stream's string table.
|
||||
/// </summary>
|
||||
[BestFriend]
|
||||
internal readonly NormStr.Pool Strings;
|
||||
|
||||
/// <summary>
|
||||
/// The main stream's model header.
|
||||
/// </summary>
|
||||
[BestFriend]
|
||||
internal ModelHeader Header;
|
||||
|
||||
/// <summary>
|
||||
/// The min file position of the main stream.
|
||||
/// </summary>
|
||||
[BestFriend]
|
||||
internal readonly long FpMin;
|
||||
|
||||
/// <summary>
|
||||
/// The wrapped entry.
|
||||
/// </summary>
|
||||
private readonly Repository.Entry _ent;
|
||||
|
||||
/// <summary>
|
||||
/// Exception context provided by Repository (can be null).
|
||||
/// </summary>
|
||||
private readonly IExceptionContext _ectx;
|
||||
|
||||
/// <summary>
|
||||
/// The assembly name where the loader resides.
|
||||
/// </summary>
|
||||
private string _loaderAssemblyName;
|
||||
|
||||
/// <summary>
|
||||
/// Returns whether this context is in repository mode (true) or single-stream mode (false).
|
||||
/// </summary>
|
||||
[BestFriend]
|
||||
internal bool InRepository => Repository != null;
|
||||
|
||||
/// <summary>
|
||||
/// Create a <see cref="ModelSaveContext"/> supporting saving to a repository, for implementors of <see cref="ICanSaveModel"/>.
|
||||
/// </summary>
|
||||
internal ModelSaveContext(RepositoryWriter rep, string dir, string name)
|
||||
_ent = rep.CreateEntry(dir, name);
|
||||
try
|
||||
{
|
||||
Contracts.CheckValue(rep, nameof(rep));
|
||||
Repository = rep;
|
||||
_ectx = rep.ExceptionContext;
|
||||
|
||||
_ectx.CheckValueOrNull(dir);
|
||||
_ectx.CheckNonEmpty(name, nameof(name));
|
||||
|
||||
Directory = dir;
|
||||
Strings = new NormStr.Pool();
|
||||
|
||||
_ent = rep.CreateEntry(dir, name);
|
||||
Writer = new BinaryWriter(_ent.Stream, Encoding.UTF8, leaveOpen: true);
|
||||
try
|
||||
{
|
||||
Writer = new BinaryWriter(_ent.Stream, Encoding.UTF8, leaveOpen: true);
|
||||
try
|
||||
{
|
||||
ModelHeader.BeginWrite(Writer, out FpMin, out Header);
|
||||
}
|
||||
catch
|
||||
{
|
||||
Writer.Dispose();
|
||||
throw;
|
||||
}
|
||||
ModelHeader.BeginWrite(Writer, out FpMin, out Header);
|
||||
}
|
||||
catch
|
||||
{
|
||||
_ent.Dispose();
|
||||
Writer.Dispose();
|
||||
throw;
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Create a <see cref="ModelSaveContext"/> supporting saving to a single-stream, for implementors of <see cref="ICanSaveInBinaryFormat"/>.
|
||||
/// </summary>
|
||||
internal ModelSaveContext(BinaryWriter writer, IExceptionContext ectx = null)
|
||||
catch
|
||||
{
|
||||
Contracts.AssertValueOrNull(ectx);
|
||||
_ectx = ectx;
|
||||
_ectx.CheckValue(writer, nameof(writer));
|
||||
|
||||
Repository = null;
|
||||
Directory = null;
|
||||
_ent = null;
|
||||
|
||||
Strings = new NormStr.Pool();
|
||||
Writer = writer;
|
||||
ModelHeader.BeginWrite(Writer, out FpMin, out Header);
|
||||
_ent.Dispose();
|
||||
throw;
|
||||
}
|
||||
}
|
||||
|
||||
[BestFriend]
|
||||
internal void CheckAtModel()
|
||||
/// <summary>
|
||||
/// Create a <see cref="ModelSaveContext"/> supporting saving to a single-stream, for implementors of <see cref="ICanSaveInBinaryFormat"/>.
|
||||
/// </summary>
|
||||
internal ModelSaveContext(BinaryWriter writer, IExceptionContext ectx = null)
|
||||
{
|
||||
Contracts.AssertValueOrNull(ectx);
|
||||
_ectx = ectx;
|
||||
_ectx.CheckValue(writer, nameof(writer));
|
||||
|
||||
Repository = null;
|
||||
Directory = null;
|
||||
_ent = null;
|
||||
|
||||
Strings = new NormStr.Pool();
|
||||
Writer = writer;
|
||||
ModelHeader.BeginWrite(Writer, out FpMin, out Header);
|
||||
}
|
||||
|
||||
[BestFriend]
|
||||
internal void CheckAtModel()
|
||||
{
|
||||
_ectx.Check(Writer.BaseStream.Position == FpMin + Header.FpModel);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Set the version information in the main stream's header. This should be called before
|
||||
/// <see cref="Done"/> is called.
|
||||
/// </summary>
|
||||
/// <param name="ver"></param>
|
||||
[BestFriend]
|
||||
internal void SetVersionInfo(VersionInfo ver)
|
||||
{
|
||||
ModelHeader.SetVersionInfo(ref Header, ver);
|
||||
_loaderAssemblyName = ver.LoaderAssemblyName;
|
||||
}
|
||||
|
||||
[BestFriend]
|
||||
internal void SaveTextStream(string name, Action<TextWriter> action)
|
||||
{
|
||||
_ectx.Check(InRepository, "Can't save a text stream when writing to a single stream");
|
||||
_ectx.CheckNonEmpty(name, nameof(name));
|
||||
_ectx.CheckValue(action, nameof(action));
|
||||
|
||||
// I verified in the CLR source that the default buffer size is 1024. It's unfortunate
|
||||
// that to set leaveOpen to true, we have to specify the buffer size....
|
||||
using (var ent = Repository.CreateEntry(Directory, name))
|
||||
using (var writer = Utils.OpenWriter(ent.Stream))
|
||||
{
|
||||
_ectx.Check(Writer.BaseStream.Position == FpMin + Header.FpModel);
|
||||
action(writer);
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Set the version information in the main stream's header. This should be called before
|
||||
/// <see cref="Done"/> is called.
|
||||
/// </summary>
|
||||
/// <param name="ver"></param>
|
||||
[BestFriend]
|
||||
internal void SetVersionInfo(VersionInfo ver)
|
||||
[BestFriend]
|
||||
internal void SaveBinaryStream(string name, Action<BinaryWriter> action)
|
||||
{
|
||||
_ectx.Check(InRepository, "Can't save a text stream when writing to a single stream");
|
||||
_ectx.CheckNonEmpty(name, nameof(name));
|
||||
_ectx.CheckValue(action, nameof(action));
|
||||
|
||||
// I verified in the CLR source that the default buffer size is 1024. It's unfortunate
|
||||
// that to set leaveOpen to true, we have to specify the buffer size....
|
||||
using (var ent = Repository.CreateEntry(Directory, name))
|
||||
using (var writer = new BinaryWriter(ent.Stream, Encoding.UTF8, leaveOpen: true))
|
||||
{
|
||||
ModelHeader.SetVersionInfo(ref Header, ver);
|
||||
_loaderAssemblyName = ver.LoaderAssemblyName;
|
||||
action(writer);
|
||||
}
|
||||
}
|
||||
|
||||
[BestFriend]
|
||||
internal void SaveTextStream(string name, Action<TextWriter> action)
|
||||
{
|
||||
_ectx.Check(InRepository, "Can't save a text stream when writing to a single stream");
|
||||
_ectx.CheckNonEmpty(name, nameof(name));
|
||||
_ectx.CheckValue(action, nameof(action));
|
||||
|
||||
// I verified in the CLR source that the default buffer size is 1024. It's unfortunate
|
||||
// that to set leaveOpen to true, we have to specify the buffer size....
|
||||
using (var ent = Repository.CreateEntry(Directory, name))
|
||||
using (var writer = Utils.OpenWriter(ent.Stream))
|
||||
{
|
||||
action(writer);
|
||||
}
|
||||
}
|
||||
|
||||
[BestFriend]
|
||||
internal void SaveBinaryStream(string name, Action<BinaryWriter> action)
|
||||
{
|
||||
_ectx.Check(InRepository, "Can't save a text stream when writing to a single stream");
|
||||
_ectx.CheckNonEmpty(name, nameof(name));
|
||||
_ectx.CheckValue(action, nameof(action));
|
||||
|
||||
// I verified in the CLR source that the default buffer size is 1024. It's unfortunate
|
||||
// that to set leaveOpen to true, we have to specify the buffer size....
|
||||
using (var ent = Repository.CreateEntry(Directory, name))
|
||||
using (var writer = new BinaryWriter(ent.Stream, Encoding.UTF8, leaveOpen: true))
|
||||
{
|
||||
action(writer);
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Puts a string into the context pool, and writes the integer code of the string ID
|
||||
/// to the write stream. If str is null, this writes -1 and doesn't add it to the pool.
|
||||
/// </summary>
|
||||
[BestFriend]
|
||||
internal void SaveStringOrNull(string str)
|
||||
{
|
||||
if (str == null)
|
||||
Writer.Write(-1);
|
||||
else
|
||||
Writer.Write(Strings.Add(str).Id);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Puts a string into the context pool, and writes the integer code of the string ID
|
||||
/// to the write stream. Checks that str is not null.
|
||||
/// </summary>
|
||||
[BestFriend]
|
||||
internal void SaveString(string str)
|
||||
{
|
||||
_ectx.CheckValue(str, nameof(str));
|
||||
/// <summary>
|
||||
/// Puts a string into the context pool, and writes the integer code of the string ID
|
||||
/// to the write stream. If str is null, this writes -1 and doesn't add it to the pool.
|
||||
/// </summary>
|
||||
[BestFriend]
|
||||
internal void SaveStringOrNull(string str)
|
||||
{
|
||||
if (str == null)
|
||||
Writer.Write(-1);
|
||||
else
|
||||
Writer.Write(Strings.Add(str).Id);
|
||||
}
|
||||
}
|
||||
|
||||
[BestFriend]
|
||||
internal void SaveString(ReadOnlyMemory<char> str)
|
||||
/// <summary>
|
||||
/// Puts a string into the context pool, and writes the integer code of the string ID
|
||||
/// to the write stream. Checks that str is not null.
|
||||
/// </summary>
|
||||
[BestFriend]
|
||||
internal void SaveString(string str)
|
||||
{
|
||||
_ectx.CheckValue(str, nameof(str));
|
||||
Writer.Write(Strings.Add(str).Id);
|
||||
}
|
||||
|
||||
[BestFriend]
|
||||
internal void SaveString(ReadOnlyMemory<char> str)
|
||||
{
|
||||
Writer.Write(Strings.Add(str).Id);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Puts a string into the context pool, and writes the integer code of the string ID
|
||||
/// to the write stream.
|
||||
/// </summary>
|
||||
[BestFriend]
|
||||
internal void SaveNonEmptyString(string str)
|
||||
{
|
||||
_ectx.CheckParam(!string.IsNullOrEmpty(str), nameof(str));
|
||||
Writer.Write(Strings.Add(str).Id);
|
||||
}
|
||||
|
||||
[BestFriend]
|
||||
internal void SaveNonEmptyString(ReadOnlyMemory<char> str)
|
||||
{
|
||||
Writer.Write(Strings.Add(str).Id);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Commit the save operation. This completes writing of the main stream. When in repository
|
||||
/// mode, it disposes <see cref="Writer"/> (but not <see cref="Repository"/>).
|
||||
/// </summary>
|
||||
[BestFriend]
|
||||
internal void Done()
|
||||
{
|
||||
_ectx.Check(Header.ModelSignature != 0, "ModelSignature not specified!");
|
||||
ModelHeader.EndWrite(Writer, FpMin, ref Header, Strings, _loaderAssemblyName);
|
||||
Dispose();
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// When in repository mode, this disposes the Writer (but not the repository).
|
||||
/// </summary>
|
||||
public void Dispose()
|
||||
{
|
||||
_ectx.Assert((_ent == null) == !InRepository);
|
||||
|
||||
// When in single stream mode, we don't own the Writer.
|
||||
if (InRepository)
|
||||
{
|
||||
Writer.Write(Strings.Add(str).Id);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Puts a string into the context pool, and writes the integer code of the string ID
|
||||
/// to the write stream.
|
||||
/// </summary>
|
||||
[BestFriend]
|
||||
internal void SaveNonEmptyString(string str)
|
||||
{
|
||||
_ectx.CheckParam(!string.IsNullOrEmpty(str), nameof(str));
|
||||
Writer.Write(Strings.Add(str).Id);
|
||||
}
|
||||
|
||||
[BestFriend]
|
||||
internal void SaveNonEmptyString(ReadOnlyMemory<char> str)
|
||||
{
|
||||
Writer.Write(Strings.Add(str).Id);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Commit the save operation. This completes writing of the main stream. When in repository
|
||||
/// mode, it disposes <see cref="Writer"/> (but not <see cref="Repository"/>).
|
||||
/// </summary>
|
||||
[BestFriend]
|
||||
internal void Done()
|
||||
{
|
||||
_ectx.Check(Header.ModelSignature != 0, "ModelSignature not specified!");
|
||||
ModelHeader.EndWrite(Writer, FpMin, ref Header, Strings, _loaderAssemblyName);
|
||||
Dispose();
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// When in repository mode, this disposes the Writer (but not the repository).
|
||||
/// </summary>
|
||||
public void Dispose()
|
||||
{
|
||||
_ectx.Assert((_ent == null) == !InRepository);
|
||||
|
||||
// When in single stream mode, we don't own the Writer.
|
||||
if (InRepository)
|
||||
{
|
||||
Writer.Dispose();
|
||||
_ent.Dispose();
|
||||
}
|
||||
Writer.Dispose();
|
||||
_ent.Dispose();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -7,85 +7,84 @@ using System.IO;
|
|||
using System.Text;
|
||||
using Microsoft.ML.Runtime;
|
||||
|
||||
namespace Microsoft.ML
|
||||
namespace Microsoft.ML;
|
||||
|
||||
public sealed partial class ModelSaveContext : IDisposable
|
||||
{
|
||||
public sealed partial class ModelSaveContext : IDisposable
|
||||
/// <summary>
|
||||
/// Save a sub model to the given sub directory. This requires <see cref="InRepository"/> to be <see langword="true"/>.
|
||||
/// </summary>
|
||||
[BestFriend]
|
||||
internal void SaveModel<T>(T value, string name)
|
||||
where T : class
|
||||
{
|
||||
/// <summary>
|
||||
/// Save a sub model to the given sub directory. This requires <see cref="InRepository"/> to be <see langword="true"/>.
|
||||
/// </summary>
|
||||
[BestFriend]
|
||||
internal void SaveModel<T>(T value, string name)
|
||||
where T : class
|
||||
_ectx.Check(InRepository, "Can't save a sub-model when writing to a single stream");
|
||||
SaveModel(Repository, value, Path.Combine(Directory ?? "", name));
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Save the object by calling TrySaveModel then falling back to .net serialization.
|
||||
/// </summary>
|
||||
[BestFriend]
|
||||
internal static void SaveModel<T>(RepositoryWriter rep, T value, string path)
|
||||
where T : class
|
||||
{
|
||||
if (value == null)
|
||||
return;
|
||||
|
||||
var sm = value as ICanSaveModel;
|
||||
if (sm != null)
|
||||
{
|
||||
_ectx.Check(InRepository, "Can't save a sub-model when writing to a single stream");
|
||||
SaveModel(Repository, value, Path.Combine(Directory ?? "", name));
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Save the object by calling TrySaveModel then falling back to .net serialization.
|
||||
/// </summary>
|
||||
[BestFriend]
|
||||
internal static void SaveModel<T>(RepositoryWriter rep, T value, string path)
|
||||
where T : class
|
||||
{
|
||||
if (value == null)
|
||||
return;
|
||||
|
||||
var sm = value as ICanSaveModel;
|
||||
if (sm != null)
|
||||
using (var ctx = new ModelSaveContext(rep, path, ModelLoadContext.ModelStreamName))
|
||||
{
|
||||
using (var ctx = new ModelSaveContext(rep, path, ModelLoadContext.ModelStreamName))
|
||||
{
|
||||
sm.Save(ctx);
|
||||
ctx.Done();
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
var sb = value as ICanSaveInBinaryFormat;
|
||||
if (sb != null)
|
||||
{
|
||||
using (var ent = rep.CreateEntry(path, ModelLoadContext.NameBinary))
|
||||
using (var writer = new BinaryWriter(ent.Stream, Encoding.UTF8, leaveOpen: true))
|
||||
{
|
||||
sb.SaveAsBinary(writer);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Save to a single-stream by invoking the given action.
|
||||
/// </summary>
|
||||
[BestFriend]
|
||||
internal static void Save(BinaryWriter writer, Action<ModelSaveContext> fn)
|
||||
{
|
||||
Contracts.CheckValue(writer, nameof(writer));
|
||||
Contracts.CheckValue(fn, nameof(fn));
|
||||
|
||||
using (var ctx = new ModelSaveContext(writer))
|
||||
{
|
||||
fn(ctx);
|
||||
sm.Save(ctx);
|
||||
ctx.Done();
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Save to the given sub directory by invoking the given action. This requires
|
||||
/// <see cref="InRepository"/> to be <see langword="true"/>.
|
||||
/// </summary>
|
||||
[BestFriend]
|
||||
internal void SaveSubModel(string dir, Action<ModelSaveContext> fn)
|
||||
var sb = value as ICanSaveInBinaryFormat;
|
||||
if (sb != null)
|
||||
{
|
||||
_ectx.Check(InRepository, "Can't save a sub-model when writing to a single stream");
|
||||
_ectx.CheckNonEmpty(dir, nameof(dir));
|
||||
_ectx.CheckValue(fn, nameof(fn));
|
||||
|
||||
using (var ctx = new ModelSaveContext(Repository, Path.Combine(Directory ?? "", dir), ModelLoadContext.ModelStreamName))
|
||||
using (var ent = rep.CreateEntry(path, ModelLoadContext.NameBinary))
|
||||
using (var writer = new BinaryWriter(ent.Stream, Encoding.UTF8, leaveOpen: true))
|
||||
{
|
||||
fn(ctx);
|
||||
ctx.Done();
|
||||
sb.SaveAsBinary(writer);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Save to a single-stream by invoking the given action.
|
||||
/// </summary>
|
||||
[BestFriend]
|
||||
internal static void Save(BinaryWriter writer, Action<ModelSaveContext> fn)
|
||||
{
|
||||
Contracts.CheckValue(writer, nameof(writer));
|
||||
Contracts.CheckValue(fn, nameof(fn));
|
||||
|
||||
using (var ctx = new ModelSaveContext(writer))
|
||||
{
|
||||
fn(ctx);
|
||||
ctx.Done();
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Save to the given sub directory by invoking the given action. This requires
|
||||
/// <see cref="InRepository"/> to be <see langword="true"/>.
|
||||
/// </summary>
|
||||
[BestFriend]
|
||||
internal void SaveSubModel(string dir, Action<ModelSaveContext> fn)
|
||||
{
|
||||
_ectx.Check(InRepository, "Can't save a sub-model when writing to a single stream");
|
||||
_ectx.CheckNonEmpty(dir, nameof(dir));
|
||||
_ectx.CheckValue(fn, nameof(fn));
|
||||
|
||||
using (var ctx = new ModelSaveContext(Repository, Path.Combine(Directory ?? "", dir), ModelLoadContext.ModelStreamName))
|
||||
{
|
||||
fn(ctx);
|
||||
ctx.Done();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
|
@ -9,264 +9,263 @@ using System.Text;
|
|||
using Microsoft.ML.Internal.Utilities;
|
||||
using Microsoft.ML.Runtime;
|
||||
|
||||
namespace Microsoft.ML.Data
|
||||
namespace Microsoft.ML.Data;
|
||||
|
||||
[BestFriend]
|
||||
internal static class ReadOnlyMemoryUtils
|
||||
{
|
||||
[BestFriend]
|
||||
internal static class ReadOnlyMemoryUtils
|
||||
|
||||
/// <summary>
|
||||
/// Compare equality with the given system string value.
|
||||
/// </summary>
|
||||
public static bool EqualsStr(string s, ReadOnlyMemory<char> memory)
|
||||
{
|
||||
Contracts.CheckValueOrNull(s);
|
||||
|
||||
/// <summary>
|
||||
/// Compare equality with the given system string value.
|
||||
/// </summary>
|
||||
public static bool EqualsStr(string s, ReadOnlyMemory<char> memory)
|
||||
if (s == null)
|
||||
return memory.Length == 0;
|
||||
|
||||
if (s.Length != memory.Length)
|
||||
return false;
|
||||
|
||||
return memory.Span.SequenceEqual(s.AsSpan());
|
||||
}
|
||||
|
||||
public static IEnumerable<ReadOnlyMemory<char>> Split(ReadOnlyMemory<char> memory, char[] separators)
|
||||
{
|
||||
Contracts.CheckValueOrNull(separators);
|
||||
|
||||
if (memory.IsEmpty)
|
||||
yield break;
|
||||
|
||||
if (separators == null || separators.Length == 0)
|
||||
{
|
||||
Contracts.CheckValueOrNull(s);
|
||||
|
||||
if (s == null)
|
||||
return memory.Length == 0;
|
||||
|
||||
if (s.Length != memory.Length)
|
||||
return false;
|
||||
|
||||
return memory.Span.SequenceEqual(s.AsSpan());
|
||||
yield return memory;
|
||||
yield break;
|
||||
}
|
||||
|
||||
public static IEnumerable<ReadOnlyMemory<char>> Split(ReadOnlyMemory<char> memory, char[] separators)
|
||||
var span = memory.Span;
|
||||
if (separators.Length == 1)
|
||||
{
|
||||
Contracts.CheckValueOrNull(separators);
|
||||
|
||||
if (memory.IsEmpty)
|
||||
yield break;
|
||||
|
||||
if (separators == null || separators.Length == 0)
|
||||
char chSep = separators[0];
|
||||
for (int ichCur = 0; ;)
|
||||
{
|
||||
yield return memory;
|
||||
yield break;
|
||||
}
|
||||
|
||||
var span = memory.Span;
|
||||
if (separators.Length == 1)
|
||||
{
|
||||
char chSep = separators[0];
|
||||
for (int ichCur = 0; ;)
|
||||
int nextSep = span.IndexOf(chSep);
|
||||
if (nextSep == -1)
|
||||
{
|
||||
int nextSep = span.IndexOf(chSep);
|
||||
if (nextSep == -1)
|
||||
{
|
||||
yield return memory.Slice(ichCur);
|
||||
yield break;
|
||||
}
|
||||
|
||||
yield return memory.Slice(ichCur, nextSep);
|
||||
|
||||
// Skip the separator.
|
||||
ichCur += nextSep + 1;
|
||||
span = memory.Slice(ichCur).Span;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
for (int ichCur = 0; ;)
|
||||
{
|
||||
int nextSep = span.IndexOfAny(separators);
|
||||
if (nextSep == -1)
|
||||
{
|
||||
yield return memory.Slice(ichCur);
|
||||
yield break;
|
||||
}
|
||||
|
||||
yield return memory.Slice(ichCur, nextSep);
|
||||
|
||||
// Skip the separator.
|
||||
ichCur += nextSep + 1;
|
||||
span = memory.Slice(ichCur).Span;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Splits <paramref name="memory"/> on the left-most occurrence of separator and produces the left
|
||||
/// and right <see cref="ReadOnlyMemory{T}"/> of <see cref="char"/> values. If <paramref name="memory"/> does not contain the separator character,
|
||||
/// this returns false and sets <paramref name="left"/> to this instance and <paramref name="right"/>
|
||||
/// to the default <see cref="ReadOnlyMemory{T}"/> of <see cref="char"/> value.
|
||||
/// </summary>
|
||||
public static bool SplitOne(ReadOnlyMemory<char> memory, char separator, out ReadOnlyMemory<char> left, out ReadOnlyMemory<char> right)
|
||||
{
|
||||
if (memory.IsEmpty)
|
||||
{
|
||||
left = memory;
|
||||
right = default;
|
||||
return false;
|
||||
}
|
||||
|
||||
int index = memory.Span.IndexOf(separator);
|
||||
if (index == -1)
|
||||
{
|
||||
left = memory;
|
||||
right = default;
|
||||
return false;
|
||||
}
|
||||
|
||||
left = memory.Slice(0, index);
|
||||
right = memory.Slice(index + 1, memory.Length - index - 1);
|
||||
return true;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Splits <paramref name="memory"/> on the left-most occurrence of an element of separators character array and
|
||||
/// produces the left and right <see cref="ReadOnlyMemory{T}"/> of <see cref="char"/> values. If <paramref name="memory"/> does not contain any of the
|
||||
/// characters in separators, this return false and initializes <paramref name="left"/> to this instance
|
||||
/// and <paramref name="right"/> to the default <see cref="ReadOnlyMemory{T}"/> of <see cref="char"/> value.
|
||||
/// </summary>
|
||||
public static bool SplitOne(ReadOnlyMemory<char> memory, char[] separators, out ReadOnlyMemory<char> left, out ReadOnlyMemory<char> right)
|
||||
{
|
||||
Contracts.CheckValueOrNull(separators);
|
||||
|
||||
if (memory.IsEmpty || separators == null || separators.Length == 0)
|
||||
{
|
||||
left = memory;
|
||||
right = default;
|
||||
return false;
|
||||
}
|
||||
|
||||
int index;
|
||||
if (separators.Length == 1)
|
||||
index = memory.Span.IndexOf(separators[0]);
|
||||
else
|
||||
index = memory.Span.IndexOfAny(separators);
|
||||
|
||||
if (index == -1)
|
||||
{
|
||||
left = memory;
|
||||
right = default;
|
||||
return false;
|
||||
}
|
||||
|
||||
left = memory.Slice(0, index);
|
||||
right = memory.Slice(index + 1, memory.Length - index - 1);
|
||||
return true;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Returns a <see cref="ReadOnlyMemory{T}"/> of <see cref="char"/> with leading and trailing spaces trimmed. Note that this
|
||||
/// will remove only spaces, not any form of whitespace.
|
||||
/// </summary>
|
||||
public static ReadOnlyMemory<char> TrimSpaces(ReadOnlyMemory<char> memory)
|
||||
{
|
||||
if (memory.IsEmpty)
|
||||
return memory;
|
||||
|
||||
int ichLim = memory.Length;
|
||||
int ichMin = 0;
|
||||
var span = memory.Span;
|
||||
if (span[ichMin] != ' ' && span[ichLim - 1] != ' ')
|
||||
return memory;
|
||||
|
||||
while (ichMin < ichLim && span[ichMin] == ' ')
|
||||
ichMin++;
|
||||
while (ichMin < ichLim && span[ichLim - 1] == ' ')
|
||||
ichLim--;
|
||||
return memory.Slice(ichMin, ichLim - ichMin);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Returns a <see cref="ReadOnlyMemory{T}"/> of <see cref="char"/> with leading and trailing whitespace trimmed.
|
||||
/// </summary>
|
||||
public static ReadOnlyMemory<char> TrimWhiteSpace(ReadOnlyMemory<char> memory)
|
||||
{
|
||||
if (memory.IsEmpty)
|
||||
return memory;
|
||||
|
||||
int ichMin = 0;
|
||||
int ichLim = memory.Length;
|
||||
var span = memory.Span;
|
||||
if (!char.IsWhiteSpace(span[ichMin]) && !char.IsWhiteSpace(span[ichLim - 1]))
|
||||
return memory;
|
||||
|
||||
while (ichMin < ichLim && char.IsWhiteSpace(span[ichMin]))
|
||||
ichMin++;
|
||||
while (ichMin < ichLim && char.IsWhiteSpace(span[ichLim - 1]))
|
||||
ichLim--;
|
||||
|
||||
return memory.Slice(ichMin, ichLim - ichMin);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Returns a <see cref="ReadOnlyMemory{T}"/> of <see cref="char"/> with trailing whitespace trimmed.
|
||||
/// </summary>
|
||||
public static ReadOnlyMemory<char> TrimEndWhiteSpace(ReadOnlyMemory<char> memory)
|
||||
{
|
||||
if (memory.IsEmpty)
|
||||
return memory;
|
||||
|
||||
int ichLim = memory.Length;
|
||||
var span = memory.Span;
|
||||
if (!char.IsWhiteSpace(span[ichLim - 1]))
|
||||
return memory;
|
||||
|
||||
while (0 < ichLim && char.IsWhiteSpace(span[ichLim - 1]))
|
||||
ichLim--;
|
||||
|
||||
return memory.Slice(0, ichLim);
|
||||
}
|
||||
|
||||
public static void AddLowerCaseToStringBuilder(ReadOnlySpan<char> span, StringBuilder sb)
|
||||
{
|
||||
Contracts.CheckValue(sb, nameof(sb));
|
||||
|
||||
if (!span.IsEmpty)
|
||||
{
|
||||
int min = 0;
|
||||
int j;
|
||||
for (j = min; j < span.Length; j++)
|
||||
{
|
||||
char ch = CharUtils.ToLowerInvariant(span[j]);
|
||||
if (ch != span[j])
|
||||
{
|
||||
sb.AppendSpan(span.Slice(min, j - min)).Append(ch);
|
||||
min = j + 1;
|
||||
}
|
||||
yield return memory.Slice(ichCur);
|
||||
yield break;
|
||||
}
|
||||
|
||||
Contracts.Assert(j == span.Length);
|
||||
if (min != j)
|
||||
sb.AppendSpan(span.Slice(min, j - min));
|
||||
yield return memory.Slice(ichCur, nextSep);
|
||||
|
||||
// Skip the separator.
|
||||
ichCur += nextSep + 1;
|
||||
span = memory.Slice(ichCur).Span;
|
||||
}
|
||||
}
|
||||
|
||||
public static StringBuilder AppendMemory(this StringBuilder sb, ReadOnlyMemory<char> memory)
|
||||
else
|
||||
{
|
||||
Contracts.CheckValue(sb, nameof(sb));
|
||||
if (!memory.IsEmpty)
|
||||
sb.AppendSpan(memory.Span);
|
||||
|
||||
return sb;
|
||||
}
|
||||
|
||||
public static StringBuilder AppendSpan(this StringBuilder sb, ReadOnlySpan<char> span)
|
||||
{
|
||||
unsafe
|
||||
for (int ichCur = 0; ;)
|
||||
{
|
||||
fixed (char* valueChars = &MemoryMarshal.GetReference(span))
|
||||
int nextSep = span.IndexOfAny(separators);
|
||||
if (nextSep == -1)
|
||||
{
|
||||
sb.Append(valueChars, span.Length);
|
||||
yield return memory.Slice(ichCur);
|
||||
yield break;
|
||||
}
|
||||
}
|
||||
|
||||
return sb;
|
||||
}
|
||||
yield return memory.Slice(ichCur, nextSep);
|
||||
|
||||
public sealed class ReadonlyMemoryCharComparer : IEqualityComparer<ReadOnlyMemory<char>>
|
||||
{
|
||||
public bool Equals(ReadOnlyMemory<char> x, ReadOnlyMemory<char> y)
|
||||
{
|
||||
return x.Span.SequenceEqual(y.Span);
|
||||
}
|
||||
|
||||
public int GetHashCode(ReadOnlyMemory<char> obj)
|
||||
{
|
||||
return (int)Hashing.HashString(obj.Span);
|
||||
// Skip the separator.
|
||||
ichCur += nextSep + 1;
|
||||
span = memory.Slice(ichCur).Span;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Splits <paramref name="memory"/> on the left-most occurrence of separator and produces the left
|
||||
/// and right <see cref="ReadOnlyMemory{T}"/> of <see cref="char"/> values. If <paramref name="memory"/> does not contain the separator character,
|
||||
/// this returns false and sets <paramref name="left"/> to this instance and <paramref name="right"/>
|
||||
/// to the default <see cref="ReadOnlyMemory{T}"/> of <see cref="char"/> value.
|
||||
/// </summary>
|
||||
public static bool SplitOne(ReadOnlyMemory<char> memory, char separator, out ReadOnlyMemory<char> left, out ReadOnlyMemory<char> right)
|
||||
{
|
||||
if (memory.IsEmpty)
|
||||
{
|
||||
left = memory;
|
||||
right = default;
|
||||
return false;
|
||||
}
|
||||
|
||||
int index = memory.Span.IndexOf(separator);
|
||||
if (index == -1)
|
||||
{
|
||||
left = memory;
|
||||
right = default;
|
||||
return false;
|
||||
}
|
||||
|
||||
left = memory.Slice(0, index);
|
||||
right = memory.Slice(index + 1, memory.Length - index - 1);
|
||||
return true;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Splits <paramref name="memory"/> on the left-most occurrence of an element of separators character array and
|
||||
/// produces the left and right <see cref="ReadOnlyMemory{T}"/> of <see cref="char"/> values. If <paramref name="memory"/> does not contain any of the
|
||||
/// characters in separators, this return false and initializes <paramref name="left"/> to this instance
|
||||
/// and <paramref name="right"/> to the default <see cref="ReadOnlyMemory{T}"/> of <see cref="char"/> value.
|
||||
/// </summary>
|
||||
public static bool SplitOne(ReadOnlyMemory<char> memory, char[] separators, out ReadOnlyMemory<char> left, out ReadOnlyMemory<char> right)
|
||||
{
|
||||
Contracts.CheckValueOrNull(separators);
|
||||
|
||||
if (memory.IsEmpty || separators == null || separators.Length == 0)
|
||||
{
|
||||
left = memory;
|
||||
right = default;
|
||||
return false;
|
||||
}
|
||||
|
||||
int index;
|
||||
if (separators.Length == 1)
|
||||
index = memory.Span.IndexOf(separators[0]);
|
||||
else
|
||||
index = memory.Span.IndexOfAny(separators);
|
||||
|
||||
if (index == -1)
|
||||
{
|
||||
left = memory;
|
||||
right = default;
|
||||
return false;
|
||||
}
|
||||
|
||||
left = memory.Slice(0, index);
|
||||
right = memory.Slice(index + 1, memory.Length - index - 1);
|
||||
return true;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Returns a <see cref="ReadOnlyMemory{T}"/> of <see cref="char"/> with leading and trailing spaces trimmed. Note that this
|
||||
/// will remove only spaces, not any form of whitespace.
|
||||
/// </summary>
|
||||
public static ReadOnlyMemory<char> TrimSpaces(ReadOnlyMemory<char> memory)
|
||||
{
|
||||
if (memory.IsEmpty)
|
||||
return memory;
|
||||
|
||||
int ichLim = memory.Length;
|
||||
int ichMin = 0;
|
||||
var span = memory.Span;
|
||||
if (span[ichMin] != ' ' && span[ichLim - 1] != ' ')
|
||||
return memory;
|
||||
|
||||
while (ichMin < ichLim && span[ichMin] == ' ')
|
||||
ichMin++;
|
||||
while (ichMin < ichLim && span[ichLim - 1] == ' ')
|
||||
ichLim--;
|
||||
return memory.Slice(ichMin, ichLim - ichMin);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Returns a <see cref="ReadOnlyMemory{T}"/> of <see cref="char"/> with leading and trailing whitespace trimmed.
|
||||
/// </summary>
|
||||
public static ReadOnlyMemory<char> TrimWhiteSpace(ReadOnlyMemory<char> memory)
|
||||
{
|
||||
if (memory.IsEmpty)
|
||||
return memory;
|
||||
|
||||
int ichMin = 0;
|
||||
int ichLim = memory.Length;
|
||||
var span = memory.Span;
|
||||
if (!char.IsWhiteSpace(span[ichMin]) && !char.IsWhiteSpace(span[ichLim - 1]))
|
||||
return memory;
|
||||
|
||||
while (ichMin < ichLim && char.IsWhiteSpace(span[ichMin]))
|
||||
ichMin++;
|
||||
while (ichMin < ichLim && char.IsWhiteSpace(span[ichLim - 1]))
|
||||
ichLim--;
|
||||
|
||||
return memory.Slice(ichMin, ichLim - ichMin);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Returns a <see cref="ReadOnlyMemory{T}"/> of <see cref="char"/> with trailing whitespace trimmed.
|
||||
/// </summary>
|
||||
public static ReadOnlyMemory<char> TrimEndWhiteSpace(ReadOnlyMemory<char> memory)
|
||||
{
|
||||
if (memory.IsEmpty)
|
||||
return memory;
|
||||
|
||||
int ichLim = memory.Length;
|
||||
var span = memory.Span;
|
||||
if (!char.IsWhiteSpace(span[ichLim - 1]))
|
||||
return memory;
|
||||
|
||||
while (0 < ichLim && char.IsWhiteSpace(span[ichLim - 1]))
|
||||
ichLim--;
|
||||
|
||||
return memory.Slice(0, ichLim);
|
||||
}
|
||||
|
||||
public static void AddLowerCaseToStringBuilder(ReadOnlySpan<char> span, StringBuilder sb)
|
||||
{
|
||||
Contracts.CheckValue(sb, nameof(sb));
|
||||
|
||||
if (!span.IsEmpty)
|
||||
{
|
||||
int min = 0;
|
||||
int j;
|
||||
for (j = min; j < span.Length; j++)
|
||||
{
|
||||
char ch = CharUtils.ToLowerInvariant(span[j]);
|
||||
if (ch != span[j])
|
||||
{
|
||||
sb.AppendSpan(span.Slice(min, j - min)).Append(ch);
|
||||
min = j + 1;
|
||||
}
|
||||
}
|
||||
|
||||
Contracts.Assert(j == span.Length);
|
||||
if (min != j)
|
||||
sb.AppendSpan(span.Slice(min, j - min));
|
||||
}
|
||||
}
|
||||
|
||||
public static StringBuilder AppendMemory(this StringBuilder sb, ReadOnlyMemory<char> memory)
|
||||
{
|
||||
Contracts.CheckValue(sb, nameof(sb));
|
||||
if (!memory.IsEmpty)
|
||||
sb.AppendSpan(memory.Span);
|
||||
|
||||
return sb;
|
||||
}
|
||||
|
||||
public static StringBuilder AppendSpan(this StringBuilder sb, ReadOnlySpan<char> span)
|
||||
{
|
||||
unsafe
|
||||
{
|
||||
fixed (char* valueChars = &MemoryMarshal.GetReference(span))
|
||||
{
|
||||
sb.Append(valueChars, span.Length);
|
||||
}
|
||||
}
|
||||
|
||||
return sb;
|
||||
}
|
||||
|
||||
public sealed class ReadonlyMemoryCharComparer : IEqualityComparer<ReadOnlyMemory<char>>
|
||||
{
|
||||
public bool Equals(ReadOnlyMemory<char> x, ReadOnlyMemory<char> y)
|
||||
{
|
||||
return x.Span.SequenceEqual(y.Span);
|
||||
}
|
||||
|
||||
public int GetHashCode(ReadOnlyMemory<char> obj)
|
||||
{
|
||||
return (int)Hashing.HashString(obj.Span);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
|
@ -6,484 +6,483 @@ using System.Collections.Generic;
|
|||
using Microsoft.ML.Internal.Utilities;
|
||||
using Microsoft.ML.Runtime;
|
||||
|
||||
namespace Microsoft.ML.Data
|
||||
namespace Microsoft.ML.Data;
|
||||
|
||||
/// <summary>
|
||||
/// Encapsulates a <see cref="DataViewSchema"/> plus column role mapping information. The purpose of role mappings is to
|
||||
/// provide information on what the intended usage is for. That is: while a given data view may have a column named
|
||||
/// "Features", by itself that is insufficient: the trainer must be fed a role mapping that says that the role
|
||||
/// mapping for features is filled by that "Features" column. This allows things like columns not named "Features"
|
||||
/// to actually fill that role (as opposed to insisting on a hard coding, or having every trainer have to be
|
||||
/// individually configured). Also, by being a one-to-many mapping, it is a way for learners that can consume
|
||||
/// multiple features columns to consume that information.
|
||||
///
|
||||
/// This class has convenience fields for several common column roles (for example, <see cref="Feature"/>, <see
|
||||
/// cref="Label"/>), but can hold an arbitrary set of column infos. The convenience fields are non-null if and only
|
||||
/// if there is a unique column with the corresponding role. When there are no such columns or more than one such
|
||||
/// column, the field is <c>null</c>. The <see cref="Has"/>, <see cref="HasUnique"/>, and <see cref="HasMultiple"/>
|
||||
/// methods provide some cardinality information. Note that all columns assigned roles are guaranteed to be non-hidden
|
||||
/// in this schema.
|
||||
/// </summary>
|
||||
/// <remarks>
|
||||
/// Note that instances of this class are, like instances of <see cref="DataViewSchema"/>, immutable.
|
||||
///
|
||||
/// It is often the case that one wishes to bundle the actual data with the role mappings, not just the schema. For
|
||||
/// that case, please use the <see cref="RoleMappedData"/> class.
|
||||
///
|
||||
/// Note that there is no need for components consuming a <see cref="RoleMappedData"/> or <see cref="RoleMappedSchema"/>
|
||||
/// to make use of every defined mapping. Consuming components are also expected to ignore any <see cref="ColumnRole"/>
|
||||
/// they do not handle. They may very well however complain if a mapping they wanted to see is not present, or the column(s)
|
||||
/// mapped from the role are not of the form they require.
|
||||
/// </remarks>
|
||||
/// <seealso cref="ColumnRole"/>
|
||||
/// <seealso cref="RoleMappedData"/>
|
||||
[BestFriend]
|
||||
internal sealed class RoleMappedSchema
|
||||
{
|
||||
private const string FeatureString = "Feature";
|
||||
private const string LabelString = "Label";
|
||||
private const string GroupString = "Group";
|
||||
private const string WeightString = "Weight";
|
||||
private const string NameString = "Name";
|
||||
private const string FeatureContributionsString = "FeatureContributions";
|
||||
|
||||
/// <summary>
|
||||
/// Encapsulates a <see cref="DataViewSchema"/> plus column role mapping information. The purpose of role mappings is to
|
||||
/// provide information on what the intended usage is for. That is: while a given data view may have a column named
|
||||
/// "Features", by itself that is insufficient: the trainer must be fed a role mapping that says that the role
|
||||
/// mapping for features is filled by that "Features" column. This allows things like columns not named "Features"
|
||||
/// to actually fill that role (as opposed to insisting on a hard coding, or having every trainer have to be
|
||||
/// individually configured). Also, by being a one-to-many mapping, it is a way for learners that can consume
|
||||
/// multiple features columns to consume that information.
|
||||
///
|
||||
/// This class has convenience fields for several common column roles (for example, <see cref="Feature"/>, <see
|
||||
/// cref="Label"/>), but can hold an arbitrary set of column infos. The convenience fields are non-null if and only
|
||||
/// if there is a unique column with the corresponding role. When there are no such columns or more than one such
|
||||
/// column, the field is <c>null</c>. The <see cref="Has"/>, <see cref="HasUnique"/>, and <see cref="HasMultiple"/>
|
||||
/// methods provide some cardinality information. Note that all columns assigned roles are guaranteed to be non-hidden
|
||||
/// in this schema.
|
||||
/// Instances of this are the keys of a <see cref="RoleMappedSchema"/>. This class also holds some important
|
||||
/// commonly used pre-defined instances available (for example, <see cref="Label"/>, <see cref="Feature"/>) that should
|
||||
/// be used when possible for consistency reasons. However, practitioners should not be afraid to declare custom
|
||||
/// roles if approppriate for their task.
|
||||
/// </summary>
|
||||
/// <remarks>
|
||||
/// Note that instances of this class are, like instances of <see cref="DataViewSchema"/>, immutable.
|
||||
///
|
||||
/// It is often the case that one wishes to bundle the actual data with the role mappings, not just the schema. For
|
||||
/// that case, please use the <see cref="RoleMappedData"/> class.
|
||||
///
|
||||
/// Note that there is no need for components consuming a <see cref="RoleMappedData"/> or <see cref="RoleMappedSchema"/>
|
||||
/// to make use of every defined mapping. Consuming components are also expected to ignore any <see cref="ColumnRole"/>
|
||||
/// they do not handle. They may very well however complain if a mapping they wanted to see is not present, or the column(s)
|
||||
/// mapped from the role are not of the form they require.
|
||||
/// </remarks>
|
||||
/// <seealso cref="ColumnRole"/>
|
||||
/// <seealso cref="RoleMappedData"/>
|
||||
[BestFriend]
|
||||
internal sealed class RoleMappedSchema
|
||||
public readonly struct ColumnRole
|
||||
{
|
||||
private const string FeatureString = "Feature";
|
||||
private const string LabelString = "Label";
|
||||
private const string GroupString = "Group";
|
||||
private const string WeightString = "Weight";
|
||||
private const string NameString = "Name";
|
||||
private const string FeatureContributionsString = "FeatureContributions";
|
||||
/// <summary>
|
||||
/// Role for features. Commonly used as the independent variables given to trainers, and scorers.
|
||||
/// </summary>
|
||||
public static ColumnRole Feature => FeatureString;
|
||||
|
||||
/// <summary>
|
||||
/// Instances of this are the keys of a <see cref="RoleMappedSchema"/>. This class also holds some important
|
||||
/// commonly used pre-defined instances available (for example, <see cref="Label"/>, <see cref="Feature"/>) that should
|
||||
/// be used when possible for consistency reasons. However, practitioners should not be afraid to declare custom
|
||||
/// roles if approppriate for their task.
|
||||
/// Role for labels. Commonly used as the dependent variables given to trainers, and evaluators.
|
||||
/// </summary>
|
||||
public readonly struct ColumnRole
|
||||
public static ColumnRole Label => LabelString;
|
||||
|
||||
/// <summary>
|
||||
/// Role for group ID. Commonly used in ranking applications, for defining query boundaries, or
|
||||
/// sequence classification, for defining the boundaries of an utterance.
|
||||
/// </summary>
|
||||
public static ColumnRole Group => GroupString;
|
||||
|
||||
/// <summary>
|
||||
/// Role for sample weights. Commonly used to point to a number to make trainers give more weight
|
||||
/// to a particular example.
|
||||
/// </summary>
|
||||
public static ColumnRole Weight => WeightString;
|
||||
|
||||
/// <summary>
|
||||
/// Role for sample names. Useful for informational and tracking purposes when scoring, but typically
|
||||
/// without affecting results.
|
||||
/// </summary>
|
||||
public static ColumnRole Name => NameString;
|
||||
|
||||
// REVIEW: Does this really belong here?
|
||||
/// <summary>
|
||||
/// Role for feature contributions. Useful for specific diagnostic functionality.
|
||||
/// </summary>
|
||||
public static ColumnRole FeatureContributions => FeatureContributionsString;
|
||||
|
||||
/// <summary>
|
||||
/// The string value for the role. Guaranteed to be non-empty.
|
||||
/// </summary>
|
||||
public readonly string Value;
|
||||
|
||||
/// <summary>
|
||||
/// Constructor for the column role.
|
||||
/// </summary>
|
||||
/// <param name="value">The value for the role. Must be non-empty.</param>
|
||||
public ColumnRole(string value)
|
||||
{
|
||||
/// <summary>
|
||||
/// Role for features. Commonly used as the independent variables given to trainers, and scorers.
|
||||
/// </summary>
|
||||
public static ColumnRole Feature => FeatureString;
|
||||
|
||||
/// <summary>
|
||||
/// Role for labels. Commonly used as the dependent variables given to trainers, and evaluators.
|
||||
/// </summary>
|
||||
public static ColumnRole Label => LabelString;
|
||||
|
||||
/// <summary>
|
||||
/// Role for group ID. Commonly used in ranking applications, for defining query boundaries, or
|
||||
/// sequence classification, for defining the boundaries of an utterance.
|
||||
/// </summary>
|
||||
public static ColumnRole Group => GroupString;
|
||||
|
||||
/// <summary>
|
||||
/// Role for sample weights. Commonly used to point to a number to make trainers give more weight
|
||||
/// to a particular example.
|
||||
/// </summary>
|
||||
public static ColumnRole Weight => WeightString;
|
||||
|
||||
/// <summary>
|
||||
/// Role for sample names. Useful for informational and tracking purposes when scoring, but typically
|
||||
/// without affecting results.
|
||||
/// </summary>
|
||||
public static ColumnRole Name => NameString;
|
||||
|
||||
// REVIEW: Does this really belong here?
|
||||
/// <summary>
|
||||
/// Role for feature contributions. Useful for specific diagnostic functionality.
|
||||
/// </summary>
|
||||
public static ColumnRole FeatureContributions => FeatureContributionsString;
|
||||
|
||||
/// <summary>
|
||||
/// The string value for the role. Guaranteed to be non-empty.
|
||||
/// </summary>
|
||||
public readonly string Value;
|
||||
|
||||
/// <summary>
|
||||
/// Constructor for the column role.
|
||||
/// </summary>
|
||||
/// <param name="value">The value for the role. Must be non-empty.</param>
|
||||
public ColumnRole(string value)
|
||||
{
|
||||
Contracts.CheckNonEmpty(value, nameof(value));
|
||||
Value = value;
|
||||
}
|
||||
|
||||
public static implicit operator ColumnRole(string value)
|
||||
=> new ColumnRole(value);
|
||||
|
||||
/// <summary>
|
||||
/// Convenience method for creating a mapping pair from a role to a column name
|
||||
/// for giving to constructors of <see cref="RoleMappedSchema"/> and <see cref="RoleMappedData"/>.
|
||||
/// </summary>
|
||||
/// <param name="name">The column name to map to. Can be <c>null</c>, in which case when used
|
||||
/// to construct a role mapping structure this pair will be ignored</param>
|
||||
/// <returns>A key-value pair with this instance as the key and <paramref name="name"/> as the value</returns>
|
||||
public KeyValuePair<ColumnRole, string> Bind(string name)
|
||||
=> new KeyValuePair<ColumnRole, string>(this, name);
|
||||
Contracts.CheckNonEmpty(value, nameof(value));
|
||||
Value = value;
|
||||
}
|
||||
|
||||
public static KeyValuePair<ColumnRole, string> CreatePair(ColumnRole role, string name)
|
||||
=> new KeyValuePair<ColumnRole, string>(role, name);
|
||||
public static implicit operator ColumnRole(string value)
|
||||
=> new ColumnRole(value);
|
||||
|
||||
/// <summary>
|
||||
/// The source <see cref="Schema"/>.
|
||||
/// Convenience method for creating a mapping pair from a role to a column name
|
||||
/// for giving to constructors of <see cref="RoleMappedSchema"/> and <see cref="RoleMappedData"/>.
|
||||
/// </summary>
|
||||
public DataViewSchema Schema { get; }
|
||||
/// <param name="name">The column name to map to. Can be <c>null</c>, in which case when used
|
||||
/// to construct a role mapping structure this pair will be ignored</param>
|
||||
/// <returns>A key-value pair with this instance as the key and <paramref name="name"/> as the value</returns>
|
||||
public KeyValuePair<ColumnRole, string> Bind(string name)
|
||||
=> new KeyValuePair<ColumnRole, string>(this, name);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// The <see cref="ColumnRole.Feature"/> column, when there is exactly one (null otherwise).
|
||||
/// </summary>
|
||||
public DataViewSchema.Column? Feature { get; }
|
||||
public static KeyValuePair<ColumnRole, string> CreatePair(ColumnRole role, string name)
|
||||
=> new KeyValuePair<ColumnRole, string>(role, name);
|
||||
|
||||
/// <summary>
|
||||
/// The <see cref="ColumnRole.Label"/> column, when there is exactly one (null otherwise).
|
||||
/// </summary>
|
||||
public DataViewSchema.Column? Label { get; }
|
||||
/// <summary>
|
||||
/// The source <see cref="Schema"/>.
|
||||
/// </summary>
|
||||
public DataViewSchema Schema { get; }
|
||||
|
||||
/// <summary>
|
||||
/// The <see cref="ColumnRole.Group"/> column, when there is exactly one (null otherwise).
|
||||
/// </summary>
|
||||
public DataViewSchema.Column? Group { get; }
|
||||
/// <summary>
|
||||
/// The <see cref="ColumnRole.Feature"/> column, when there is exactly one (null otherwise).
|
||||
/// </summary>
|
||||
public DataViewSchema.Column? Feature { get; }
|
||||
|
||||
/// <summary>
|
||||
/// The <see cref="ColumnRole.Weight"/> column, when there is exactly one (null otherwise).
|
||||
/// </summary>
|
||||
public DataViewSchema.Column? Weight { get; }
|
||||
/// <summary>
|
||||
/// The <see cref="ColumnRole.Label"/> column, when there is exactly one (null otherwise).
|
||||
/// </summary>
|
||||
public DataViewSchema.Column? Label { get; }
|
||||
|
||||
/// <summary>
|
||||
/// The <see cref="ColumnRole.Name"/> column, when there is exactly one (null otherwise).
|
||||
/// </summary>
|
||||
public DataViewSchema.Column? Name { get; }
|
||||
/// <summary>
|
||||
/// The <see cref="ColumnRole.Group"/> column, when there is exactly one (null otherwise).
|
||||
/// </summary>
|
||||
public DataViewSchema.Column? Group { get; }
|
||||
|
||||
// Maps from role to the associated column infos.
|
||||
private readonly Dictionary<string, IReadOnlyList<DataViewSchema.Column>> _map;
|
||||
/// <summary>
|
||||
/// The <see cref="ColumnRole.Weight"/> column, when there is exactly one (null otherwise).
|
||||
/// </summary>
|
||||
public DataViewSchema.Column? Weight { get; }
|
||||
|
||||
private RoleMappedSchema(DataViewSchema schema, Dictionary<string, IReadOnlyList<DataViewSchema.Column>> map)
|
||||
/// <summary>
|
||||
/// The <see cref="ColumnRole.Name"/> column, when there is exactly one (null otherwise).
|
||||
/// </summary>
|
||||
public DataViewSchema.Column? Name { get; }
|
||||
|
||||
// Maps from role to the associated column infos.
|
||||
private readonly Dictionary<string, IReadOnlyList<DataViewSchema.Column>> _map;
|
||||
|
||||
private RoleMappedSchema(DataViewSchema schema, Dictionary<string, IReadOnlyList<DataViewSchema.Column>> map)
|
||||
{
|
||||
Contracts.AssertValue(schema);
|
||||
Contracts.AssertValue(map);
|
||||
|
||||
Schema = schema;
|
||||
_map = map;
|
||||
foreach (var kvp in _map)
|
||||
{
|
||||
Contracts.AssertValue(schema);
|
||||
Contracts.AssertValue(map);
|
||||
|
||||
Schema = schema;
|
||||
_map = map;
|
||||
foreach (var kvp in _map)
|
||||
{
|
||||
Contracts.Assert(Utils.Size(kvp.Value) > 0);
|
||||
var cols = kvp.Value;
|
||||
Contracts.Assert(Utils.Size(kvp.Value) > 0);
|
||||
var cols = kvp.Value;
|
||||
#if DEBUG
|
||||
foreach (var info in cols)
|
||||
Contracts.Assert(!schema[info.Index].IsHidden, "How did a hidden column sneak in?");
|
||||
foreach (var info in cols)
|
||||
Contracts.Assert(!schema[info.Index].IsHidden, "How did a hidden column sneak in?");
|
||||
#endif
|
||||
if (cols.Count == 1)
|
||||
if (cols.Count == 1)
|
||||
{
|
||||
switch (kvp.Key)
|
||||
{
|
||||
switch (kvp.Key)
|
||||
{
|
||||
case FeatureString:
|
||||
Feature = cols[0];
|
||||
break;
|
||||
case LabelString:
|
||||
Label = cols[0];
|
||||
break;
|
||||
case GroupString:
|
||||
Group = cols[0];
|
||||
break;
|
||||
case WeightString:
|
||||
Weight = cols[0];
|
||||
break;
|
||||
case NameString:
|
||||
Name = cols[0];
|
||||
break;
|
||||
}
|
||||
case FeatureString:
|
||||
Feature = cols[0];
|
||||
break;
|
||||
case LabelString:
|
||||
Label = cols[0];
|
||||
break;
|
||||
case GroupString:
|
||||
Group = cols[0];
|
||||
break;
|
||||
case WeightString:
|
||||
Weight = cols[0];
|
||||
break;
|
||||
case NameString:
|
||||
Name = cols[0];
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private RoleMappedSchema(DataViewSchema schema, Dictionary<string, List<DataViewSchema.Column>> map)
|
||||
: this(schema, Copy(map))
|
||||
private RoleMappedSchema(DataViewSchema schema, Dictionary<string, List<DataViewSchema.Column>> map)
|
||||
: this(schema, Copy(map))
|
||||
{
|
||||
}
|
||||
|
||||
private static void Add(Dictionary<string, List<DataViewSchema.Column>> map, ColumnRole role, DataViewSchema.Column column)
|
||||
{
|
||||
Contracts.AssertValue(map);
|
||||
Contracts.AssertNonEmpty(role.Value);
|
||||
|
||||
if (!map.TryGetValue(role.Value, out var list))
|
||||
{
|
||||
list = new List<DataViewSchema.Column>();
|
||||
map.Add(role.Value, list);
|
||||
}
|
||||
list.Add(column);
|
||||
}
|
||||
|
||||
private static void Add(Dictionary<string, List<DataViewSchema.Column>> map, ColumnRole role, DataViewSchema.Column column)
|
||||
private static Dictionary<string, List<DataViewSchema.Column>> MapFromNames(DataViewSchema schema, IEnumerable<KeyValuePair<ColumnRole, string>> roles, bool opt = false)
|
||||
{
|
||||
Contracts.AssertValue(schema);
|
||||
Contracts.AssertValue(roles);
|
||||
|
||||
var map = new Dictionary<string, List<DataViewSchema.Column>>();
|
||||
foreach (var kvp in roles)
|
||||
{
|
||||
Contracts.AssertValue(map);
|
||||
Contracts.AssertNonEmpty(role.Value);
|
||||
|
||||
if (!map.TryGetValue(role.Value, out var list))
|
||||
{
|
||||
list = new List<DataViewSchema.Column>();
|
||||
map.Add(role.Value, list);
|
||||
}
|
||||
list.Add(column);
|
||||
Contracts.AssertNonEmpty(kvp.Key.Value);
|
||||
if (string.IsNullOrEmpty(kvp.Value))
|
||||
continue;
|
||||
var info = schema.GetColumnOrNull(kvp.Value);
|
||||
if (info.HasValue)
|
||||
Add(map, kvp.Key.Value, info.Value);
|
||||
else if (!opt)
|
||||
throw Contracts.ExceptParam(nameof(schema), $"{kvp.Value} column '{kvp.Key.Value}' not found");
|
||||
}
|
||||
return map;
|
||||
}
|
||||
|
||||
private static Dictionary<string, List<DataViewSchema.Column>> MapFromNames(DataViewSchema schema, IEnumerable<KeyValuePair<ColumnRole, string>> roles, bool opt = false)
|
||||
/// <summary>
|
||||
/// Returns whether there are any columns with the given column role.
|
||||
/// </summary>
|
||||
public bool Has(ColumnRole role)
|
||||
=> _map.ContainsKey(role.Value);
|
||||
|
||||
/// <summary>
|
||||
/// Returns whether there is exactly one column of the given role.
|
||||
/// </summary>
|
||||
public bool HasUnique(ColumnRole role)
|
||||
=> _map.TryGetValue(role.Value, out var cols) && cols.Count == 1;
|
||||
|
||||
/// <summary>
|
||||
/// Returns whether there are two or more columns of the given role.
|
||||
/// </summary>
|
||||
public bool HasMultiple(ColumnRole role)
|
||||
=> _map.TryGetValue(role.Value, out var cols) && cols.Count > 1;
|
||||
|
||||
/// <summary>
|
||||
/// If there are columns of the given role, this returns the infos as a readonly list. Otherwise,
|
||||
/// it returns null.
|
||||
/// </summary>
|
||||
public IReadOnlyList<DataViewSchema.Column> GetColumns(ColumnRole role)
|
||||
=> _map.TryGetValue(role.Value, out var list) ? list : null;
|
||||
|
||||
/// <summary>
|
||||
/// An enumerable over all role-column associations within this object.
|
||||
/// </summary>
|
||||
public IEnumerable<KeyValuePair<ColumnRole, DataViewSchema.Column>> GetColumnRoles()
|
||||
{
|
||||
foreach (var roleAndList in _map)
|
||||
{
|
||||
Contracts.AssertValue(schema);
|
||||
Contracts.AssertValue(roles);
|
||||
|
||||
var map = new Dictionary<string, List<DataViewSchema.Column>>();
|
||||
foreach (var kvp in roles)
|
||||
{
|
||||
Contracts.AssertNonEmpty(kvp.Key.Value);
|
||||
if (string.IsNullOrEmpty(kvp.Value))
|
||||
continue;
|
||||
var info = schema.GetColumnOrNull(kvp.Value);
|
||||
if (info.HasValue)
|
||||
Add(map, kvp.Key.Value, info.Value);
|
||||
else if (!opt)
|
||||
throw Contracts.ExceptParam(nameof(schema), $"{kvp.Value} column '{kvp.Key.Value}' not found");
|
||||
}
|
||||
return map;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Returns whether there are any columns with the given column role.
|
||||
/// </summary>
|
||||
public bool Has(ColumnRole role)
|
||||
=> _map.ContainsKey(role.Value);
|
||||
|
||||
/// <summary>
|
||||
/// Returns whether there is exactly one column of the given role.
|
||||
/// </summary>
|
||||
public bool HasUnique(ColumnRole role)
|
||||
=> _map.TryGetValue(role.Value, out var cols) && cols.Count == 1;
|
||||
|
||||
/// <summary>
|
||||
/// Returns whether there are two or more columns of the given role.
|
||||
/// </summary>
|
||||
public bool HasMultiple(ColumnRole role)
|
||||
=> _map.TryGetValue(role.Value, out var cols) && cols.Count > 1;
|
||||
|
||||
/// <summary>
|
||||
/// If there are columns of the given role, this returns the infos as a readonly list. Otherwise,
|
||||
/// it returns null.
|
||||
/// </summary>
|
||||
public IReadOnlyList<DataViewSchema.Column> GetColumns(ColumnRole role)
|
||||
=> _map.TryGetValue(role.Value, out var list) ? list : null;
|
||||
|
||||
/// <summary>
|
||||
/// An enumerable over all role-column associations within this object.
|
||||
/// </summary>
|
||||
public IEnumerable<KeyValuePair<ColumnRole, DataViewSchema.Column>> GetColumnRoles()
|
||||
{
|
||||
foreach (var roleAndList in _map)
|
||||
{
|
||||
foreach (var info in roleAndList.Value)
|
||||
yield return new KeyValuePair<ColumnRole, DataViewSchema.Column>(roleAndList.Key, info);
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// An enumerable over all role-column associations within this object.
|
||||
/// </summary>
|
||||
public IEnumerable<KeyValuePair<ColumnRole, string>> GetColumnRoleNames()
|
||||
{
|
||||
foreach (var roleAndList in _map)
|
||||
{
|
||||
foreach (var info in roleAndList.Value)
|
||||
yield return new KeyValuePair<ColumnRole, string>(roleAndList.Key, info.Name);
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// An enumerable over all role-column associations for the given role. This is a helper function
|
||||
/// for implementing the <see cref="ISchemaBoundMapper.GetInputColumnRoles"/> method.
|
||||
/// </summary>
|
||||
public IEnumerable<KeyValuePair<ColumnRole, string>> GetColumnRoleNames(ColumnRole role)
|
||||
{
|
||||
if (_map.TryGetValue(role.Value, out var list))
|
||||
{
|
||||
foreach (var info in list)
|
||||
yield return new KeyValuePair<ColumnRole, string>(role, info.Name);
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Returns the <see cref="DataViewSchema.Column"/> corresponding to <paramref name="role"/> if there is
|
||||
/// exactly one such mapping, and otherwise throws an exception.
|
||||
/// </summary>
|
||||
/// <param name="role">The role to look up</param>
|
||||
/// <returns>The column corresponding to that role, assuming there was only one column
|
||||
/// mapped to that</returns>
|
||||
public DataViewSchema.Column GetUniqueColumn(ColumnRole role)
|
||||
{
|
||||
var infos = GetColumns(role);
|
||||
if (Utils.Size(infos) != 1)
|
||||
throw Contracts.Except("Expected exactly one column with role '{0}', but found {1}.", role.Value, Utils.Size(infos));
|
||||
return infos[0];
|
||||
}
|
||||
|
||||
private static Dictionary<string, IReadOnlyList<DataViewSchema.Column>> Copy(Dictionary<string, List<DataViewSchema.Column>> map)
|
||||
{
|
||||
var copy = new Dictionary<string, IReadOnlyList<DataViewSchema.Column>>(map.Count);
|
||||
foreach (var kvp in map)
|
||||
{
|
||||
Contracts.Assert(Utils.Size(kvp.Value) > 0);
|
||||
var cols = kvp.Value.ToArray();
|
||||
copy.Add(kvp.Key, cols);
|
||||
}
|
||||
return copy;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Constructor given a schema, and mapping pairs of roles to columns in the schema.
|
||||
/// This skips null or empty column-names. It will also skip column-names that are not
|
||||
/// found in the schema if <paramref name="opt"/> is true.
|
||||
/// </summary>
|
||||
/// <param name="schema">The schema over which roles are defined</param>
|
||||
/// <param name="opt">Whether to consider the column names specified "optional" or not. If <c>false</c> then any non-empty
|
||||
/// values for the column names that does not appear in <paramref name="schema"/> will result in an exception being thrown,
|
||||
/// but if <c>true</c> such values will be ignored</param>
|
||||
/// <param name="roles">The column role to column name mappings</param>
|
||||
public RoleMappedSchema(DataViewSchema schema, bool opt = false, params KeyValuePair<ColumnRole, string>[] roles)
|
||||
: this(Contracts.CheckRef(schema, nameof(schema)), Contracts.CheckRef(roles, nameof(roles)), opt)
|
||||
{
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Constructor given a schema, and mapping pairs of roles to columns in the schema.
|
||||
/// This skips null or empty column names. It will also skip column-names that are not
|
||||
/// found in the schema if <paramref name="opt"/> is true.
|
||||
/// </summary>
|
||||
/// <param name="schema">The schema over which roles are defined</param>
|
||||
/// <param name="roles">The column role to column name mappings</param>
|
||||
/// <param name="opt">Whether to consider the column names specified "optional" or not. If <c>false</c> then any non-empty
|
||||
/// values for the column names that does not appear in <paramref name="schema"/> will result in an exception being thrown,
|
||||
/// but if <c>true</c> such values will be ignored</param>
|
||||
public RoleMappedSchema(DataViewSchema schema, IEnumerable<KeyValuePair<ColumnRole, string>> roles, bool opt = false)
|
||||
: this(Contracts.CheckRef(schema, nameof(schema)),
|
||||
MapFromNames(schema, Contracts.CheckRef(roles, nameof(roles)), opt))
|
||||
{
|
||||
}
|
||||
|
||||
private static IEnumerable<KeyValuePair<ColumnRole, string>> PredefinedRolesHelper(
|
||||
string label, string feature, string group, string weight, string name,
|
||||
IEnumerable<KeyValuePair<ColumnRole, string>> custom = null)
|
||||
{
|
||||
if (!string.IsNullOrWhiteSpace(label))
|
||||
yield return ColumnRole.Label.Bind(label);
|
||||
if (!string.IsNullOrWhiteSpace(feature))
|
||||
yield return ColumnRole.Feature.Bind(feature);
|
||||
if (!string.IsNullOrWhiteSpace(group))
|
||||
yield return ColumnRole.Group.Bind(group);
|
||||
if (!string.IsNullOrWhiteSpace(weight))
|
||||
yield return ColumnRole.Weight.Bind(weight);
|
||||
if (!string.IsNullOrWhiteSpace(name))
|
||||
yield return ColumnRole.Name.Bind(name);
|
||||
if (custom != null)
|
||||
{
|
||||
foreach (var role in custom)
|
||||
yield return role;
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Convenience constructor for role-mappings over the commonly used roles. Note that if any column name specified
|
||||
/// is <c>null</c> or whitespace, it is ignored.
|
||||
/// </summary>
|
||||
/// <param name="schema">The schema over which roles are defined</param>
|
||||
/// <param name="label">The column name that will be mapped to the <see cref="ColumnRole.Label"/> role</param>
|
||||
/// <param name="feature">The column name that will be mapped to the <see cref="ColumnRole.Feature"/> role</param>
|
||||
/// <param name="group">The column name that will be mapped to the <see cref="ColumnRole.Group"/> role</param>
|
||||
/// <param name="weight">The column name that will be mapped to the <see cref="ColumnRole.Weight"/> role</param>
|
||||
/// <param name="name">The column name that will be mapped to the <see cref="ColumnRole.Name"/> role</param>
|
||||
/// <param name="custom">Any additional desired custom column role mappings</param>
|
||||
/// <param name="opt">Whether to consider the column names specified "optional" or not. If <c>false</c> then any non-empty
|
||||
/// values for the column names that does not appear in <paramref name="schema"/> will result in an exception being thrown,
|
||||
/// but if <c>true</c> such values will be ignored</param>
|
||||
public RoleMappedSchema(DataViewSchema schema, string label, string feature,
|
||||
string group = null, string weight = null, string name = null,
|
||||
IEnumerable<KeyValuePair<RoleMappedSchema.ColumnRole, string>> custom = null, bool opt = false)
|
||||
: this(Contracts.CheckRef(schema, nameof(schema)), PredefinedRolesHelper(label, feature, group, weight, name, custom), opt)
|
||||
{
|
||||
Contracts.CheckValueOrNull(label);
|
||||
Contracts.CheckValueOrNull(feature);
|
||||
Contracts.CheckValueOrNull(group);
|
||||
Contracts.CheckValueOrNull(weight);
|
||||
Contracts.CheckValueOrNull(name);
|
||||
Contracts.CheckValueOrNull(custom);
|
||||
foreach (var info in roleAndList.Value)
|
||||
yield return new KeyValuePair<ColumnRole, DataViewSchema.Column>(roleAndList.Key, info);
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Encapsulates an <see cref="IDataView"/> plus a corresponding <see cref="RoleMappedSchema"/>.
|
||||
/// Note that the schema of <see cref="RoleMappedSchema.Schema"/> of <see cref="Schema"/> is
|
||||
/// guaranteed to equal the the <see cref="IDataView.Schema"/> of <see cref="Data"/>.
|
||||
/// An enumerable over all role-column associations within this object.
|
||||
/// </summary>
|
||||
[BestFriend]
|
||||
internal sealed class RoleMappedData
|
||||
public IEnumerable<KeyValuePair<ColumnRole, string>> GetColumnRoleNames()
|
||||
{
|
||||
/// <summary>
|
||||
/// The data.
|
||||
/// </summary>
|
||||
public IDataView Data { get; }
|
||||
|
||||
/// <summary>
|
||||
/// The role mapped schema. Note that <see cref="Schema"/>'s <see cref="RoleMappedSchema.Schema"/> is
|
||||
/// guaranteed to be the same as <see cref="Data"/>'s <see cref="IDataView.Schema"/>.
|
||||
/// </summary>
|
||||
public RoleMappedSchema Schema { get; }
|
||||
|
||||
private RoleMappedData(IDataView data, RoleMappedSchema schema)
|
||||
foreach (var roleAndList in _map)
|
||||
{
|
||||
Contracts.AssertValue(data);
|
||||
Contracts.AssertValue(schema);
|
||||
Contracts.Assert(schema.Schema == data.Schema);
|
||||
Data = data;
|
||||
Schema = schema;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Constructor given a data view, and mapping pairs of roles to columns in the data view's schema.
|
||||
/// This skips null or empty column-names. It will also skip column-names that are not
|
||||
/// found in the schema if <paramref name="opt"/> is true.
|
||||
/// </summary>
|
||||
/// <param name="data">The data over which roles are defined</param>
|
||||
/// <param name="opt">Whether to consider the column names specified "optional" or not. If <c>false</c> then any non-empty
|
||||
/// values for the column names that does not appear in <paramref name="data"/>'s schema will result in an exception being thrown,
|
||||
/// but if <c>true</c> such values will be ignored</param>
|
||||
/// <param name="roles">The column role to column name mappings</param>
|
||||
public RoleMappedData(IDataView data, bool opt = false, params KeyValuePair<RoleMappedSchema.ColumnRole, string>[] roles)
|
||||
: this(Contracts.CheckRef(data, nameof(data)), new RoleMappedSchema(data.Schema, Contracts.CheckRef(roles, nameof(roles)), opt))
|
||||
{
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Constructor given a data view, and mapping pairs of roles to columns in the data view's schema.
|
||||
/// This skips null or empty column-names. It will also skip column-names that are not
|
||||
/// found in the schema if <paramref name="opt"/> is true.
|
||||
/// </summary>
|
||||
/// <param name="data">The schema over which roles are defined</param>
|
||||
/// <param name="roles">The column role to column name mappings</param>
|
||||
/// <param name="opt">Whether to consider the column names specified "optional" or not. If <c>false</c> then any non-empty
|
||||
/// values for the column names that does not appear in <paramref name="data"/>'s schema will result in an exception being thrown,
|
||||
/// but if <c>true</c> such values will be ignored</param>
|
||||
public RoleMappedData(IDataView data, IEnumerable<KeyValuePair<RoleMappedSchema.ColumnRole, string>> roles, bool opt = false)
|
||||
: this(Contracts.CheckRef(data, nameof(data)), new RoleMappedSchema(data.Schema, Contracts.CheckRef(roles, nameof(roles)), opt))
|
||||
{
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Convenience constructor for role-mappings over the commonly used roles. Note that if any column name specified
|
||||
/// is <c>null</c> or whitespace, it is ignored.
|
||||
/// </summary>
|
||||
/// <param name="data">The data over which roles are defined</param>
|
||||
/// <param name="label">The column name that will be mapped to the <see cref="RoleMappedSchema.ColumnRole.Label"/> role</param>
|
||||
/// <param name="feature">The column name that will be mapped to the <see cref="RoleMappedSchema.ColumnRole.Feature"/> role</param>
|
||||
/// <param name="group">The column name that will be mapped to the <see cref="RoleMappedSchema.ColumnRole.Group"/> role</param>
|
||||
/// <param name="weight">The column name that will be mapped to the <see cref="RoleMappedSchema.ColumnRole.Weight"/> role</param>
|
||||
/// <param name="name">The column name that will be mapped to the <see cref="RoleMappedSchema.ColumnRole.Name"/> role</param>
|
||||
/// <param name="custom">Any additional desired custom column role mappings</param>
|
||||
/// <param name="opt">Whether to consider the column names specified "optional" or not. If <c>false</c> then any non-empty
|
||||
/// values for the column names that does not appear in <paramref name="data"/>'s schema will result in an exception being thrown,
|
||||
/// but if <c>true</c> such values will be ignored</param>
|
||||
public RoleMappedData(IDataView data, string label, string feature,
|
||||
string group = null, string weight = null, string name = null,
|
||||
IEnumerable<KeyValuePair<RoleMappedSchema.ColumnRole, string>> custom = null, bool opt = false)
|
||||
: this(Contracts.CheckRef(data, nameof(data)),
|
||||
new RoleMappedSchema(data.Schema, label, feature, group, weight, name, custom, opt))
|
||||
{
|
||||
Contracts.CheckValueOrNull(label);
|
||||
Contracts.CheckValueOrNull(feature);
|
||||
Contracts.CheckValueOrNull(group);
|
||||
Contracts.CheckValueOrNull(weight);
|
||||
Contracts.CheckValueOrNull(name);
|
||||
Contracts.CheckValueOrNull(custom);
|
||||
foreach (var info in roleAndList.Value)
|
||||
yield return new KeyValuePair<ColumnRole, string>(roleAndList.Key, info.Name);
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// An enumerable over all role-column associations for the given role. This is a helper function
|
||||
/// for implementing the <see cref="ISchemaBoundMapper.GetInputColumnRoles"/> method.
|
||||
/// </summary>
|
||||
public IEnumerable<KeyValuePair<ColumnRole, string>> GetColumnRoleNames(ColumnRole role)
|
||||
{
|
||||
if (_map.TryGetValue(role.Value, out var list))
|
||||
{
|
||||
foreach (var info in list)
|
||||
yield return new KeyValuePair<ColumnRole, string>(role, info.Name);
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Returns the <see cref="DataViewSchema.Column"/> corresponding to <paramref name="role"/> if there is
|
||||
/// exactly one such mapping, and otherwise throws an exception.
|
||||
/// </summary>
|
||||
/// <param name="role">The role to look up</param>
|
||||
/// <returns>The column corresponding to that role, assuming there was only one column
|
||||
/// mapped to that</returns>
|
||||
public DataViewSchema.Column GetUniqueColumn(ColumnRole role)
|
||||
{
|
||||
var infos = GetColumns(role);
|
||||
if (Utils.Size(infos) != 1)
|
||||
throw Contracts.Except("Expected exactly one column with role '{0}', but found {1}.", role.Value, Utils.Size(infos));
|
||||
return infos[0];
|
||||
}
|
||||
|
||||
private static Dictionary<string, IReadOnlyList<DataViewSchema.Column>> Copy(Dictionary<string, List<DataViewSchema.Column>> map)
|
||||
{
|
||||
var copy = new Dictionary<string, IReadOnlyList<DataViewSchema.Column>>(map.Count);
|
||||
foreach (var kvp in map)
|
||||
{
|
||||
Contracts.Assert(Utils.Size(kvp.Value) > 0);
|
||||
var cols = kvp.Value.ToArray();
|
||||
copy.Add(kvp.Key, cols);
|
||||
}
|
||||
return copy;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Constructor given a schema, and mapping pairs of roles to columns in the schema.
|
||||
/// This skips null or empty column-names. It will also skip column-names that are not
|
||||
/// found in the schema if <paramref name="opt"/> is true.
|
||||
/// </summary>
|
||||
/// <param name="schema">The schema over which roles are defined</param>
|
||||
/// <param name="opt">Whether to consider the column names specified "optional" or not. If <c>false</c> then any non-empty
|
||||
/// values for the column names that does not appear in <paramref name="schema"/> will result in an exception being thrown,
|
||||
/// but if <c>true</c> such values will be ignored</param>
|
||||
/// <param name="roles">The column role to column name mappings</param>
|
||||
public RoleMappedSchema(DataViewSchema schema, bool opt = false, params KeyValuePair<ColumnRole, string>[] roles)
|
||||
: this(Contracts.CheckRef(schema, nameof(schema)), Contracts.CheckRef(roles, nameof(roles)), opt)
|
||||
{
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Constructor given a schema, and mapping pairs of roles to columns in the schema.
|
||||
/// This skips null or empty column names. It will also skip column-names that are not
|
||||
/// found in the schema if <paramref name="opt"/> is true.
|
||||
/// </summary>
|
||||
/// <param name="schema">The schema over which roles are defined</param>
|
||||
/// <param name="roles">The column role to column name mappings</param>
|
||||
/// <param name="opt">Whether to consider the column names specified "optional" or not. If <c>false</c> then any non-empty
|
||||
/// values for the column names that does not appear in <paramref name="schema"/> will result in an exception being thrown,
|
||||
/// but if <c>true</c> such values will be ignored</param>
|
||||
public RoleMappedSchema(DataViewSchema schema, IEnumerable<KeyValuePair<ColumnRole, string>> roles, bool opt = false)
|
||||
: this(Contracts.CheckRef(schema, nameof(schema)),
|
||||
MapFromNames(schema, Contracts.CheckRef(roles, nameof(roles)), opt))
|
||||
{
|
||||
}
|
||||
|
||||
private static IEnumerable<KeyValuePair<ColumnRole, string>> PredefinedRolesHelper(
|
||||
string label, string feature, string group, string weight, string name,
|
||||
IEnumerable<KeyValuePair<ColumnRole, string>> custom = null)
|
||||
{
|
||||
if (!string.IsNullOrWhiteSpace(label))
|
||||
yield return ColumnRole.Label.Bind(label);
|
||||
if (!string.IsNullOrWhiteSpace(feature))
|
||||
yield return ColumnRole.Feature.Bind(feature);
|
||||
if (!string.IsNullOrWhiteSpace(group))
|
||||
yield return ColumnRole.Group.Bind(group);
|
||||
if (!string.IsNullOrWhiteSpace(weight))
|
||||
yield return ColumnRole.Weight.Bind(weight);
|
||||
if (!string.IsNullOrWhiteSpace(name))
|
||||
yield return ColumnRole.Name.Bind(name);
|
||||
if (custom != null)
|
||||
{
|
||||
foreach (var role in custom)
|
||||
yield return role;
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Convenience constructor for role-mappings over the commonly used roles. Note that if any column name specified
|
||||
/// is <c>null</c> or whitespace, it is ignored.
|
||||
/// </summary>
|
||||
/// <param name="schema">The schema over which roles are defined</param>
|
||||
/// <param name="label">The column name that will be mapped to the <see cref="ColumnRole.Label"/> role</param>
|
||||
/// <param name="feature">The column name that will be mapped to the <see cref="ColumnRole.Feature"/> role</param>
|
||||
/// <param name="group">The column name that will be mapped to the <see cref="ColumnRole.Group"/> role</param>
|
||||
/// <param name="weight">The column name that will be mapped to the <see cref="ColumnRole.Weight"/> role</param>
|
||||
/// <param name="name">The column name that will be mapped to the <see cref="ColumnRole.Name"/> role</param>
|
||||
/// <param name="custom">Any additional desired custom column role mappings</param>
|
||||
/// <param name="opt">Whether to consider the column names specified "optional" or not. If <c>false</c> then any non-empty
|
||||
/// values for the column names that does not appear in <paramref name="schema"/> will result in an exception being thrown,
|
||||
/// but if <c>true</c> such values will be ignored</param>
|
||||
public RoleMappedSchema(DataViewSchema schema, string label, string feature,
|
||||
string group = null, string weight = null, string name = null,
|
||||
IEnumerable<KeyValuePair<RoleMappedSchema.ColumnRole, string>> custom = null, bool opt = false)
|
||||
: this(Contracts.CheckRef(schema, nameof(schema)), PredefinedRolesHelper(label, feature, group, weight, name, custom), opt)
|
||||
{
|
||||
Contracts.CheckValueOrNull(label);
|
||||
Contracts.CheckValueOrNull(feature);
|
||||
Contracts.CheckValueOrNull(group);
|
||||
Contracts.CheckValueOrNull(weight);
|
||||
Contracts.CheckValueOrNull(name);
|
||||
Contracts.CheckValueOrNull(custom);
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Encapsulates an <see cref="IDataView"/> plus a corresponding <see cref="RoleMappedSchema"/>.
|
||||
/// Note that the schema of <see cref="RoleMappedSchema.Schema"/> of <see cref="Schema"/> is
|
||||
/// guaranteed to equal the the <see cref="IDataView.Schema"/> of <see cref="Data"/>.
|
||||
/// </summary>
|
||||
[BestFriend]
|
||||
internal sealed class RoleMappedData
|
||||
{
|
||||
/// <summary>
|
||||
/// The data.
|
||||
/// </summary>
|
||||
public IDataView Data { get; }
|
||||
|
||||
/// <summary>
|
||||
/// The role mapped schema. Note that <see cref="Schema"/>'s <see cref="RoleMappedSchema.Schema"/> is
|
||||
/// guaranteed to be the same as <see cref="Data"/>'s <see cref="IDataView.Schema"/>.
|
||||
/// </summary>
|
||||
public RoleMappedSchema Schema { get; }
|
||||
|
||||
private RoleMappedData(IDataView data, RoleMappedSchema schema)
|
||||
{
|
||||
Contracts.AssertValue(data);
|
||||
Contracts.AssertValue(schema);
|
||||
Contracts.Assert(schema.Schema == data.Schema);
|
||||
Data = data;
|
||||
Schema = schema;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Constructor given a data view, and mapping pairs of roles to columns in the data view's schema.
|
||||
/// This skips null or empty column-names. It will also skip column-names that are not
|
||||
/// found in the schema if <paramref name="opt"/> is true.
|
||||
/// </summary>
|
||||
/// <param name="data">The data over which roles are defined</param>
|
||||
/// <param name="opt">Whether to consider the column names specified "optional" or not. If <c>false</c> then any non-empty
|
||||
/// values for the column names that does not appear in <paramref name="data"/>'s schema will result in an exception being thrown,
|
||||
/// but if <c>true</c> such values will be ignored</param>
|
||||
/// <param name="roles">The column role to column name mappings</param>
|
||||
public RoleMappedData(IDataView data, bool opt = false, params KeyValuePair<RoleMappedSchema.ColumnRole, string>[] roles)
|
||||
: this(Contracts.CheckRef(data, nameof(data)), new RoleMappedSchema(data.Schema, Contracts.CheckRef(roles, nameof(roles)), opt))
|
||||
{
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Constructor given a data view, and mapping pairs of roles to columns in the data view's schema.
|
||||
/// This skips null or empty column-names. It will also skip column-names that are not
|
||||
/// found in the schema if <paramref name="opt"/> is true.
|
||||
/// </summary>
|
||||
/// <param name="data">The schema over which roles are defined</param>
|
||||
/// <param name="roles">The column role to column name mappings</param>
|
||||
/// <param name="opt">Whether to consider the column names specified "optional" or not. If <c>false</c> then any non-empty
|
||||
/// values for the column names that does not appear in <paramref name="data"/>'s schema will result in an exception being thrown,
|
||||
/// but if <c>true</c> such values will be ignored</param>
|
||||
public RoleMappedData(IDataView data, IEnumerable<KeyValuePair<RoleMappedSchema.ColumnRole, string>> roles, bool opt = false)
|
||||
: this(Contracts.CheckRef(data, nameof(data)), new RoleMappedSchema(data.Schema, Contracts.CheckRef(roles, nameof(roles)), opt))
|
||||
{
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Convenience constructor for role-mappings over the commonly used roles. Note that if any column name specified
|
||||
/// is <c>null</c> or whitespace, it is ignored.
|
||||
/// </summary>
|
||||
/// <param name="data">The data over which roles are defined</param>
|
||||
/// <param name="label">The column name that will be mapped to the <see cref="RoleMappedSchema.ColumnRole.Label"/> role</param>
|
||||
/// <param name="feature">The column name that will be mapped to the <see cref="RoleMappedSchema.ColumnRole.Feature"/> role</param>
|
||||
/// <param name="group">The column name that will be mapped to the <see cref="RoleMappedSchema.ColumnRole.Group"/> role</param>
|
||||
/// <param name="weight">The column name that will be mapped to the <see cref="RoleMappedSchema.ColumnRole.Weight"/> role</param>
|
||||
/// <param name="name">The column name that will be mapped to the <see cref="RoleMappedSchema.ColumnRole.Name"/> role</param>
|
||||
/// <param name="custom">Any additional desired custom column role mappings</param>
|
||||
/// <param name="opt">Whether to consider the column names specified "optional" or not. If <c>false</c> then any non-empty
|
||||
/// values for the column names that does not appear in <paramref name="data"/>'s schema will result in an exception being thrown,
|
||||
/// but if <c>true</c> such values will be ignored</param>
|
||||
public RoleMappedData(IDataView data, string label, string feature,
|
||||
string group = null, string weight = null, string name = null,
|
||||
IEnumerable<KeyValuePair<RoleMappedSchema.ColumnRole, string>> custom = null, bool opt = false)
|
||||
: this(Contracts.CheckRef(data, nameof(data)),
|
||||
new RoleMappedSchema(data.Schema, label, feature, group, weight, name, custom, opt))
|
||||
{
|
||||
Contracts.CheckValueOrNull(label);
|
||||
Contracts.CheckValueOrNull(feature);
|
||||
Contracts.CheckValueOrNull(group);
|
||||
Contracts.CheckValueOrNull(weight);
|
||||
Contracts.CheckValueOrNull(name);
|
||||
Contracts.CheckValueOrNull(custom);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -4,78 +4,77 @@
|
|||
|
||||
using Microsoft.ML.Runtime;
|
||||
|
||||
namespace Microsoft.ML.Data
|
||||
namespace Microsoft.ML.Data;
|
||||
|
||||
// REVIEW: Since each cursor will create a channel, it would be great that the RootCursorBase takes
|
||||
// ownership of the channel so the derived classes don't have to.
|
||||
|
||||
/// <summary>
|
||||
/// Base class for creating a cursor with default tracking of <see cref="Position"/>. All calls to <see cref="MoveNext"/>
|
||||
/// will be seen by subclasses of this cursor. For a cursor that has an input cursor and does not need notification on
|
||||
/// <see cref="MoveNext"/>, use <see cref="SynchronizedCursorBase"/> instead.
|
||||
/// </summary>
|
||||
[BestFriend]
|
||||
internal abstract class RootCursorBase : DataViewRowCursor
|
||||
{
|
||||
// REVIEW: Since each cursor will create a channel, it would be great that the RootCursorBase takes
|
||||
// ownership of the channel so the derived classes don't have to.
|
||||
protected readonly IChannel Ch;
|
||||
private long _position;
|
||||
private bool _disposed;
|
||||
|
||||
/// <summary>
|
||||
/// Base class for creating a cursor with default tracking of <see cref="Position"/>. All calls to <see cref="MoveNext"/>
|
||||
/// will be seen by subclasses of this cursor. For a cursor that has an input cursor and does not need notification on
|
||||
/// <see cref="MoveNext"/>, use <see cref="SynchronizedCursorBase"/> instead.
|
||||
/// Zero-based position of the cursor.
|
||||
/// </summary>
|
||||
[BestFriend]
|
||||
internal abstract class RootCursorBase : DataViewRowCursor
|
||||
public sealed override long Position => _position;
|
||||
|
||||
/// <summary>
|
||||
/// Convenience property for checking whether the current state of the cursor is one where data can be fetched.
|
||||
/// </summary>
|
||||
protected bool IsGood => _position >= 0;
|
||||
|
||||
/// <summary>
|
||||
/// Creates an instance of the <see cref="RootCursorBase"/> class
|
||||
/// </summary>
|
||||
/// <param name="provider">Channel provider</param>
|
||||
protected RootCursorBase(IChannelProvider provider)
|
||||
{
|
||||
protected readonly IChannel Ch;
|
||||
private long _position;
|
||||
private bool _disposed;
|
||||
Contracts.CheckValue(provider, nameof(provider));
|
||||
Ch = provider.Start("Cursor");
|
||||
|
||||
/// <summary>
|
||||
/// Zero-based position of the cursor.
|
||||
/// </summary>
|
||||
public sealed override long Position => _position;
|
||||
_position = -1;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Convenience property for checking whether the current state of the cursor is one where data can be fetched.
|
||||
/// </summary>
|
||||
protected bool IsGood => _position >= 0;
|
||||
|
||||
/// <summary>
|
||||
/// Creates an instance of the <see cref="RootCursorBase"/> class
|
||||
/// </summary>
|
||||
/// <param name="provider">Channel provider</param>
|
||||
protected RootCursorBase(IChannelProvider provider)
|
||||
protected override void Dispose(bool disposing)
|
||||
{
|
||||
if (_disposed)
|
||||
return;
|
||||
if (disposing)
|
||||
{
|
||||
Contracts.CheckValue(provider, nameof(provider));
|
||||
Ch = provider.Start("Cursor");
|
||||
|
||||
Ch.Dispose();
|
||||
_position = -1;
|
||||
}
|
||||
_disposed = true;
|
||||
base.Dispose(disposing);
|
||||
|
||||
protected override void Dispose(bool disposing)
|
||||
{
|
||||
if (_disposed)
|
||||
return;
|
||||
if (disposing)
|
||||
{
|
||||
Ch.Dispose();
|
||||
_position = -1;
|
||||
}
|
||||
_disposed = true;
|
||||
base.Dispose(disposing);
|
||||
|
||||
}
|
||||
|
||||
public sealed override bool MoveNext()
|
||||
{
|
||||
if (_disposed)
|
||||
return false;
|
||||
|
||||
if (MoveNextCore())
|
||||
{
|
||||
_position++;
|
||||
return true;
|
||||
}
|
||||
|
||||
Dispose();
|
||||
return false;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Core implementation of <see cref="MoveNext"/>, called if no prior call to this method
|
||||
/// has returned <see langword="false"/>.
|
||||
/// </summary>
|
||||
protected abstract bool MoveNextCore();
|
||||
}
|
||||
|
||||
public sealed override bool MoveNext()
|
||||
{
|
||||
if (_disposed)
|
||||
return false;
|
||||
|
||||
if (MoveNextCore())
|
||||
{
|
||||
_position++;
|
||||
return true;
|
||||
}
|
||||
|
||||
Dispose();
|
||||
return false;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Core implementation of <see cref="MoveNext"/>, called if no prior call to this method
|
||||
/// has returned <see langword="false"/>.
|
||||
/// </summary>
|
||||
protected abstract bool MoveNextCore();
|
||||
}
|
||||
|
|
|
@ -4,26 +4,25 @@
|
|||
|
||||
using System.Collections.Generic;
|
||||
|
||||
namespace Microsoft.ML.Data
|
||||
{
|
||||
[BestFriend]
|
||||
internal static class SchemaExtensions
|
||||
{
|
||||
public static DataViewSchema MakeSchema(IEnumerable<DataViewSchema.DetachedColumn> columns)
|
||||
{
|
||||
var builder = new DataViewSchema.Builder();
|
||||
builder.AddColumns(columns);
|
||||
return builder.ToSchema();
|
||||
}
|
||||
namespace Microsoft.ML.Data;
|
||||
|
||||
/// <summary>
|
||||
/// Legacy method to get the column index.
|
||||
/// DO NOT USE: use <see cref="DataViewSchema.GetColumnOrNull"/> instead.
|
||||
/// </summary>
|
||||
public static bool TryGetColumnIndex(this DataViewSchema schema, string name, out int col)
|
||||
{
|
||||
col = schema.GetColumnOrNull(name)?.Index ?? -1;
|
||||
return col >= 0;
|
||||
}
|
||||
[BestFriend]
|
||||
internal static class SchemaExtensions
|
||||
{
|
||||
public static DataViewSchema MakeSchema(IEnumerable<DataViewSchema.DetachedColumn> columns)
|
||||
{
|
||||
var builder = new DataViewSchema.Builder();
|
||||
builder.AddColumns(columns);
|
||||
return builder.ToSchema();
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Legacy method to get the column index.
|
||||
/// DO NOT USE: use <see cref="DataViewSchema.GetColumnOrNull"/> instead.
|
||||
/// </summary>
|
||||
public static bool TryGetColumnIndex(this DataViewSchema schema, string name, out int col)
|
||||
{
|
||||
col = schema.GetColumnOrNull(name)?.Index ?? -1;
|
||||
return col >= 0;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -6,266 +6,265 @@ using System;
|
|||
using System.Collections.Generic;
|
||||
using Microsoft.ML.EntryPoints;
|
||||
|
||||
namespace Microsoft.ML.Runtime
|
||||
namespace Microsoft.ML.Runtime;
|
||||
|
||||
/// <summary>
|
||||
/// Instances of this class are used to set up a bundle of named delegates. These
|
||||
/// delegates are registered through <see cref="Register{TRet}"/> and its overloads.
|
||||
/// Once all registrations are done, <see cref="Publish"/> is called and a message
|
||||
/// of type <see cref="Bundle"/> is sent through the input channel
|
||||
/// provider. The intended use case is that any information surfaced through these
|
||||
/// delegates will be published in some fashion, with the target scenario being
|
||||
/// that the library will publish some sort of restful API.
|
||||
/// </summary>
|
||||
[BestFriend]
|
||||
internal sealed class ServerChannel : ServerChannel.IPendingBundleNotification, IDisposable
|
||||
{
|
||||
// See ServerChannel.md for a more elaborate discussion of high level usage and design.
|
||||
private readonly IChannelProvider _chp;
|
||||
private readonly string _identifier;
|
||||
|
||||
// This holds the running collection of named delegates, if any. The dictionary itself
|
||||
// is lazily initialized only when a listener
|
||||
private Dictionary<string, Delegate> _toPublish;
|
||||
private Action<Bundle> _onPublish;
|
||||
private Bundle _published;
|
||||
private bool _disposed;
|
||||
|
||||
/// <summary>
|
||||
/// Instances of this class are used to set up a bundle of named delegates. These
|
||||
/// delegates are registered through <see cref="Register{TRet}"/> and its overloads.
|
||||
/// Once all registrations are done, <see cref="Publish"/> is called and a message
|
||||
/// of type <see cref="Bundle"/> is sent through the input channel
|
||||
/// provider. The intended use case is that any information surfaced through these
|
||||
/// delegates will be published in some fashion, with the target scenario being
|
||||
/// that the library will publish some sort of restful API.
|
||||
/// Returns either this object, or <c>null</c> if there are no listeners on this server
|
||||
/// channel. This can be used in conjunction with the <c>?.</c> operator to have more
|
||||
/// performant though more robust calls to <see cref="Register{TRet}"/> and
|
||||
/// <see cref="Publish"/>.
|
||||
/// </summary>
|
||||
[BestFriend]
|
||||
internal sealed class ServerChannel : ServerChannel.IPendingBundleNotification, IDisposable
|
||||
private ServerChannel ThisIfActiveOrNull => _toPublish == null ? null : this;
|
||||
|
||||
private ServerChannel(IChannelProvider provider, string idenfier)
|
||||
{
|
||||
// See ServerChannel.md for a more elaborate discussion of high level usage and design.
|
||||
private readonly IChannelProvider _chp;
|
||||
private readonly string _identifier;
|
||||
Contracts.AssertValue(provider);
|
||||
_chp = provider;
|
||||
_chp.AssertNonWhiteSpace(idenfier);
|
||||
_identifier = idenfier;
|
||||
}
|
||||
|
||||
// This holds the running collection of named delegates, if any. The dictionary itself
|
||||
// is lazily initialized only when a listener
|
||||
private Dictionary<string, Delegate> _toPublish;
|
||||
private Action<Bundle> _onPublish;
|
||||
private Bundle _published;
|
||||
private bool _disposed;
|
||||
|
||||
/// <summary>
|
||||
/// Returns either this object, or <c>null</c> if there are no listeners on this server
|
||||
/// channel. This can be used in conjunction with the <c>?.</c> operator to have more
|
||||
/// performant though more robust calls to <see cref="Register{TRet}"/> and
|
||||
/// <see cref="Publish"/>.
|
||||
/// </summary>
|
||||
private ServerChannel ThisIfActiveOrNull => _toPublish == null ? null : this;
|
||||
|
||||
private ServerChannel(IChannelProvider provider, string idenfier)
|
||||
/// <summary>
|
||||
/// Starts a new server channel.
|
||||
/// </summary>
|
||||
/// <param name="provider">The channel provider, on which to send
|
||||
/// the notification that a server is being constructed</param>
|
||||
/// <param name="identifier">A semi-unique identifier for this
|
||||
/// "bundle" that is being constructed</param>
|
||||
/// <returns>The constructed server channel, or <c>null</c> if there
|
||||
/// was no listeners for server channels registered on <paramref name="provider"/></returns>
|
||||
public static ServerChannel Start(IChannelProvider provider, string identifier)
|
||||
{
|
||||
Contracts.CheckValue(provider, nameof(provider));
|
||||
provider.CheckNonWhiteSpace(identifier, nameof(identifier));
|
||||
using (var pipe = provider.StartPipe<IPendingBundleNotification>("Server"))
|
||||
{
|
||||
Contracts.AssertValue(provider);
|
||||
_chp = provider;
|
||||
_chp.AssertNonWhiteSpace(idenfier);
|
||||
_identifier = idenfier;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Starts a new server channel.
|
||||
/// </summary>
|
||||
/// <param name="provider">The channel provider, on which to send
|
||||
/// the notification that a server is being constructed</param>
|
||||
/// <param name="identifier">A semi-unique identifier for this
|
||||
/// "bundle" that is being constructed</param>
|
||||
/// <returns>The constructed server channel, or <c>null</c> if there
|
||||
/// was no listeners for server channels registered on <paramref name="provider"/></returns>
|
||||
public static ServerChannel Start(IChannelProvider provider, string identifier)
|
||||
{
|
||||
Contracts.CheckValue(provider, nameof(provider));
|
||||
provider.CheckNonWhiteSpace(identifier, nameof(identifier));
|
||||
using (var pipe = provider.StartPipe<IPendingBundleNotification>("Server"))
|
||||
{
|
||||
var sc = new ServerChannel(provider, identifier);
|
||||
pipe.Send(sc);
|
||||
return sc.ThisIfActiveOrNull;
|
||||
}
|
||||
}
|
||||
|
||||
public void Dispose()
|
||||
{
|
||||
if (!_disposed)
|
||||
{
|
||||
_disposed = true;
|
||||
_published?.Done();
|
||||
}
|
||||
}
|
||||
|
||||
private void RegisterCore(string name, Delegate func)
|
||||
{
|
||||
_chp.CheckNonEmpty(name, nameof(name));
|
||||
_chp.CheckValue(func, nameof(func));
|
||||
_chp.Check(_published == null, "Cannot expose more interfaces once a server channel has been published");
|
||||
_chp.AssertValue(_toPublish);
|
||||
|
||||
_toPublish.Add(name, func);
|
||||
}
|
||||
|
||||
public void Register<TRet>(string name, Func<TRet> func)
|
||||
{
|
||||
if (_toPublish != null)
|
||||
RegisterCore(name, func);
|
||||
}
|
||||
|
||||
public void Register<T1, TRet>(string name, Func<T1, TRet> func)
|
||||
{
|
||||
if (_toPublish != null)
|
||||
RegisterCore(name, func);
|
||||
}
|
||||
|
||||
public void Register<T1, T2, TRet>(string name, Func<T1, T2, TRet> func)
|
||||
{
|
||||
if (_toPublish != null)
|
||||
RegisterCore(name, func);
|
||||
}
|
||||
|
||||
public void Register<T1, T2, T3, TRet>(string name, Func<T1, T2, T3, TRet> func)
|
||||
{
|
||||
if (_toPublish != null)
|
||||
RegisterCore(name, func);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Finalizes all registrations of delegates, and pipes the bundle of objects
|
||||
/// in a <see cref="Bundle"/> up through the pipe to be consumed by any
|
||||
/// listeners.
|
||||
/// </summary>
|
||||
public void Publish()
|
||||
{
|
||||
_chp.Assert((_toPublish == null) == (_onPublish == null));
|
||||
if (_toPublish == null)
|
||||
return;
|
||||
_chp.Check(_published == null, "Cannot republish once a server channel has been published");
|
||||
_published = new Bundle(this);
|
||||
_onPublish(_published);
|
||||
}
|
||||
|
||||
public void Acknowledge(Action<Bundle> toDo)
|
||||
{
|
||||
_chp.CheckValue(toDo, nameof(toDo));
|
||||
_chp.Assert((_onPublish == null) == (_toPublish == null));
|
||||
if (_toPublish == null)
|
||||
_toPublish = new Dictionary<string, Delegate>();
|
||||
_onPublish += toDo;
|
||||
_chp.AssertValue(_onPublish);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Entry point factory for creating <see cref="IServer"/> instances.
|
||||
/// </summary>
|
||||
[TlcModule.ComponentKind("Server")]
|
||||
public interface IServerFactory : IComponentFactory<IChannel, IServer>
|
||||
{
|
||||
new IServer CreateComponent(IHostEnvironment env, IChannel ch);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Classes that want to publish the bundles from server channels in some fashion should implement
|
||||
/// this interface. The intended simple use case is that this will be some form of in-process web
|
||||
/// server, and then when disposed, they should stop themselves.
|
||||
///
|
||||
/// Note that the primary communication with the server from the client code's perspective is not
|
||||
/// through method calls on this interface, but rather communication through an
|
||||
/// <see cref="IPipe{IPendingBundleNotification}"/> that the server will listen to throughout its
|
||||
/// lifetime.
|
||||
/// </summary>
|
||||
public interface IServer : IDisposable
|
||||
{
|
||||
/// <summary>
|
||||
/// This should return the base address where the server is. If this server is not actually
|
||||
/// serving content at any URL, this property should be null.
|
||||
/// </summary>
|
||||
Uri BaseAddress { get; }
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Creates what might be considered a good "default" server factory, if possible,
|
||||
/// or <c>null</c> if no good default was possible. A <c>null</c> value could be returned,
|
||||
/// for example, if a user opted to remove all implementations of <see cref="IServer"/> and
|
||||
/// the associated <see cref="IServerFactory"/> for security reasons.
|
||||
/// </summary>
|
||||
public static IServerFactory CreateDefaultServerFactoryOrNull(IHostEnvironment env)
|
||||
{
|
||||
Contracts.CheckValue(env, nameof(env));
|
||||
// REVIEW: There should be a better way. There currently isn't,
|
||||
// but there should be. This is pretty horrifying, but it is preferable to
|
||||
// the alternative of having core components depend on an actual server
|
||||
// implementation, since we want those to be removable because of security
|
||||
// concerns in certain environments (since not everyone will be wild about
|
||||
// web servers popping up everywhere).
|
||||
ComponentCatalog.ComponentInfo component;
|
||||
if (!env.ComponentCatalog.TryFindComponent(typeof(IServerFactory), "mini", out component))
|
||||
return null;
|
||||
IServerFactory factory = (IServerFactory)Activator.CreateInstance(component.ArgumentType);
|
||||
var field = factory.GetType().GetField("Port");
|
||||
if (field?.FieldType != typeof(int))
|
||||
return null;
|
||||
field.SetValue(factory, 12345);
|
||||
return factory;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// When a <see cref="ServerChannel"/> is created, the creation method will send an implementation
|
||||
/// is a notification sent through an <see cref="IPipe{IPendingBundleNotification}"/>, to indicate that
|
||||
/// a <see cref="Bundle"/> may be pending soon. Listeners that want to receive the bundle to
|
||||
/// expose it, for example, a web service, should register this interest by passing in an action to be called.
|
||||
/// If no listener registers interest, the server channel that sent the notification will act
|
||||
/// differently by, say, acting as a no-op w.r.t. client calls to it.
|
||||
/// </summary>
|
||||
public interface IPendingBundleNotification
|
||||
{
|
||||
/// <summary>
|
||||
/// Any publisher of the named delegates will call this method, upon receiving an instance
|
||||
/// of this object through the pipe. This method serves two purposes: firstly it detects
|
||||
/// whether anyone is even interested in publishing anything at all, so that we can just
|
||||
/// ignore any input delegates in the case where no one is listening (which, we must expect,
|
||||
/// is the majority of scenarios). The second is that it provides an action to call, once
|
||||
/// all publishing is complete, and <see cref="Publish"/> has been called by the client code.
|
||||
/// </summary>
|
||||
/// <param name="toDo">The callback to perform when all named delegates have been registered,
|
||||
/// and <see cref="Publish"/> is called.</param>
|
||||
void Acknowledge(Action<Bundle> toDo);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// The final bundle of published named delegates that a listener can serve.
|
||||
/// </summary>
|
||||
public sealed class Bundle
|
||||
{
|
||||
/// <summary>
|
||||
/// This contains a name to delegate mappings. The delegates contained herein are gauranteed to be
|
||||
/// some variety of <see cref="Func{TResult}"/>, <see cref="Func{T1, TResult}"/>,
|
||||
/// <see cref="Func{T1, T2, TResult}"/>, etc.
|
||||
/// </summary>
|
||||
public readonly IReadOnlyDictionary<string, Delegate> NameToFuncs;
|
||||
|
||||
/// <summary>
|
||||
/// This should be a more-or-less unique identifier for the type of API this bundle is producing.
|
||||
/// Its intended use is that it will form part of the URL for the RESTful API, so to the extent that
|
||||
/// it contains multiple tokens they must be slash delimited.
|
||||
/// </summary>
|
||||
public readonly string Identifier;
|
||||
|
||||
internal Action Done;
|
||||
|
||||
internal Bundle(ServerChannel sch)
|
||||
{
|
||||
Contracts.AssertValue(sch);
|
||||
|
||||
NameToFuncs = sch._toPublish;
|
||||
Identifier = sch._identifier;
|
||||
}
|
||||
|
||||
public void AddDoneAction(Action onDone)
|
||||
{
|
||||
Done += onDone;
|
||||
}
|
||||
var sc = new ServerChannel(provider, identifier);
|
||||
pipe.Send(sc);
|
||||
return sc.ThisIfActiveOrNull;
|
||||
}
|
||||
}
|
||||
|
||||
[BestFriend]
|
||||
internal static class ServerChannelUtilities
|
||||
public void Dispose()
|
||||
{
|
||||
if (!_disposed)
|
||||
{
|
||||
_disposed = true;
|
||||
_published?.Done();
|
||||
}
|
||||
}
|
||||
|
||||
private void RegisterCore(string name, Delegate func)
|
||||
{
|
||||
_chp.CheckNonEmpty(name, nameof(name));
|
||||
_chp.CheckValue(func, nameof(func));
|
||||
_chp.Check(_published == null, "Cannot expose more interfaces once a server channel has been published");
|
||||
_chp.AssertValue(_toPublish);
|
||||
|
||||
_toPublish.Add(name, func);
|
||||
}
|
||||
|
||||
public void Register<TRet>(string name, Func<TRet> func)
|
||||
{
|
||||
if (_toPublish != null)
|
||||
RegisterCore(name, func);
|
||||
}
|
||||
|
||||
public void Register<T1, TRet>(string name, Func<T1, TRet> func)
|
||||
{
|
||||
if (_toPublish != null)
|
||||
RegisterCore(name, func);
|
||||
}
|
||||
|
||||
public void Register<T1, T2, TRet>(string name, Func<T1, T2, TRet> func)
|
||||
{
|
||||
if (_toPublish != null)
|
||||
RegisterCore(name, func);
|
||||
}
|
||||
|
||||
public void Register<T1, T2, T3, TRet>(string name, Func<T1, T2, T3, TRet> func)
|
||||
{
|
||||
if (_toPublish != null)
|
||||
RegisterCore(name, func);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Finalizes all registrations of delegates, and pipes the bundle of objects
|
||||
/// in a <see cref="Bundle"/> up through the pipe to be consumed by any
|
||||
/// listeners.
|
||||
/// </summary>
|
||||
public void Publish()
|
||||
{
|
||||
_chp.Assert((_toPublish == null) == (_onPublish == null));
|
||||
if (_toPublish == null)
|
||||
return;
|
||||
_chp.Check(_published == null, "Cannot republish once a server channel has been published");
|
||||
_published = new Bundle(this);
|
||||
_onPublish(_published);
|
||||
}
|
||||
|
||||
public void Acknowledge(Action<Bundle> toDo)
|
||||
{
|
||||
_chp.CheckValue(toDo, nameof(toDo));
|
||||
_chp.Assert((_onPublish == null) == (_toPublish == null));
|
||||
if (_toPublish == null)
|
||||
_toPublish = new Dictionary<string, Delegate>();
|
||||
_onPublish += toDo;
|
||||
_chp.AssertValue(_onPublish);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Entry point factory for creating <see cref="IServer"/> instances.
|
||||
/// </summary>
|
||||
[TlcModule.ComponentKind("Server")]
|
||||
public interface IServerFactory : IComponentFactory<IChannel, IServer>
|
||||
{
|
||||
new IServer CreateComponent(IHostEnvironment env, IChannel ch);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Classes that want to publish the bundles from server channels in some fashion should implement
|
||||
/// this interface. The intended simple use case is that this will be some form of in-process web
|
||||
/// server, and then when disposed, they should stop themselves.
|
||||
///
|
||||
/// Note that the primary communication with the server from the client code's perspective is not
|
||||
/// through method calls on this interface, but rather communication through an
|
||||
/// <see cref="IPipe{IPendingBundleNotification}"/> that the server will listen to throughout its
|
||||
/// lifetime.
|
||||
/// </summary>
|
||||
public interface IServer : IDisposable
|
||||
{
|
||||
/// <summary>
|
||||
/// Convenience method for <see cref="ServerChannel.Start"/> that looks more idiomatic to typical
|
||||
/// channel creation methods on <see cref="IChannelProvider"/>.
|
||||
/// This should return the base address where the server is. If this server is not actually
|
||||
/// serving content at any URL, this property should be null.
|
||||
/// </summary>
|
||||
/// <param name="provider">The channel provider.</param>
|
||||
/// <param name="identifier">This is an identifier of the "type" of bundle that is being published,
|
||||
/// and should form a path with forward-slash '/' delimiters.</param>
|
||||
/// <returns>The newly created server channel, or <c>null</c> if there was no listener for
|
||||
/// server channels on <paramref name="provider"/>.</returns>
|
||||
public static ServerChannel StartServerChannel(this IChannelProvider provider, string identifier)
|
||||
Uri BaseAddress { get; }
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Creates what might be considered a good "default" server factory, if possible,
|
||||
/// or <c>null</c> if no good default was possible. A <c>null</c> value could be returned,
|
||||
/// for example, if a user opted to remove all implementations of <see cref="IServer"/> and
|
||||
/// the associated <see cref="IServerFactory"/> for security reasons.
|
||||
/// </summary>
|
||||
public static IServerFactory CreateDefaultServerFactoryOrNull(IHostEnvironment env)
|
||||
{
|
||||
Contracts.CheckValue(env, nameof(env));
|
||||
// REVIEW: There should be a better way. There currently isn't,
|
||||
// but there should be. This is pretty horrifying, but it is preferable to
|
||||
// the alternative of having core components depend on an actual server
|
||||
// implementation, since we want those to be removable because of security
|
||||
// concerns in certain environments (since not everyone will be wild about
|
||||
// web servers popping up everywhere).
|
||||
ComponentCatalog.ComponentInfo component;
|
||||
if (!env.ComponentCatalog.TryFindComponent(typeof(IServerFactory), "mini", out component))
|
||||
return null;
|
||||
IServerFactory factory = (IServerFactory)Activator.CreateInstance(component.ArgumentType);
|
||||
var field = factory.GetType().GetField("Port");
|
||||
if (field?.FieldType != typeof(int))
|
||||
return null;
|
||||
field.SetValue(factory, 12345);
|
||||
return factory;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// When a <see cref="ServerChannel"/> is created, the creation method will send an implementation
|
||||
/// is a notification sent through an <see cref="IPipe{IPendingBundleNotification}"/>, to indicate that
|
||||
/// a <see cref="Bundle"/> may be pending soon. Listeners that want to receive the bundle to
|
||||
/// expose it, for example, a web service, should register this interest by passing in an action to be called.
|
||||
/// If no listener registers interest, the server channel that sent the notification will act
|
||||
/// differently by, say, acting as a no-op w.r.t. client calls to it.
|
||||
/// </summary>
|
||||
public interface IPendingBundleNotification
|
||||
{
|
||||
/// <summary>
|
||||
/// Any publisher of the named delegates will call this method, upon receiving an instance
|
||||
/// of this object through the pipe. This method serves two purposes: firstly it detects
|
||||
/// whether anyone is even interested in publishing anything at all, so that we can just
|
||||
/// ignore any input delegates in the case where no one is listening (which, we must expect,
|
||||
/// is the majority of scenarios). The second is that it provides an action to call, once
|
||||
/// all publishing is complete, and <see cref="Publish"/> has been called by the client code.
|
||||
/// </summary>
|
||||
/// <param name="toDo">The callback to perform when all named delegates have been registered,
|
||||
/// and <see cref="Publish"/> is called.</param>
|
||||
void Acknowledge(Action<Bundle> toDo);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// The final bundle of published named delegates that a listener can serve.
|
||||
/// </summary>
|
||||
public sealed class Bundle
|
||||
{
|
||||
/// <summary>
|
||||
/// This contains a name to delegate mappings. The delegates contained herein are gauranteed to be
|
||||
/// some variety of <see cref="Func{TResult}"/>, <see cref="Func{T1, TResult}"/>,
|
||||
/// <see cref="Func{T1, T2, TResult}"/>, etc.
|
||||
/// </summary>
|
||||
public readonly IReadOnlyDictionary<string, Delegate> NameToFuncs;
|
||||
|
||||
/// <summary>
|
||||
/// This should be a more-or-less unique identifier for the type of API this bundle is producing.
|
||||
/// Its intended use is that it will form part of the URL for the RESTful API, so to the extent that
|
||||
/// it contains multiple tokens they must be slash delimited.
|
||||
/// </summary>
|
||||
public readonly string Identifier;
|
||||
|
||||
internal Action Done;
|
||||
|
||||
internal Bundle(ServerChannel sch)
|
||||
{
|
||||
Contracts.CheckValue(provider, nameof(provider));
|
||||
Contracts.CheckNonWhiteSpace(identifier, nameof(identifier));
|
||||
return ServerChannel.Start(provider, identifier);
|
||||
Contracts.AssertValue(sch);
|
||||
|
||||
NameToFuncs = sch._toPublish;
|
||||
Identifier = sch._identifier;
|
||||
}
|
||||
|
||||
public void AddDoneAction(Action onDone)
|
||||
{
|
||||
Done += onDone;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
[BestFriend]
|
||||
internal static class ServerChannelUtilities
|
||||
{
|
||||
/// <summary>
|
||||
/// Convenience method for <see cref="ServerChannel.Start"/> that looks more idiomatic to typical
|
||||
/// channel creation methods on <see cref="IChannelProvider"/>.
|
||||
/// </summary>
|
||||
/// <param name="provider">The channel provider.</param>
|
||||
/// <param name="identifier">This is an identifier of the "type" of bundle that is being published,
|
||||
/// and should form a path with forward-slash '/' delimiters.</param>
|
||||
/// <returns>The newly created server channel, or <c>null</c> if there was no listener for
|
||||
/// server channels on <paramref name="provider"/>.</returns>
|
||||
public static ServerChannel StartServerChannel(this IChannelProvider provider, string identifier)
|
||||
{
|
||||
Contracts.CheckValue(provider, nameof(provider));
|
||||
Contracts.CheckNonWhiteSpace(identifier, nameof(identifier));
|
||||
return ServerChannel.Start(provider, identifier);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -4,69 +4,68 @@
|
|||
|
||||
using Microsoft.ML.Runtime;
|
||||
|
||||
namespace Microsoft.ML.Data
|
||||
namespace Microsoft.ML.Data;
|
||||
|
||||
/// <summary>
|
||||
/// Base class for creating a cursor on top of another cursor that does not add or remove rows.
|
||||
/// It forces one-to-one correspondence between items in the input cursor and this cursor.
|
||||
/// It delegates all <see cref="DataViewRowCursor"/> functionality except Dispose() to the root cursor.
|
||||
/// Dispose is virtual with the default implementation delegating to the input cursor.
|
||||
/// </summary>
|
||||
[BestFriend]
|
||||
internal abstract class SynchronizedCursorBase : DataViewRowCursor
|
||||
{
|
||||
protected readonly IChannel Ch;
|
||||
|
||||
/// <summary>
|
||||
/// Base class for creating a cursor on top of another cursor that does not add or remove rows.
|
||||
/// It forces one-to-one correspondence between items in the input cursor and this cursor.
|
||||
/// It delegates all <see cref="DataViewRowCursor"/> functionality except Dispose() to the root cursor.
|
||||
/// Dispose is virtual with the default implementation delegating to the input cursor.
|
||||
/// The synchronized cursor base, as it merely passes through requests for all "positional" calls (including
|
||||
/// <see cref="MoveNext"/>, <see cref="Position"/>, <see cref="Batch"/>, and so forth), offers an opportunity
|
||||
/// for optimization for "wrapping" cursors (which are themselves often <see cref="SynchronizedCursorBase"/>
|
||||
/// implementors) to get this root cursor. But, this can only be done by exposing this root cursor, as we do here.
|
||||
/// Internal code should be quite careful in using this as the potential for misuse is quite high.
|
||||
/// </summary>
|
||||
[BestFriend]
|
||||
internal abstract class SynchronizedCursorBase : DataViewRowCursor
|
||||
internal readonly DataViewRowCursor Root;
|
||||
private bool _disposed;
|
||||
|
||||
protected DataViewRowCursor Input { get; }
|
||||
|
||||
public sealed override long Position => Root.Position;
|
||||
|
||||
public sealed override long Batch => Root.Batch;
|
||||
|
||||
/// <summary>
|
||||
/// Convenience property for checking whether the cursor is in a good state where values
|
||||
/// can be retrieved, that is, whenever <see cref="Position"/> is non-negative.
|
||||
/// </summary>
|
||||
protected bool IsGood => Position >= 0;
|
||||
|
||||
protected SynchronizedCursorBase(IChannelProvider provider, DataViewRowCursor input)
|
||||
{
|
||||
protected readonly IChannel Ch;
|
||||
Contracts.AssertValue(provider);
|
||||
Ch = provider.Start("Cursor");
|
||||
|
||||
/// <summary>
|
||||
/// The synchronized cursor base, as it merely passes through requests for all "positional" calls (including
|
||||
/// <see cref="MoveNext"/>, <see cref="Position"/>, <see cref="Batch"/>, and so forth), offers an opportunity
|
||||
/// for optimization for "wrapping" cursors (which are themselves often <see cref="SynchronizedCursorBase"/>
|
||||
/// implementors) to get this root cursor. But, this can only be done by exposing this root cursor, as we do here.
|
||||
/// Internal code should be quite careful in using this as the potential for misuse is quite high.
|
||||
/// </summary>
|
||||
internal readonly DataViewRowCursor Root;
|
||||
private bool _disposed;
|
||||
|
||||
protected DataViewRowCursor Input { get; }
|
||||
|
||||
public sealed override long Position => Root.Position;
|
||||
|
||||
public sealed override long Batch => Root.Batch;
|
||||
|
||||
/// <summary>
|
||||
/// Convenience property for checking whether the cursor is in a good state where values
|
||||
/// can be retrieved, that is, whenever <see cref="Position"/> is non-negative.
|
||||
/// </summary>
|
||||
protected bool IsGood => Position >= 0;
|
||||
|
||||
protected SynchronizedCursorBase(IChannelProvider provider, DataViewRowCursor input)
|
||||
{
|
||||
Contracts.AssertValue(provider);
|
||||
Ch = provider.Start("Cursor");
|
||||
|
||||
Ch.AssertValue(input);
|
||||
Input = input;
|
||||
// If this thing happens to be itself an instance of this class (which, practically, it will
|
||||
// be in the majority of situations), we can treat the input as likewise being a passthrough,
|
||||
// thereby saving lots of "nested" calls on the stack when doing common operations like movement.
|
||||
Root = Input is SynchronizedCursorBase syncInput ? syncInput.Root : input;
|
||||
}
|
||||
|
||||
protected override void Dispose(bool disposing)
|
||||
{
|
||||
if (_disposed)
|
||||
return;
|
||||
if (disposing)
|
||||
{
|
||||
Input.Dispose();
|
||||
Ch.Dispose();
|
||||
}
|
||||
base.Dispose(disposing);
|
||||
_disposed = true;
|
||||
}
|
||||
|
||||
public sealed override bool MoveNext() => Root.MoveNext();
|
||||
|
||||
public sealed override ValueGetter<DataViewRowId> GetIdGetter() => Input.GetIdGetter();
|
||||
Ch.AssertValue(input);
|
||||
Input = input;
|
||||
// If this thing happens to be itself an instance of this class (which, practically, it will
|
||||
// be in the majority of situations), we can treat the input as likewise being a passthrough,
|
||||
// thereby saving lots of "nested" calls on the stack when doing common operations like movement.
|
||||
Root = Input is SynchronizedCursorBase syncInput ? syncInput.Root : input;
|
||||
}
|
||||
|
||||
protected override void Dispose(bool disposing)
|
||||
{
|
||||
if (_disposed)
|
||||
return;
|
||||
if (disposing)
|
||||
{
|
||||
Input.Dispose();
|
||||
Ch.Dispose();
|
||||
}
|
||||
base.Dispose(disposing);
|
||||
_disposed = true;
|
||||
}
|
||||
|
||||
public sealed override bool MoveNext() => Root.MoveNext();
|
||||
|
||||
public sealed override ValueGetter<DataViewRowId> GetIdGetter() => Input.GetIdGetter();
|
||||
}
|
||||
|
|
|
@ -5,61 +5,60 @@
|
|||
using System;
|
||||
using Microsoft.ML.Runtime;
|
||||
|
||||
namespace Microsoft.ML.Data
|
||||
namespace Microsoft.ML.Data;
|
||||
|
||||
/// <summary>
|
||||
/// Convenient base class for <see cref="DataViewRow"/> implementors that wrap a single <see cref="DataViewRow"/>
|
||||
/// as their input. The <see cref="DataViewRow.Position"/>, <see cref="DataViewRow.Batch"/>, and <see cref="DataViewRow.GetIdGetter"/>
|
||||
/// are taken from this <see cref="Input"/>.
|
||||
/// </summary>
|
||||
[BestFriend]
|
||||
internal abstract class WrappingRow : DataViewRow
|
||||
{
|
||||
private bool _disposed;
|
||||
|
||||
/// <summary>
|
||||
/// Convenient base class for <see cref="DataViewRow"/> implementors that wrap a single <see cref="DataViewRow"/>
|
||||
/// as their input. The <see cref="DataViewRow.Position"/>, <see cref="DataViewRow.Batch"/>, and <see cref="DataViewRow.GetIdGetter"/>
|
||||
/// are taken from this <see cref="Input"/>.
|
||||
/// The wrapped input row.
|
||||
/// </summary>
|
||||
protected DataViewRow Input { get; }
|
||||
|
||||
public sealed override long Batch => Input.Batch;
|
||||
public sealed override long Position => Input.Position;
|
||||
public override ValueGetter<DataViewRowId> GetIdGetter() => Input.GetIdGetter();
|
||||
|
||||
[BestFriend]
|
||||
internal abstract class WrappingRow : DataViewRow
|
||||
private protected WrappingRow(DataViewRow input)
|
||||
{
|
||||
private bool _disposed;
|
||||
Contracts.AssertValue(input);
|
||||
Input = input;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// The wrapped input row.
|
||||
/// </summary>
|
||||
protected DataViewRow Input { get; }
|
||||
/// <summary>
|
||||
/// This override of the dispose method by default only calls <see cref="Input"/>'s
|
||||
/// <see cref="IDisposable.Dispose"/> method, but subclasses can enable additional functionality
|
||||
/// via the <see cref="DisposeCore(bool)"/> functionality.
|
||||
/// </summary>
|
||||
/// <param name="disposing"></param>
|
||||
protected sealed override void Dispose(bool disposing)
|
||||
{
|
||||
if (_disposed)
|
||||
return;
|
||||
// Since the input was created first, and this instance may depend on it, we should
|
||||
// dispose local resources first before potentially disposing the input row resources.
|
||||
DisposeCore(disposing);
|
||||
if (disposing)
|
||||
Input.Dispose();
|
||||
_disposed = true;
|
||||
}
|
||||
|
||||
public sealed override long Batch => Input.Batch;
|
||||
public sealed override long Position => Input.Position;
|
||||
public override ValueGetter<DataViewRowId> GetIdGetter() => Input.GetIdGetter();
|
||||
|
||||
[BestFriend]
|
||||
private protected WrappingRow(DataViewRow input)
|
||||
{
|
||||
Contracts.AssertValue(input);
|
||||
Input = input;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// This override of the dispose method by default only calls <see cref="Input"/>'s
|
||||
/// <see cref="IDisposable.Dispose"/> method, but subclasses can enable additional functionality
|
||||
/// via the <see cref="DisposeCore(bool)"/> functionality.
|
||||
/// </summary>
|
||||
/// <param name="disposing"></param>
|
||||
protected sealed override void Dispose(bool disposing)
|
||||
{
|
||||
if (_disposed)
|
||||
return;
|
||||
// Since the input was created first, and this instance may depend on it, we should
|
||||
// dispose local resources first before potentially disposing the input row resources.
|
||||
DisposeCore(disposing);
|
||||
if (disposing)
|
||||
Input.Dispose();
|
||||
_disposed = true;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Called from <see cref="Dispose(bool)"/> with <see langword="true"/> in the case where
|
||||
/// that method has never been called before, and right after <see cref="Input"/> has been
|
||||
/// disposed. The default implementation does nothing.
|
||||
/// </summary>
|
||||
/// <param name="disposing">Whether this was called through the dispose path, as opposed
|
||||
/// to the finalizer path.</param>
|
||||
protected virtual void DisposeCore(bool disposing)
|
||||
{
|
||||
}
|
||||
/// <summary>
|
||||
/// Called from <see cref="Dispose(bool)"/> with <see langword="true"/> in the case where
|
||||
/// that method has never been called before, and right after <see cref="Input"/> has been
|
||||
/// disposed. The default implementation does nothing.
|
||||
/// </summary>
|
||||
/// <param name="disposing">Whether this was called through the dispose path, as opposed
|
||||
/// to the finalizer path.</param>
|
||||
protected virtual void DisposeCore(bool disposing)
|
||||
{
|
||||
}
|
||||
}
|
||||
|
|
Загрузка…
Ссылка в новой задаче