File-scoped namespaces in files under `Data` (`Microsoft.ML.Core`) (#6789)

Co-authored-by: Lehonti Ramos <john@doe>
This commit is contained in:
Lehonti Ramos 2023-08-31 06:21:45 +02:00 коммит произвёл GitHub
Родитель 34389b63e5
Коммит aaf226c7e7
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
31 изменённых файлов: 5622 добавлений и 5653 удалений

Просмотреть файл

@ -4,29 +4,28 @@
using System;
namespace Microsoft.ML.Data
{
[BestFriend]
internal static class AnnotationBuilderExtensions
{
/// <summary>
/// Add slot names annotation.
/// </summary>
/// <param name="builder">The <see cref="DataViewSchema.Annotations.Builder"/> to which to add the slot names.</param>
/// <param name="size">The size of the slot names vector.</param>
/// <param name="getter">The getter delegate for the slot names.</param>
public static void AddSlotNames(this DataViewSchema.Annotations.Builder builder, int size, ValueGetter<VBuffer<ReadOnlyMemory<char>>> getter)
=> builder.Add(AnnotationUtils.Kinds.SlotNames, new VectorDataViewType(TextDataViewType.Instance, size), getter);
namespace Microsoft.ML.Data;
/// <summary>
/// Add key values annotation.
/// </summary>
/// <typeparam name="TValue">The value type of key values.</typeparam>
/// <param name="builder">The <see cref="DataViewSchema.Annotations.Builder"/> to which to add the key values.</param>
/// <param name="size">The size of key values vector.</param>
/// <param name="valueType">The value type of key values. Its raw type must match <typeparamref name="TValue"/>.</param>
/// <param name="getter">The getter delegate for the key values.</param>
public static void AddKeyValues<TValue>(this DataViewSchema.Annotations.Builder builder, int size, PrimitiveDataViewType valueType, ValueGetter<VBuffer<TValue>> getter)
=> builder.Add(AnnotationUtils.Kinds.KeyValues, new VectorDataViewType(valueType, size), getter);
}
[BestFriend]
internal static class AnnotationBuilderExtensions
{
/// <summary>
/// Add slot names annotation.
/// </summary>
/// <param name="builder">The <see cref="DataViewSchema.Annotations.Builder"/> to which to add the slot names.</param>
/// <param name="size">The size of the slot names vector.</param>
/// <param name="getter">The getter delegate for the slot names.</param>
public static void AddSlotNames(this DataViewSchema.Annotations.Builder builder, int size, ValueGetter<VBuffer<ReadOnlyMemory<char>>> getter)
=> builder.Add(AnnotationUtils.Kinds.SlotNames, new VectorDataViewType(TextDataViewType.Instance, size), getter);
/// <summary>
/// Add key values annotation.
/// </summary>
/// <typeparam name="TValue">The value type of key values.</typeparam>
/// <param name="builder">The <see cref="DataViewSchema.Annotations.Builder"/> to which to add the key values.</param>
/// <param name="size">The size of key values vector.</param>
/// <param name="valueType">The value type of key values. Its raw type must match <typeparamref name="TValue"/>.</param>
/// <param name="getter">The getter delegate for the key values.</param>
public static void AddKeyValues<TValue>(this DataViewSchema.Annotations.Builder builder, int size, PrimitiveDataViewType valueType, ValueGetter<VBuffer<TValue>> getter)
=> builder.Add(AnnotationUtils.Kinds.KeyValues, new VectorDataViewType(valueType, size), getter);
}

Просмотреть файл

@ -9,493 +9,492 @@ using System.Threading;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Runtime;
namespace Microsoft.ML.Data
namespace Microsoft.ML.Data;
/// <summary>
/// Utilities for implementing and using the annotation API of <see cref="DataViewSchema"/>.
/// </summary>
[BestFriend]
internal static class AnnotationUtils
{
/// <summary>
/// Utilities for implementing and using the annotation API of <see cref="DataViewSchema"/>.
/// This class lists the canonical annotation kinds
/// </summary>
[BestFriend]
internal static class AnnotationUtils
public static class Kinds
{
/// <summary>
/// This class lists the canonical annotation kinds
/// Annotation kind for names associated with slots/positions in a vector-valued column.
/// The associated annotation type is typically fixed-sized vector of Text.
/// </summary>
public static class Kinds
{
/// <summary>
/// Annotation kind for names associated with slots/positions in a vector-valued column.
/// The associated annotation type is typically fixed-sized vector of Text.
/// </summary>
public const string SlotNames = "SlotNames";
/// <summary>
/// Annotation kind for values associated with the key indices when the column type's item type
/// is a key type. The associated annotation type is typically fixed-sized vector of a primitive
/// type. The primitive type is frequently Text, but can be anything.
/// </summary>
public const string KeyValues = "KeyValues";
/// <summary>
/// Annotation kind for sets of score columns. The value is typically a <see cref="KeyDataViewType"/> with raw type U4.
/// </summary>
public const string ScoreColumnSetId = "ScoreColumnSetId";
/// <summary>
/// Annotation kind that indicates the prediction kind as a string. For example, "BinaryClassification".
/// The value is typically a ReadOnlyMemory&lt;char&gt;.
/// </summary>
public const string ScoreColumnKind = "ScoreColumnKind";
/// <summary>
/// Annotation kind that indicates the value kind of the score column as a string. For example, "Score", "PredictedLabel", "Probability". The value is typically a ReadOnlyMemory.
/// </summary>
public const string ScoreValueKind = "ScoreValueKind";
/// <summary>
/// Annotation kind that indicates if a column is normalized. The value is typically a Bool.
/// </summary>
public const string IsNormalized = "IsNormalized";
/// <summary>
/// Annotation kind that indicates if a column is visible to the users. The value is typically a Bool.
/// Not to be confused with IsHidden() that determines if a column is masked.
/// </summary>
public const string IsUserVisible = "IsUserVisible";
/// <summary>
/// Annotation kind for the label values used in training to be used for the predicted label.
/// The value is typically a fixed-sized vector of Text.
/// </summary>
public const string TrainingLabelValues = "TrainingLabelValues";
/// <summary>
/// Annotation kind that indicates the ranges within a column that are categorical features.
/// The value is a vector type of ints with dimension of two. The first dimension
/// represents the number of categorical features and second dimension represents the range
/// and is of size two. The range has start and end index(both inclusive) of categorical
/// slots within that column.
/// </summary>
public const string CategoricalSlotRanges = "CategoricalSlotRanges";
}
public const string SlotNames = "SlotNames";
/// <summary>
/// This class holds all pre-defined string values that can be found in canonical annotations
/// Annotation kind for values associated with the key indices when the column type's item type
/// is a key type. The associated annotation type is typically fixed-sized vector of a primitive
/// type. The primitive type is frequently Text, but can be anything.
/// </summary>
public static class Const
{
public static class ScoreColumnKind
{
public const string BinaryClassification = "BinaryClassification";
public const string MulticlassClassification = "MulticlassClassification";
public const string Regression = "Regression";
public const string Ranking = "Ranking";
public const string Clustering = "Clustering";
public const string MultiOutputRegression = "MultiOutputRegression";
public const string AnomalyDetection = "AnomalyDetection";
public const string SequenceClassification = "SequenceClassification";
public const string QuantileRegression = "QuantileRegression";
public const string Recommender = "Recommender";
public const string ItemSimilarity = "ItemSimilarity";
public const string FeatureContribution = "FeatureContribution";
}
public static class ScoreValueKind
{
public const string Score = "Score";
public const string PredictedLabel = "PredictedLabel";
public const string Probability = "Probability";
}
}
public const string KeyValues = "KeyValues";
/// <summary>
/// Helper delegate for marshaling from generic land to specific types. Used by the Marshal method below.
/// Annotation kind for sets of score columns. The value is typically a <see cref="KeyDataViewType"/> with raw type U4.
/// </summary>
public delegate void AnnotationGetter<TValue>(int col, ref TValue dst);
public const string ScoreColumnSetId = "ScoreColumnSetId";
/// <summary>
/// Returns a standard exception for responding to an invalid call to GetAnnotation.
/// Annotation kind that indicates the prediction kind as a string. For example, "BinaryClassification".
/// The value is typically a ReadOnlyMemory&lt;char&gt;.
/// </summary>
public static Exception ExceptGetAnnotation() => Contracts.Except("Invalid call to GetAnnotation");
public const string ScoreColumnKind = "ScoreColumnKind";
/// <summary>
/// Returns a standard exception for responding to an invalid call to GetAnnotation.
/// Annotation kind that indicates the value kind of the score column as a string. For example, "Score", "PredictedLabel", "Probability". The value is typically a ReadOnlyMemory.
/// </summary>
public static Exception ExceptGetAnnotation(this IExceptionContext ctx) => ctx.Except("Invalid call to GetAnnotation");
public const string ScoreValueKind = "ScoreValueKind";
/// <summary>
/// Helper to marshal a call to GetAnnotation{TValue} to a specific type.
/// Annotation kind that indicates if a column is normalized. The value is typically a Bool.
/// </summary>
public static void Marshal<THave, TNeed>(this AnnotationGetter<THave> getter, int col, ref TNeed dst)
{
Contracts.CheckValue(getter, nameof(getter));
if (typeof(TNeed) != typeof(THave))
throw ExceptGetAnnotation();
var get = (AnnotationGetter<TNeed>)(Delegate)getter;
get(col, ref dst);
}
public const string IsNormalized = "IsNormalized";
/// <summary>
/// Returns a vector type with item type text and the given size. The size must be positive.
/// This is a standard type for annotation consisting of multiple text values, eg SlotNames.
/// Annotation kind that indicates if a column is visible to the users. The value is typically a Bool.
/// Not to be confused with IsHidden() that determines if a column is masked.
/// </summary>
public static VectorDataViewType GetNamesType(int size)
{
Contracts.CheckParam(size > 0, nameof(size), "must be known size");
return new VectorDataViewType(TextDataViewType.Instance, size);
}
public const string IsUserVisible = "IsUserVisible";
/// <summary>
/// Returns a vector type with item type int and the given size.
/// The range count must be a positive integer.
/// This is a standard type for annotation consisting of multiple int values that represent
/// categorical slot ranges with in a column.
/// Annotation kind for the label values used in training to be used for the predicted label.
/// The value is typically a fixed-sized vector of Text.
/// </summary>
public static VectorDataViewType GetCategoricalType(int rangeCount)
{
Contracts.CheckParam(rangeCount > 0, nameof(rangeCount), "must be known size");
return new VectorDataViewType(NumberDataViewType.Int32, rangeCount, 2);
}
private static volatile KeyDataViewType _scoreColumnSetIdType;
public const string TrainingLabelValues = "TrainingLabelValues";
/// <summary>
/// The type of the ScoreColumnSetId annotation.
/// Annotation kind that indicates the ranges within a column that are categorical features.
/// The value is a vector type of ints with dimension of two. The first dimension
/// represents the number of categorical features and second dimension represents the range
/// and is of size two. The range has start and end index(both inclusive) of categorical
/// slots within that column.
/// </summary>
public static KeyDataViewType ScoreColumnSetIdType
public const string CategoricalSlotRanges = "CategoricalSlotRanges";
}
/// <summary>
/// This class holds all pre-defined string values that can be found in canonical annotations
/// </summary>
public static class Const
{
public static class ScoreColumnKind
{
get
{
return _scoreColumnSetIdType ??
Interlocked.CompareExchange(ref _scoreColumnSetIdType, new KeyDataViewType(typeof(uint), int.MaxValue), null) ??
_scoreColumnSetIdType;
}
public const string BinaryClassification = "BinaryClassification";
public const string MulticlassClassification = "MulticlassClassification";
public const string Regression = "Regression";
public const string Ranking = "Ranking";
public const string Clustering = "Clustering";
public const string MultiOutputRegression = "MultiOutputRegression";
public const string AnomalyDetection = "AnomalyDetection";
public const string SequenceClassification = "SequenceClassification";
public const string QuantileRegression = "QuantileRegression";
public const string Recommender = "Recommender";
public const string ItemSimilarity = "ItemSimilarity";
public const string FeatureContribution = "FeatureContribution";
}
/// <summary>
/// Returns a key-value pair useful when implementing GetAnnotationTypes(col).
/// </summary>
public static KeyValuePair<string, DataViewType> GetSlotNamesPair(int size)
public static class ScoreValueKind
{
return GetNamesType(size).GetPair(Kinds.SlotNames);
}
/// <summary>
/// Returns a key-value pair useful when implementing GetAnnotationTypes(col). This assumes
/// that the values of the key type are Text.
/// </summary>
public static KeyValuePair<string, DataViewType> GetKeyNamesPair(int size)
{
return GetNamesType(size).GetPair(Kinds.KeyValues);
}
/// <summary>
/// Given a type and annotation kind string, returns a key-value pair. This is useful when
/// implementing GetAnnotationTypes(col).
/// </summary>
public static KeyValuePair<string, DataViewType> GetPair(this DataViewType type, string kind)
{
Contracts.CheckValue(type, nameof(type));
return new KeyValuePair<string, DataViewType>(kind, type);
}
// REVIEW: This should be in some general utility code.
/// <summary>
/// Prepends a params array to an enumerable. Useful when implementing GetAnnotationTypes.
/// </summary>
public static IEnumerable<T> Prepend<T>(this IEnumerable<T> tail, params T[] head)
{
return head.Concat(tail);
}
/// <summary>
/// Returns the max value for the specified annotation kind.
/// The annotation type should be a <see cref="KeyDataViewType"/> with raw type U4.
/// colMax will be set to the first column that has the max value for the specified annotation.
/// If no column has the specified annotation, colMax is set to -1 and the method returns zero.
/// The filter function is called for each column, passing in the schema and the column index, and returns
/// true if the column should be considered, false if the column should be skipped.
/// </summary>
public static uint GetMaxAnnotationKind(this DataViewSchema schema, out int colMax, string annotationKind, Func<DataViewSchema, int, bool> filterFunc = null)
{
uint max = 0;
colMax = -1;
for (int col = 0; col < schema.Count; col++)
{
var columnType = schema[col].Annotations.Schema.GetColumnOrNull(annotationKind)?.Type;
if (!(columnType is KeyDataViewType) || columnType.RawType != typeof(uint))
continue;
if (filterFunc != null && !filterFunc(schema, col))
continue;
uint value = 0;
schema[col].Annotations.GetValue(annotationKind, ref value);
if (max < value)
{
max = value;
colMax = col;
}
}
return max;
}
/// <summary>
/// Returns the set of column ids which match the value of specified annotation kind.
/// The annotation type should be a <see cref="KeyDataViewType"/> with raw type U4.
/// </summary>
public static IEnumerable<int> GetColumnSet(this DataViewSchema schema, string annotationKind, uint value)
{
for (int col = 0; col < schema.Count; col++)
{
var columnType = schema[col].Annotations.Schema.GetColumnOrNull(annotationKind)?.Type;
if (columnType is KeyDataViewType && columnType.RawType == typeof(uint))
{
uint val = 0;
schema[col].Annotations.GetValue(annotationKind, ref val);
if (val == value)
yield return col;
}
}
}
/// <summary>
/// Returns the set of column ids which match the value of specified annotation kind.
/// The annotation type should be of type text.
/// </summary>
public static IEnumerable<int> GetColumnSet(this DataViewSchema schema, string annotationKind, string value)
{
for (int col = 0; col < schema.Count; col++)
{
var columnType = schema[col].Annotations.Schema.GetColumnOrNull(annotationKind)?.Type;
if (columnType is TextDataViewType)
{
ReadOnlyMemory<char> val = default;
schema[col].Annotations.GetValue(annotationKind, ref val);
if (ReadOnlyMemoryUtils.EqualsStr(value, val))
yield return col;
}
}
}
/// <summary>
/// Returns <c>true</c> if the specified column:
/// * has a SlotNames annotation
/// * annotation type is VBuffer&lt;ReadOnlyMemory&lt;char&gt;&gt; of length <paramref name="vectorSize"/>.
/// </summary>
public static bool HasSlotNames(this DataViewSchema.Column column, int vectorSize)
{
if (vectorSize == 0)
return false;
var metaColumn = column.Annotations.Schema.GetColumnOrNull(Kinds.SlotNames);
return
metaColumn != null
&& metaColumn.Value.Type is VectorDataViewType vectorType
&& vectorType.Size == vectorSize
&& vectorType.ItemType is TextDataViewType;
}
public static void GetSlotNames(RoleMappedSchema schema, RoleMappedSchema.ColumnRole role, int vectorSize, ref VBuffer<ReadOnlyMemory<char>> slotNames)
{
Contracts.CheckValueOrNull(schema);
Contracts.CheckParam(vectorSize >= 0, nameof(vectorSize));
IReadOnlyList<DataViewSchema.Column> list = schema?.GetColumns(role);
if (list?.Count != 1 || !schema.Schema[list[0].Index].HasSlotNames(vectorSize))
VBufferUtils.Resize(ref slotNames, vectorSize, 0);
else
schema.Schema[list[0].Index].Annotations.GetValue(Kinds.SlotNames, ref slotNames);
}
public static bool NeedsSlotNames(this SchemaShape.Column col)
{
return col.Annotations.TryFindColumn(Kinds.KeyValues, out var metaCol)
&& metaCol.Kind == SchemaShape.Column.VectorKind.Vector
&& metaCol.ItemType is TextDataViewType;
}
/// <summary>
/// Returns whether a column has the <see cref="Kinds.IsNormalized"/> annotation indicated by
/// the schema shape.
/// </summary>
/// <param name="column">The schema shape column to query</param>
/// <returns>True if and only if the column has the <see cref="Kinds.IsNormalized"/> annotation
/// of a scalar <see cref="BooleanDataViewType"/> type, which we assume, if set, should be <c>true</c>.</returns>
public static bool IsNormalized(this SchemaShape.Column column)
{
Contracts.CheckParam(column.IsValid, nameof(column), "struct not initialized properly");
return column.Annotations.TryFindColumn(Kinds.IsNormalized, out var metaCol)
&& metaCol.Kind == SchemaShape.Column.VectorKind.Scalar && !metaCol.IsKey
&& metaCol.ItemType == BooleanDataViewType.Instance;
}
/// <summary>
/// Returns whether a column has the <see cref="Kinds.SlotNames"/> annotation indicated by
/// the schema shape.
/// </summary>
/// <param name="col">The schema shape column to query</param>
/// <returns>True if and only if the column is a definite sized vector type, has the
/// <see cref="Kinds.SlotNames"/> annotation of definite sized vectors of text.</returns>
public static bool HasSlotNames(this SchemaShape.Column col)
{
Contracts.CheckParam(col.IsValid, nameof(col), "struct not initialized properly");
return col.Kind == SchemaShape.Column.VectorKind.Vector
&& col.Annotations.TryFindColumn(Kinds.SlotNames, out var metaCol)
&& metaCol.Kind == SchemaShape.Column.VectorKind.Vector && !metaCol.IsKey
&& metaCol.ItemType == TextDataViewType.Instance;
}
/// <summary>
/// Tries to get the annotation kind of the specified type for a column.
/// </summary>
/// <typeparam name="T">The raw type of the annotation, should match the PrimitiveType type</typeparam>
/// <param name="schema">The schema</param>
/// <param name="type">The type of the annotation</param>
/// <param name="kind">The annotation kind</param>
/// <param name="col">The column</param>
/// <param name="value">The value to return, if successful</param>
/// <returns>True if the annotation of the right type exists, false otherwise</returns>
public static bool TryGetAnnotation<T>(this DataViewSchema schema, PrimitiveDataViewType type, string kind, int col, ref T value)
{
Contracts.CheckValue(schema, nameof(schema));
Contracts.CheckValue(type, nameof(type));
var annotationType = schema[col].Annotations.Schema.GetColumnOrNull(kind)?.Type;
if (!type.Equals(annotationType))
return false;
schema[col].Annotations.GetValue(kind, ref value);
return true;
}
/// <summary>
/// The categoricalFeatures is a vector of the indices of categorical features slots.
/// This vector should always have an even number of elements, and the elements should be parsed in groups of two consecutive numbers.
/// So if its value is the range of numbers: 0,2,3,4,8,9
/// look at it as [0,2],[3,4],[8,9].
/// The way to interpret that is: feature with indices 0, 1, and 2 are one categorical
/// Features with indices 3 and 4 are another categorical. Features 5 and 6 don't appear there, so they are not categoricals.
/// </summary>
public static bool TryGetCategoricalFeatureIndices(DataViewSchema schema, int colIndex, out int[] categoricalFeatures)
{
Contracts.CheckValue(schema, nameof(schema));
Contracts.Check(colIndex >= 0, nameof(colIndex));
bool isValid = false;
categoricalFeatures = null;
if (!(schema[colIndex].Type is VectorDataViewType vecType && vecType.Size > 0))
return isValid;
var type = schema[colIndex].Annotations.Schema.GetColumnOrNull(Kinds.CategoricalSlotRanges)?.Type;
if (type?.RawType == typeof(VBuffer<int>))
{
VBuffer<int> catIndices = default(VBuffer<int>);
schema[colIndex].Annotations.GetValue(Kinds.CategoricalSlotRanges, ref catIndices);
VBufferUtils.Densify(ref catIndices);
int columnSlotsCount = vecType.Size;
if (catIndices.Length > 0 && catIndices.Length % 2 == 0 && catIndices.Length <= columnSlotsCount * 2)
{
int previousEndIndex = -1;
isValid = true;
var catIndicesValues = catIndices.GetValues();
for (int i = 0; i < catIndicesValues.Length; i += 2)
{
if (catIndicesValues[i] > catIndicesValues[i + 1] ||
catIndicesValues[i] <= previousEndIndex ||
catIndicesValues[i] >= columnSlotsCount ||
catIndicesValues[i + 1] >= columnSlotsCount)
{
isValid = false;
break;
}
previousEndIndex = catIndicesValues[i + 1];
}
if (isValid)
categoricalFeatures = catIndicesValues.ToArray();
}
}
return isValid;
}
/// <summary>
/// Produces sequence of columns that are generated by trainer estimators.
/// </summary>
/// <param name="isNormalized">whether we should also append 'IsNormalized' (typically for probability column)</param>
public static IEnumerable<SchemaShape.Column> GetTrainerOutputAnnotation(bool isNormalized = false)
{
var cols = new List<SchemaShape.Column>();
cols.Add(new SchemaShape.Column(Kinds.ScoreColumnSetId, SchemaShape.Column.VectorKind.Scalar, NumberDataViewType.UInt32, true));
cols.Add(new SchemaShape.Column(Kinds.ScoreColumnKind, SchemaShape.Column.VectorKind.Scalar, TextDataViewType.Instance, false));
cols.Add(new SchemaShape.Column(Kinds.ScoreValueKind, SchemaShape.Column.VectorKind.Scalar, TextDataViewType.Instance, false));
if (isNormalized)
cols.Add(new SchemaShape.Column(Kinds.IsNormalized, SchemaShape.Column.VectorKind.Scalar, BooleanDataViewType.Instance, false));
return cols;
}
/// <summary>
/// Produces annotations for the score column generated by trainer estimators for multiclass classification.
/// If input LabelColumn is not available it produces slotnames annotation by default.
/// </summary>
/// <param name="labelColumn">Label column.</param>
public static IEnumerable<SchemaShape.Column> AnnotationsForMulticlassScoreColumn(SchemaShape.Column? labelColumn = null)
{
var cols = new List<SchemaShape.Column>();
if (labelColumn != null && labelColumn.Value.IsKey)
{
if (labelColumn.Value.Annotations.TryFindColumn(Kinds.KeyValues, out var metaCol) &&
metaCol.Kind == SchemaShape.Column.VectorKind.Vector)
{
if (metaCol.ItemType is TextDataViewType)
cols.Add(new SchemaShape.Column(Kinds.SlotNames, SchemaShape.Column.VectorKind.Vector, TextDataViewType.Instance, false));
cols.Add(new SchemaShape.Column(Kinds.TrainingLabelValues, SchemaShape.Column.VectorKind.Vector, metaCol.ItemType, false));
}
}
cols.AddRange(GetTrainerOutputAnnotation());
return cols;
}
private sealed class AnnotationRow : DataViewRow
{
private readonly DataViewSchema.Annotations _annotations;
public AnnotationRow(DataViewSchema.Annotations annotations)
{
Contracts.AssertValue(annotations);
_annotations = annotations;
}
public override DataViewSchema Schema => _annotations.Schema;
public override long Position => 0;
public override long Batch => 0;
/// <summary>
/// Returns a value getter delegate to fetch the value of column with the given columnIndex, from the row.
/// This throws if the column is not active in this row, or if the type
/// <typeparamref name="TValue"/> differs from this column's type.
/// </summary>
/// <typeparam name="TValue"> is the column's content type.</typeparam>
/// <param name="column"> is the output column whose getter should be returned.</param>
public override ValueGetter<TValue> GetGetter<TValue>(DataViewSchema.Column column) => _annotations.GetGetter<TValue>(column);
public override ValueGetter<DataViewRowId> GetIdGetter() => (ref DataViewRowId dst) => dst = default;
/// <summary>
/// Returns whether the given column is active in this row.
/// </summary>
public override bool IsColumnActive(DataViewSchema.Column column) => true;
}
/// <summary>
/// Presents a <see cref="DataViewSchema.Annotations"/> as a an <see cref="DataViewRow"/>.
/// </summary>
/// <param name="annotations">The annotations to wrap.</param>
/// <returns>A row that wraps an input annotations.</returns>
[BestFriend]
internal static DataViewRow AnnotationsAsRow(DataViewSchema.Annotations annotations)
{
Contracts.CheckValue(annotations, nameof(annotations));
return new AnnotationRow(annotations);
public const string Score = "Score";
public const string PredictedLabel = "PredictedLabel";
public const string Probability = "Probability";
}
}
/// <summary>
/// Helper delegate for marshaling from generic land to specific types. Used by the Marshal method below.
/// </summary>
public delegate void AnnotationGetter<TValue>(int col, ref TValue dst);
/// <summary>
/// Returns a standard exception for responding to an invalid call to GetAnnotation.
/// </summary>
public static Exception ExceptGetAnnotation() => Contracts.Except("Invalid call to GetAnnotation");
/// <summary>
/// Returns a standard exception for responding to an invalid call to GetAnnotation.
/// </summary>
public static Exception ExceptGetAnnotation(this IExceptionContext ctx) => ctx.Except("Invalid call to GetAnnotation");
/// <summary>
/// Helper to marshal a call to GetAnnotation{TValue} to a specific type.
/// </summary>
public static void Marshal<THave, TNeed>(this AnnotationGetter<THave> getter, int col, ref TNeed dst)
{
Contracts.CheckValue(getter, nameof(getter));
if (typeof(TNeed) != typeof(THave))
throw ExceptGetAnnotation();
var get = (AnnotationGetter<TNeed>)(Delegate)getter;
get(col, ref dst);
}
/// <summary>
/// Returns a vector type with item type text and the given size. The size must be positive.
/// This is a standard type for annotation consisting of multiple text values, eg SlotNames.
/// </summary>
public static VectorDataViewType GetNamesType(int size)
{
Contracts.CheckParam(size > 0, nameof(size), "must be known size");
return new VectorDataViewType(TextDataViewType.Instance, size);
}
/// <summary>
/// Returns a vector type with item type int and the given size.
/// The range count must be a positive integer.
/// This is a standard type for annotation consisting of multiple int values that represent
/// categorical slot ranges with in a column.
/// </summary>
public static VectorDataViewType GetCategoricalType(int rangeCount)
{
Contracts.CheckParam(rangeCount > 0, nameof(rangeCount), "must be known size");
return new VectorDataViewType(NumberDataViewType.Int32, rangeCount, 2);
}
private static volatile KeyDataViewType _scoreColumnSetIdType;
/// <summary>
/// The type of the ScoreColumnSetId annotation.
/// </summary>
public static KeyDataViewType ScoreColumnSetIdType
{
get
{
return _scoreColumnSetIdType ??
Interlocked.CompareExchange(ref _scoreColumnSetIdType, new KeyDataViewType(typeof(uint), int.MaxValue), null) ??
_scoreColumnSetIdType;
}
}
/// <summary>
/// Returns a key-value pair useful when implementing GetAnnotationTypes(col).
/// </summary>
public static KeyValuePair<string, DataViewType> GetSlotNamesPair(int size)
{
return GetNamesType(size).GetPair(Kinds.SlotNames);
}
/// <summary>
/// Returns a key-value pair useful when implementing GetAnnotationTypes(col). This assumes
/// that the values of the key type are Text.
/// </summary>
public static KeyValuePair<string, DataViewType> GetKeyNamesPair(int size)
{
return GetNamesType(size).GetPair(Kinds.KeyValues);
}
/// <summary>
/// Given a type and annotation kind string, returns a key-value pair. This is useful when
/// implementing GetAnnotationTypes(col).
/// </summary>
public static KeyValuePair<string, DataViewType> GetPair(this DataViewType type, string kind)
{
Contracts.CheckValue(type, nameof(type));
return new KeyValuePair<string, DataViewType>(kind, type);
}
// REVIEW: This should be in some general utility code.
/// <summary>
/// Prepends a params array to an enumerable. Useful when implementing GetAnnotationTypes.
/// </summary>
public static IEnumerable<T> Prepend<T>(this IEnumerable<T> tail, params T[] head)
{
return head.Concat(tail);
}
/// <summary>
/// Returns the max value for the specified annotation kind.
/// The annotation type should be a <see cref="KeyDataViewType"/> with raw type U4.
/// colMax will be set to the first column that has the max value for the specified annotation.
/// If no column has the specified annotation, colMax is set to -1 and the method returns zero.
/// The filter function is called for each column, passing in the schema and the column index, and returns
/// true if the column should be considered, false if the column should be skipped.
/// </summary>
public static uint GetMaxAnnotationKind(this DataViewSchema schema, out int colMax, string annotationKind, Func<DataViewSchema, int, bool> filterFunc = null)
{
uint max = 0;
colMax = -1;
for (int col = 0; col < schema.Count; col++)
{
var columnType = schema[col].Annotations.Schema.GetColumnOrNull(annotationKind)?.Type;
if (!(columnType is KeyDataViewType) || columnType.RawType != typeof(uint))
continue;
if (filterFunc != null && !filterFunc(schema, col))
continue;
uint value = 0;
schema[col].Annotations.GetValue(annotationKind, ref value);
if (max < value)
{
max = value;
colMax = col;
}
}
return max;
}
/// <summary>
/// Returns the set of column ids which match the value of specified annotation kind.
/// The annotation type should be a <see cref="KeyDataViewType"/> with raw type U4.
/// </summary>
public static IEnumerable<int> GetColumnSet(this DataViewSchema schema, string annotationKind, uint value)
{
for (int col = 0; col < schema.Count; col++)
{
var columnType = schema[col].Annotations.Schema.GetColumnOrNull(annotationKind)?.Type;
if (columnType is KeyDataViewType && columnType.RawType == typeof(uint))
{
uint val = 0;
schema[col].Annotations.GetValue(annotationKind, ref val);
if (val == value)
yield return col;
}
}
}
/// <summary>
/// Returns the set of column ids which match the value of specified annotation kind.
/// The annotation type should be of type text.
/// </summary>
public static IEnumerable<int> GetColumnSet(this DataViewSchema schema, string annotationKind, string value)
{
for (int col = 0; col < schema.Count; col++)
{
var columnType = schema[col].Annotations.Schema.GetColumnOrNull(annotationKind)?.Type;
if (columnType is TextDataViewType)
{
ReadOnlyMemory<char> val = default;
schema[col].Annotations.GetValue(annotationKind, ref val);
if (ReadOnlyMemoryUtils.EqualsStr(value, val))
yield return col;
}
}
}
/// <summary>
/// Returns <c>true</c> if the specified column:
/// * has a SlotNames annotation
/// * annotation type is VBuffer&lt;ReadOnlyMemory&lt;char&gt;&gt; of length <paramref name="vectorSize"/>.
/// </summary>
public static bool HasSlotNames(this DataViewSchema.Column column, int vectorSize)
{
if (vectorSize == 0)
return false;
var metaColumn = column.Annotations.Schema.GetColumnOrNull(Kinds.SlotNames);
return
metaColumn != null
&& metaColumn.Value.Type is VectorDataViewType vectorType
&& vectorType.Size == vectorSize
&& vectorType.ItemType is TextDataViewType;
}
public static void GetSlotNames(RoleMappedSchema schema, RoleMappedSchema.ColumnRole role, int vectorSize, ref VBuffer<ReadOnlyMemory<char>> slotNames)
{
Contracts.CheckValueOrNull(schema);
Contracts.CheckParam(vectorSize >= 0, nameof(vectorSize));
IReadOnlyList<DataViewSchema.Column> list = schema?.GetColumns(role);
if (list?.Count != 1 || !schema.Schema[list[0].Index].HasSlotNames(vectorSize))
VBufferUtils.Resize(ref slotNames, vectorSize, 0);
else
schema.Schema[list[0].Index].Annotations.GetValue(Kinds.SlotNames, ref slotNames);
}
public static bool NeedsSlotNames(this SchemaShape.Column col)
{
return col.Annotations.TryFindColumn(Kinds.KeyValues, out var metaCol)
&& metaCol.Kind == SchemaShape.Column.VectorKind.Vector
&& metaCol.ItemType is TextDataViewType;
}
/// <summary>
/// Returns whether a column has the <see cref="Kinds.IsNormalized"/> annotation indicated by
/// the schema shape.
/// </summary>
/// <param name="column">The schema shape column to query</param>
/// <returns>True if and only if the column has the <see cref="Kinds.IsNormalized"/> annotation
/// of a scalar <see cref="BooleanDataViewType"/> type, which we assume, if set, should be <c>true</c>.</returns>
public static bool IsNormalized(this SchemaShape.Column column)
{
Contracts.CheckParam(column.IsValid, nameof(column), "struct not initialized properly");
return column.Annotations.TryFindColumn(Kinds.IsNormalized, out var metaCol)
&& metaCol.Kind == SchemaShape.Column.VectorKind.Scalar && !metaCol.IsKey
&& metaCol.ItemType == BooleanDataViewType.Instance;
}
/// <summary>
/// Returns whether a column has the <see cref="Kinds.SlotNames"/> annotation indicated by
/// the schema shape.
/// </summary>
/// <param name="col">The schema shape column to query</param>
/// <returns>True if and only if the column is a definite sized vector type, has the
/// <see cref="Kinds.SlotNames"/> annotation of definite sized vectors of text.</returns>
public static bool HasSlotNames(this SchemaShape.Column col)
{
Contracts.CheckParam(col.IsValid, nameof(col), "struct not initialized properly");
return col.Kind == SchemaShape.Column.VectorKind.Vector
&& col.Annotations.TryFindColumn(Kinds.SlotNames, out var metaCol)
&& metaCol.Kind == SchemaShape.Column.VectorKind.Vector && !metaCol.IsKey
&& metaCol.ItemType == TextDataViewType.Instance;
}
/// <summary>
/// Tries to get the annotation kind of the specified type for a column.
/// </summary>
/// <typeparam name="T">The raw type of the annotation, should match the PrimitiveType type</typeparam>
/// <param name="schema">The schema</param>
/// <param name="type">The type of the annotation</param>
/// <param name="kind">The annotation kind</param>
/// <param name="col">The column</param>
/// <param name="value">The value to return, if successful</param>
/// <returns>True if the annotation of the right type exists, false otherwise</returns>
public static bool TryGetAnnotation<T>(this DataViewSchema schema, PrimitiveDataViewType type, string kind, int col, ref T value)
{
Contracts.CheckValue(schema, nameof(schema));
Contracts.CheckValue(type, nameof(type));
var annotationType = schema[col].Annotations.Schema.GetColumnOrNull(kind)?.Type;
if (!type.Equals(annotationType))
return false;
schema[col].Annotations.GetValue(kind, ref value);
return true;
}
/// <summary>
/// The categoricalFeatures is a vector of the indices of categorical features slots.
/// This vector should always have an even number of elements, and the elements should be parsed in groups of two consecutive numbers.
/// So if its value is the range of numbers: 0,2,3,4,8,9
/// look at it as [0,2],[3,4],[8,9].
/// The way to interpret that is: feature with indices 0, 1, and 2 are one categorical
/// Features with indices 3 and 4 are another categorical. Features 5 and 6 don't appear there, so they are not categoricals.
/// </summary>
public static bool TryGetCategoricalFeatureIndices(DataViewSchema schema, int colIndex, out int[] categoricalFeatures)
{
Contracts.CheckValue(schema, nameof(schema));
Contracts.Check(colIndex >= 0, nameof(colIndex));
bool isValid = false;
categoricalFeatures = null;
if (!(schema[colIndex].Type is VectorDataViewType vecType && vecType.Size > 0))
return isValid;
var type = schema[colIndex].Annotations.Schema.GetColumnOrNull(Kinds.CategoricalSlotRanges)?.Type;
if (type?.RawType == typeof(VBuffer<int>))
{
VBuffer<int> catIndices = default(VBuffer<int>);
schema[colIndex].Annotations.GetValue(Kinds.CategoricalSlotRanges, ref catIndices);
VBufferUtils.Densify(ref catIndices);
int columnSlotsCount = vecType.Size;
if (catIndices.Length > 0 && catIndices.Length % 2 == 0 && catIndices.Length <= columnSlotsCount * 2)
{
int previousEndIndex = -1;
isValid = true;
var catIndicesValues = catIndices.GetValues();
for (int i = 0; i < catIndicesValues.Length; i += 2)
{
if (catIndicesValues[i] > catIndicesValues[i + 1] ||
catIndicesValues[i] <= previousEndIndex ||
catIndicesValues[i] >= columnSlotsCount ||
catIndicesValues[i + 1] >= columnSlotsCount)
{
isValid = false;
break;
}
previousEndIndex = catIndicesValues[i + 1];
}
if (isValid)
categoricalFeatures = catIndicesValues.ToArray();
}
}
return isValid;
}
/// <summary>
/// Produces sequence of columns that are generated by trainer estimators.
/// </summary>
/// <param name="isNormalized">whether we should also append 'IsNormalized' (typically for probability column)</param>
public static IEnumerable<SchemaShape.Column> GetTrainerOutputAnnotation(bool isNormalized = false)
{
var cols = new List<SchemaShape.Column>();
cols.Add(new SchemaShape.Column(Kinds.ScoreColumnSetId, SchemaShape.Column.VectorKind.Scalar, NumberDataViewType.UInt32, true));
cols.Add(new SchemaShape.Column(Kinds.ScoreColumnKind, SchemaShape.Column.VectorKind.Scalar, TextDataViewType.Instance, false));
cols.Add(new SchemaShape.Column(Kinds.ScoreValueKind, SchemaShape.Column.VectorKind.Scalar, TextDataViewType.Instance, false));
if (isNormalized)
cols.Add(new SchemaShape.Column(Kinds.IsNormalized, SchemaShape.Column.VectorKind.Scalar, BooleanDataViewType.Instance, false));
return cols;
}
/// <summary>
/// Produces annotations for the score column generated by trainer estimators for multiclass classification.
/// If input LabelColumn is not available it produces slotnames annotation by default.
/// </summary>
/// <param name="labelColumn">Label column.</param>
public static IEnumerable<SchemaShape.Column> AnnotationsForMulticlassScoreColumn(SchemaShape.Column? labelColumn = null)
{
var cols = new List<SchemaShape.Column>();
if (labelColumn != null && labelColumn.Value.IsKey)
{
if (labelColumn.Value.Annotations.TryFindColumn(Kinds.KeyValues, out var metaCol) &&
metaCol.Kind == SchemaShape.Column.VectorKind.Vector)
{
if (metaCol.ItemType is TextDataViewType)
cols.Add(new SchemaShape.Column(Kinds.SlotNames, SchemaShape.Column.VectorKind.Vector, TextDataViewType.Instance, false));
cols.Add(new SchemaShape.Column(Kinds.TrainingLabelValues, SchemaShape.Column.VectorKind.Vector, metaCol.ItemType, false));
}
}
cols.AddRange(GetTrainerOutputAnnotation());
return cols;
}
private sealed class AnnotationRow : DataViewRow
{
private readonly DataViewSchema.Annotations _annotations;
public AnnotationRow(DataViewSchema.Annotations annotations)
{
Contracts.AssertValue(annotations);
_annotations = annotations;
}
public override DataViewSchema Schema => _annotations.Schema;
public override long Position => 0;
public override long Batch => 0;
/// <summary>
/// Returns a value getter delegate to fetch the value of column with the given columnIndex, from the row.
/// This throws if the column is not active in this row, or if the type
/// <typeparamref name="TValue"/> differs from this column's type.
/// </summary>
/// <typeparam name="TValue"> is the column's content type.</typeparam>
/// <param name="column"> is the output column whose getter should be returned.</param>
public override ValueGetter<TValue> GetGetter<TValue>(DataViewSchema.Column column) => _annotations.GetGetter<TValue>(column);
public override ValueGetter<DataViewRowId> GetIdGetter() => (ref DataViewRowId dst) => dst = default;
/// <summary>
/// Returns whether the given column is active in this row.
/// </summary>
public override bool IsColumnActive(DataViewSchema.Column column) => true;
}
/// <summary>
/// Presents a <see cref="DataViewSchema.Annotations"/> as a an <see cref="DataViewRow"/>.
/// </summary>
/// <param name="annotations">The annotations to wrap.</param>
/// <returns>A row that wraps an input annotations.</returns>
[BestFriend]
internal static DataViewRow AnnotationsAsRow(DataViewSchema.Annotations annotations)
{
Contracts.CheckValue(annotations, nameof(annotations));
return new AnnotationRow(annotations);
}
}

Просмотреть файл

@ -5,165 +5,164 @@
using System;
using Microsoft.ML.Runtime;
namespace Microsoft.ML.Data
namespace Microsoft.ML.Data;
/// <summary>
/// Extension methods related to the ColumnType class.
/// </summary>
[BestFriend]
internal static class ColumnTypeExtensions
{
/// <summary>
/// Extension methods related to the ColumnType class.
/// Whether this type is a standard scalar type completely determined by its <see cref="DataViewType.RawType"/>
/// (not a <see cref="KeyDataViewType"/> or <see cref="StructuredDataViewType"/>, etc).
/// </summary>
[BestFriend]
internal static class ColumnTypeExtensions
public static bool IsStandardScalar(this DataViewType columnType) =>
(columnType is NumberDataViewType) || (columnType is TextDataViewType) || (columnType is BooleanDataViewType) ||
(columnType is RowIdDataViewType) || (columnType is TimeSpanDataViewType) ||
(columnType is DateTimeDataViewType) || (columnType is DateTimeOffsetDataViewType);
/// <summary>
/// Zero return means it's not a key type.
/// </summary>
public static ulong GetKeyCount(this DataViewType columnType) => (columnType as KeyDataViewType)?.Count ?? 0;
/// <summary>
/// Sometimes it is necessary to cast the Count to an int. This performs overflow check.
/// Zero return means it's not a key type.
/// </summary>
public static int GetKeyCountAsInt32(this DataViewType columnType, IExceptionContext ectx = null)
{
/// <summary>
/// Whether this type is a standard scalar type completely determined by its <see cref="DataViewType.RawType"/>
/// (not a <see cref="KeyDataViewType"/> or <see cref="StructuredDataViewType"/>, etc).
/// </summary>
public static bool IsStandardScalar(this DataViewType columnType) =>
(columnType is NumberDataViewType) || (columnType is TextDataViewType) || (columnType is BooleanDataViewType) ||
(columnType is RowIdDataViewType) || (columnType is TimeSpanDataViewType) ||
(columnType is DateTimeDataViewType) || (columnType is DateTimeOffsetDataViewType);
ulong count = columnType.GetKeyCount();
ectx.Check(count <= int.MaxValue, nameof(KeyDataViewType) + "." + nameof(KeyDataViewType.Count) + " exceeds int.MaxValue.");
return (int)count;
}
/// <summary>
/// Zero return means it's not a key type.
/// </summary>
public static ulong GetKeyCount(this DataViewType columnType) => (columnType as KeyDataViewType)?.Count ?? 0;
/// <summary>
/// For non-vector types, this returns the column type itself (i.e., return <paramref name="columnType"/>).
/// For vector types, this returns the type of the items stored as values in vector.
/// </summary>
public static DataViewType GetItemType(this DataViewType columnType) => (columnType as VectorDataViewType)?.ItemType ?? columnType;
/// <summary>
/// Sometimes it is necessary to cast the Count to an int. This performs overflow check.
/// Zero return means it's not a key type.
/// </summary>
public static int GetKeyCountAsInt32(this DataViewType columnType, IExceptionContext ectx = null)
{
ulong count = columnType.GetKeyCount();
ectx.Check(count <= int.MaxValue, nameof(KeyDataViewType) + "." + nameof(KeyDataViewType.Count) + " exceeds int.MaxValue.");
return (int)count;
}
/// <summary>
/// Zero return means either it's not a vector or the size is unknown.
/// </summary>
public static int GetVectorSize(this DataViewType columnType) => (columnType as VectorDataViewType)?.Size ?? 0;
/// <summary>
/// For non-vector types, this returns the column type itself (i.e., return <paramref name="columnType"/>).
/// For vector types, this returns the type of the items stored as values in vector.
/// </summary>
public static DataViewType GetItemType(this DataViewType columnType) => (columnType as VectorDataViewType)?.ItemType ?? columnType;
/// <summary>
/// For non-vectors, this returns one. For unknown size vectors, it returns zero.
/// For known sized vectors, it returns size.
/// </summary>
public static int GetValueCount(this DataViewType columnType) => (columnType as VectorDataViewType)?.Size ?? 1;
/// <summary>
/// Zero return means either it's not a vector or the size is unknown.
/// </summary>
public static int GetVectorSize(this DataViewType columnType) => (columnType as VectorDataViewType)?.Size ?? 0;
/// <summary>
/// Whether this is a vector type with known size. Returns false for non-vector types.
/// Equivalent to <c><see cref="GetVectorSize"/> &gt; 0</c>.
/// </summary>
public static bool IsKnownSizeVector(this DataViewType columnType) => columnType.GetVectorSize() > 0;
/// <summary>
/// For non-vectors, this returns one. For unknown size vectors, it returns zero.
/// For known sized vectors, it returns size.
/// </summary>
public static int GetValueCount(this DataViewType columnType) => (columnType as VectorDataViewType)?.Size ?? 1;
/// <summary>
/// Gets the equivalent <see cref="InternalDataKind"/> for the <paramref name="columnType"/>'s RawType.
/// This can return default(<see cref="InternalDataKind"/>) if the RawType doesn't have a corresponding
/// <see cref="InternalDataKind"/>.
/// </summary>
public static InternalDataKind GetRawKind(this DataViewType columnType)
{
columnType.RawType.TryGetDataKind(out InternalDataKind result);
return result;
}
/// <summary>
/// Whether this is a vector type with known size. Returns false for non-vector types.
/// Equivalent to <c><see cref="GetVectorSize"/> &gt; 0</c>.
/// </summary>
public static bool IsKnownSizeVector(this DataViewType columnType) => columnType.GetVectorSize() > 0;
/// <summary>
/// Equivalent to calling Equals(ColumnType) for non-vector types. For vector type,
/// returns true if current and other vector types have the same size and item type.
/// </summary>
public static bool SameSizeAndItemType(this DataViewType columnType, DataViewType other)
{
if (other == null)
return false;
/// <summary>
/// Gets the equivalent <see cref="InternalDataKind"/> for the <paramref name="columnType"/>'s RawType.
/// This can return default(<see cref="InternalDataKind"/>) if the RawType doesn't have a corresponding
/// <see cref="InternalDataKind"/>.
/// </summary>
public static InternalDataKind GetRawKind(this DataViewType columnType)
{
columnType.RawType.TryGetDataKind(out InternalDataKind result);
return result;
}
if (columnType.Equals(other))
return true;
/// <summary>
/// Equivalent to calling Equals(ColumnType) for non-vector types. For vector type,
/// returns true if current and other vector types have the same size and item type.
/// </summary>
public static bool SameSizeAndItemType(this DataViewType columnType, DataViewType other)
{
if (other == null)
return false;
// For vector types, we don't care about the factoring of the dimensions.
if (!(columnType is VectorDataViewType vectorType) || !(other is VectorDataViewType otherVectorType))
return false;
if (!vectorType.ItemType.Equals(otherVectorType.ItemType))
return false;
return vectorType.Size == otherVectorType.Size;
}
if (columnType.Equals(other))
return true;
public static PrimitiveDataViewType PrimitiveTypeFromType(Type type)
{
if (type == typeof(ReadOnlyMemory<char>) || type == typeof(string))
return TextDataViewType.Instance;
if (type == typeof(bool))
return BooleanDataViewType.Instance;
if (type == typeof(TimeSpan))
return TimeSpanDataViewType.Instance;
if (type == typeof(DateTime))
return DateTimeDataViewType.Instance;
if (type == typeof(DateTimeOffset))
return DateTimeOffsetDataViewType.Instance;
if (type == typeof(DataViewRowId))
return RowIdDataViewType.Instance;
return NumberTypeFromType(type);
}
// For vector types, we don't care about the factoring of the dimensions.
if (!(columnType is VectorDataViewType vectorType) || !(other is VectorDataViewType otherVectorType))
return false;
if (!vectorType.ItemType.Equals(otherVectorType.ItemType))
return false;
return vectorType.Size == otherVectorType.Size;
}
public static PrimitiveDataViewType PrimitiveTypeFromKind(InternalDataKind kind)
{
if (kind == InternalDataKind.TX)
return TextDataViewType.Instance;
if (kind == InternalDataKind.BL)
return BooleanDataViewType.Instance;
if (kind == InternalDataKind.TS)
return TimeSpanDataViewType.Instance;
if (kind == InternalDataKind.DT)
return DateTimeDataViewType.Instance;
if (kind == InternalDataKind.DZ)
return DateTimeOffsetDataViewType.Instance;
if (kind == InternalDataKind.UG)
return RowIdDataViewType.Instance;
return NumberTypeFromKind(kind);
}
public static PrimitiveDataViewType PrimitiveTypeFromType(Type type)
{
if (type == typeof(ReadOnlyMemory<char>) || type == typeof(string))
return TextDataViewType.Instance;
if (type == typeof(bool))
return BooleanDataViewType.Instance;
if (type == typeof(TimeSpan))
return TimeSpanDataViewType.Instance;
if (type == typeof(DateTime))
return DateTimeDataViewType.Instance;
if (type == typeof(DateTimeOffset))
return DateTimeOffsetDataViewType.Instance;
if (type == typeof(DataViewRowId))
return RowIdDataViewType.Instance;
return NumberTypeFromType(type);
}
public static PrimitiveDataViewType PrimitiveTypeFromKind(InternalDataKind kind)
{
if (kind == InternalDataKind.TX)
return TextDataViewType.Instance;
if (kind == InternalDataKind.BL)
return BooleanDataViewType.Instance;
if (kind == InternalDataKind.TS)
return TimeSpanDataViewType.Instance;
if (kind == InternalDataKind.DT)
return DateTimeDataViewType.Instance;
if (kind == InternalDataKind.DZ)
return DateTimeOffsetDataViewType.Instance;
if (kind == InternalDataKind.UG)
return RowIdDataViewType.Instance;
public static NumberDataViewType NumberTypeFromType(Type type)
{
InternalDataKind kind;
if (type.TryGetDataKind(out kind))
return NumberTypeFromKind(kind);
}
public static NumberDataViewType NumberTypeFromType(Type type)
Contracts.Assert(false);
throw new InvalidOperationException($"Bad type in {nameof(ColumnTypeExtensions)}.{nameof(NumberTypeFromType)}: {type}");
}
private static NumberDataViewType NumberTypeFromKind(InternalDataKind kind)
{
switch (kind)
{
InternalDataKind kind;
if (type.TryGetDataKind(out kind))
return NumberTypeFromKind(kind);
Contracts.Assert(false);
throw new InvalidOperationException($"Bad type in {nameof(ColumnTypeExtensions)}.{nameof(NumberTypeFromType)}: {type}");
case InternalDataKind.I1:
return NumberDataViewType.SByte;
case InternalDataKind.U1:
return NumberDataViewType.Byte;
case InternalDataKind.I2:
return NumberDataViewType.Int16;
case InternalDataKind.U2:
return NumberDataViewType.UInt16;
case InternalDataKind.I4:
return NumberDataViewType.Int32;
case InternalDataKind.U4:
return NumberDataViewType.UInt32;
case InternalDataKind.I8:
return NumberDataViewType.Int64;
case InternalDataKind.U8:
return NumberDataViewType.UInt64;
case InternalDataKind.R4:
return NumberDataViewType.Single;
case InternalDataKind.R8:
return NumberDataViewType.Double;
}
private static NumberDataViewType NumberTypeFromKind(InternalDataKind kind)
{
switch (kind)
{
case InternalDataKind.I1:
return NumberDataViewType.SByte;
case InternalDataKind.U1:
return NumberDataViewType.Byte;
case InternalDataKind.I2:
return NumberDataViewType.Int16;
case InternalDataKind.U2:
return NumberDataViewType.UInt16;
case InternalDataKind.I4:
return NumberDataViewType.Int32;
case InternalDataKind.U4:
return NumberDataViewType.UInt32;
case InternalDataKind.I8:
return NumberDataViewType.Int64;
case InternalDataKind.U8:
return NumberDataViewType.UInt64;
case InternalDataKind.R4:
return NumberDataViewType.Single;
case InternalDataKind.R8:
return NumberDataViewType.Double;
}
Contracts.Assert(false);
throw new InvalidOperationException($"Bad data kind in {nameof(ColumnTypeExtensions)}.{nameof(NumberTypeFromKind)}: {kind}");
}
Contracts.Assert(false);
throw new InvalidOperationException($"Bad data kind in {nameof(ColumnTypeExtensions)}.{nameof(NumberTypeFromKind)}: {kind}");
}
}

Просмотреть файл

@ -4,380 +4,379 @@
using System;
using Microsoft.ML.Runtime;
namespace Microsoft.ML.Data
namespace Microsoft.ML.Data;
/// <summary>
/// Specifies a simple data type.
/// </summary>
/// <remarks>
/// <format type="text/markdown"><![CDATA[
/// Some transforms use the default value and/or missing value of the data types.
/// The table below shows the default value definition for each of the data types.
///
/// | Type | Default Value | IsDefault Indicator |
/// | -- | -- | -- |
/// | <xref:Microsoft.ML.Data.DataKind.String> or [text](xref:Microsoft.ML.Data.TextDataViewType) | Empty or `null` string (both result in empty `System.ReadOnlyMemory<char>` | <xref:System.ReadOnlyMemory`1.IsEmpty*> |
/// | [Key](xref:Microsoft.ML.Data.KeyDataViewType) type (supported by the unsigned integer types in `DataKind`) | Not defined | Always `false` |
/// | All other types | Default value of the corresponding system type as defined by .NET standard. In C#, default value expression `default(T)` provides that value. | Equality test with the default value |
///
/// The table below shows the missing value definition for each of the data types.
///
/// | Type | Missing Value | IsMissing Indicator |
/// | -- | -- | -- |
/// | <xref:Microsoft.ML.Data.DataKind.String> or [text](xref:Microsoft.ML.Data.TextDataViewType) | Not defined | Always `false` |
/// | [Key](xref:Microsoft.ML.Data.KeyDataViewType) type (supported by the unsigned integer types in `DataKind`) | `0` | Equality test with `0` |
/// | <xref:Microsoft.ML.Data.DataKind.Single> | <xref:System.Single.NaN> | <xref:System.Single.IsNaN(System.Single)> |
/// | <xref:Microsoft.ML.Data.DataKind.Double> | <xref:System.Double.NaN> | <xref:System.Double.IsNaN(System.Double)> |
/// | All other types | Not defined | Always `false` |
///
/// ]]>
/// </format>
/// </remarks>
// Data type specifiers mainly used in creating text loader and type converter.
public enum DataKind : byte
{
/// <summary>1-byte integer, type of <see cref="System.SByte"/>.</summary>
SByte = 1,
/// <summary>1-byte unsigned integer, type of <see cref="System.Byte"/>.</summary>
Byte = 2,
/// <summary>2-byte integer, type of <see cref="System.Int16"/>.</summary>
Int16 = 3,
/// <summary>2-byte unsigned integer, type of <see cref="System.UInt16"/>.</summary>
UInt16 = 4,
/// <summary>4-byte integer, type of <see cref="System.Int32"/>.</summary>
Int32 = 5,
/// <summary>4-byte unsigned integer, type of <see cref="System.UInt32"/>.</summary>
UInt32 = 6,
/// <summary>8-byte integer, type of <see cref="System.Int64"/>.</summary>
Int64 = 7,
/// <summary>8-byte unsigned integer, type of <see cref="System.UInt64"/>.</summary>
UInt64 = 8,
/// <summary>4-byte floating-point number, type of <see cref="System.Single"/>.</summary>
Single = 9,
/// <summary>8-byte floating-point number, type of <see cref="System.Double"/>.</summary>
Double = 10,
/// <summary>
/// Specifies a simple data type.
/// string, type of <see cref="System.ReadOnlyMemory{T}"/>, where T is <see cref="char"/>.
/// Also compatible with <see cref="System.String"/>.
/// </summary>
/// <remarks>
/// <format type="text/markdown"><![CDATA[
/// Some transforms use the default value and/or missing value of the data types.
/// The table below shows the default value definition for each of the data types.
///
/// | Type | Default Value | IsDefault Indicator |
/// | -- | -- | -- |
/// | <xref:Microsoft.ML.Data.DataKind.String> or [text](xref:Microsoft.ML.Data.TextDataViewType) | Empty or `null` string (both result in empty `System.ReadOnlyMemory<char>` | <xref:System.ReadOnlyMemory`1.IsEmpty*> |
/// | [Key](xref:Microsoft.ML.Data.KeyDataViewType) type (supported by the unsigned integer types in `DataKind`) | Not defined | Always `false` |
/// | All other types | Default value of the corresponding system type as defined by .NET standard. In C#, default value expression `default(T)` provides that value. | Equality test with the default value |
///
/// The table below shows the missing value definition for each of the data types.
///
/// | Type | Missing Value | IsMissing Indicator |
/// | -- | -- | -- |
/// | <xref:Microsoft.ML.Data.DataKind.String> or [text](xref:Microsoft.ML.Data.TextDataViewType) | Not defined | Always `false` |
/// | [Key](xref:Microsoft.ML.Data.KeyDataViewType) type (supported by the unsigned integer types in `DataKind`) | `0` | Equality test with `0` |
/// | <xref:Microsoft.ML.Data.DataKind.Single> | <xref:System.Single.NaN> | <xref:System.Single.IsNaN(System.Single)> |
/// | <xref:Microsoft.ML.Data.DataKind.Double> | <xref:System.Double.NaN> | <xref:System.Double.IsNaN(System.Double)> |
/// | All other types | Not defined | Always `false` |
///
/// ]]>
/// </format>
/// </remarks>
// Data type specifiers mainly used in creating text loader and type converter.
public enum DataKind : byte
{
/// <summary>1-byte integer, type of <see cref="System.SByte"/>.</summary>
SByte = 1,
/// <summary>1-byte unsigned integer, type of <see cref="System.Byte"/>.</summary>
Byte = 2,
/// <summary>2-byte integer, type of <see cref="System.Int16"/>.</summary>
Int16 = 3,
/// <summary>2-byte unsigned integer, type of <see cref="System.UInt16"/>.</summary>
UInt16 = 4,
/// <summary>4-byte integer, type of <see cref="System.Int32"/>.</summary>
Int32 = 5,
/// <summary>4-byte unsigned integer, type of <see cref="System.UInt32"/>.</summary>
UInt32 = 6,
/// <summary>8-byte integer, type of <see cref="System.Int64"/>.</summary>
Int64 = 7,
/// <summary>8-byte unsigned integer, type of <see cref="System.UInt64"/>.</summary>
UInt64 = 8,
/// <summary>4-byte floating-point number, type of <see cref="System.Single"/>.</summary>
Single = 9,
/// <summary>8-byte floating-point number, type of <see cref="System.Double"/>.</summary>
Double = 10,
/// <summary>
/// string, type of <see cref="System.ReadOnlyMemory{T}"/>, where T is <see cref="char"/>.
/// Also compatible with <see cref="System.String"/>.
/// </summary>
String = 11,
/// <summary>boolean variable type, type of <see cref="System.Boolean"/>.</summary>
Boolean = 12,
/// <summary>type of <see cref="System.TimeSpan"/>.</summary>
TimeSpan = 13,
/// <summary>type of <see cref="System.DateTime"/>.</summary>
DateTime = 14,
/// <summary>type of <see cref="System.DateTimeOffset"/>.</summary>
DateTimeOffset = 15,
}
String = 11,
/// <summary>boolean variable type, type of <see cref="System.Boolean"/>.</summary>
Boolean = 12,
/// <summary>type of <see cref="System.TimeSpan"/>.</summary>
TimeSpan = 13,
/// <summary>type of <see cref="System.DateTime"/>.</summary>
DateTime = 14,
/// <summary>type of <see cref="System.DateTimeOffset"/>.</summary>
DateTimeOffset = 15,
}
/// <summary>
/// Data type specifier used in command line. <see cref="InternalDataKind"/> is the underlying version of <see cref="DataKind"/>
/// used for command line and entry point BC.
/// </summary>
[BestFriend]
internal enum InternalDataKind : byte
{
// Notes:
// * These values are serialized, so changing them breaks binary formats.
// * We intentionally skip zero.
// * Some code depends on sizeof(DataKind) == sizeof(byte).
/// <summary>
/// Data type specifier used in command line. <see cref="InternalDataKind"/> is the underlying version of <see cref="DataKind"/>
/// used for command line and entry point BC.
/// </summary>
[BestFriend]
internal enum InternalDataKind : byte
{
// Notes:
// * These values are serialized, so changing them breaks binary formats.
// * We intentionally skip zero.
// * Some code depends on sizeof(DataKind) == sizeof(byte).
I1 = DataKind.SByte,
U1 = DataKind.Byte,
I2 = DataKind.Int16,
U2 = DataKind.UInt16,
I4 = DataKind.Int32,
U4 = DataKind.UInt32,
I8 = DataKind.Int64,
U8 = DataKind.UInt64,
R4 = DataKind.Single,
R8 = DataKind.Double,
Num = R4,
I1 = DataKind.SByte,
U1 = DataKind.Byte,
I2 = DataKind.Int16,
U2 = DataKind.UInt16,
I4 = DataKind.Int32,
U4 = DataKind.UInt32,
I8 = DataKind.Int64,
U8 = DataKind.UInt64,
R4 = DataKind.Single,
R8 = DataKind.Double,
Num = R4,
TX = DataKind.String,
TX = DataKind.String,
#pragma warning disable MSML_GeneralName // The data kind enum has its own logic, independent of C# naming conventions.
TXT = TX,
Text = TX,
TXT = TX,
Text = TX,
BL = DataKind.Boolean,
Bool = BL,
BL = DataKind.Boolean,
Bool = BL,
TS = DataKind.TimeSpan,
TimeSpan = TS,
DT = DataKind.DateTime,
DateTime = DT,
DZ = DataKind.DateTimeOffset,
DateTimeZone = DZ,
TS = DataKind.TimeSpan,
TimeSpan = TS,
DT = DataKind.DateTime,
DateTime = DT,
DZ = DataKind.DateTimeOffset,
DateTimeZone = DZ,
UG = 16, // Unsigned 16-byte integer.
U16 = UG,
UG = 16, // Unsigned 16-byte integer.
U16 = UG,
#pragma warning restore MSML_GeneralName
}
/// <summary>
/// Extension methods related to the DataKind enum.
/// </summary>
[BestFriend]
internal static class InternalDataKindExtensions
{
public const InternalDataKind KindMin = InternalDataKind.I1;
public const InternalDataKind KindLim = InternalDataKind.U16 + 1;
public const int KindCount = KindLim - KindMin;
/// <summary>
/// Maps a DataKind to a value suitable for indexing into an array of size KindCount.
/// </summary>
public static int ToIndex(this InternalDataKind kind)
{
return kind - KindMin;
}
/// <summary>
/// Extension methods related to the DataKind enum.
/// Maps from an index into an array of size KindCount to the corresponding DataKind
/// </summary>
[BestFriend]
internal static class InternalDataKindExtensions
public static InternalDataKind FromIndex(int index)
{
public const InternalDataKind KindMin = InternalDataKind.I1;
public const InternalDataKind KindLim = InternalDataKind.U16 + 1;
public const int KindCount = KindLim - KindMin;
Contracts.Check(0 <= index && index < KindCount);
return (InternalDataKind)(index + (int)KindMin);
}
/// <summary>
/// Maps a DataKind to a value suitable for indexing into an array of size KindCount.
/// </summary>
public static int ToIndex(this InternalDataKind kind)
/// <summary>
/// This function converts <paramref name="dataKind"/> to <see cref="InternalDataKind"/>.
/// Because <see cref="DataKind"/> is a subset of <see cref="InternalDataKind"/>, the conversion is straightforward.
/// </summary>
public static InternalDataKind ToInternalDataKind(this DataKind dataKind) => (InternalDataKind)dataKind;
/// <summary>
/// This function converts <paramref name="kind"/> to <see cref="DataKind"/>.
/// Because <see cref="DataKind"/> is a subset of <see cref="InternalDataKind"/>, we should check if <paramref name="kind"/>
/// can be found in <see cref="DataKind"/>.
/// </summary>
public static DataKind ToDataKind(this InternalDataKind kind)
{
Contracts.Check(kind != InternalDataKind.UG);
return (DataKind)kind;
}
/// <summary>
/// For integer DataKinds, this returns the maximum legal value. For un-supported kinds,
/// it returns zero.
/// </summary>
public static ulong ToMaxInt(this InternalDataKind kind)
{
switch (kind)
{
return kind - KindMin;
}
/// <summary>
/// Maps from an index into an array of size KindCount to the corresponding DataKind
/// </summary>
public static InternalDataKind FromIndex(int index)
{
Contracts.Check(0 <= index && index < KindCount);
return (InternalDataKind)(index + (int)KindMin);
}
/// <summary>
/// This function converts <paramref name="dataKind"/> to <see cref="InternalDataKind"/>.
/// Because <see cref="DataKind"/> is a subset of <see cref="InternalDataKind"/>, the conversion is straightforward.
/// </summary>
public static InternalDataKind ToInternalDataKind(this DataKind dataKind) => (InternalDataKind)dataKind;
/// <summary>
/// This function converts <paramref name="kind"/> to <see cref="DataKind"/>.
/// Because <see cref="DataKind"/> is a subset of <see cref="InternalDataKind"/>, we should check if <paramref name="kind"/>
/// can be found in <see cref="DataKind"/>.
/// </summary>
public static DataKind ToDataKind(this InternalDataKind kind)
{
Contracts.Check(kind != InternalDataKind.UG);
return (DataKind)kind;
}
/// <summary>
/// For integer DataKinds, this returns the maximum legal value. For un-supported kinds,
/// it returns zero.
/// </summary>
public static ulong ToMaxInt(this InternalDataKind kind)
{
switch (kind)
{
case InternalDataKind.I1:
return (ulong)sbyte.MaxValue;
case InternalDataKind.U1:
return byte.MaxValue;
case InternalDataKind.I2:
return (ulong)short.MaxValue;
case InternalDataKind.U2:
return ushort.MaxValue;
case InternalDataKind.I4:
return int.MaxValue;
case InternalDataKind.U4:
return uint.MaxValue;
case InternalDataKind.I8:
return long.MaxValue;
case InternalDataKind.U8:
return ulong.MaxValue;
}
return 0;
}
/// <summary>
/// For integer Types, this returns the maximum legal value. For un-supported Types,
/// it returns zero.
/// </summary>
public static ulong ToMaxInt(this Type type)
{
if (type == typeof(sbyte))
case InternalDataKind.I1:
return (ulong)sbyte.MaxValue;
else if (type == typeof(byte))
case InternalDataKind.U1:
return byte.MaxValue;
else if (type == typeof(short))
case InternalDataKind.I2:
return (ulong)short.MaxValue;
else if (type == typeof(ushort))
case InternalDataKind.U2:
return ushort.MaxValue;
else if (type == typeof(int))
case InternalDataKind.I4:
return int.MaxValue;
else if (type == typeof(uint))
case InternalDataKind.U4:
return uint.MaxValue;
else if (type == typeof(long))
case InternalDataKind.I8:
return long.MaxValue;
else if (type == typeof(ulong))
case InternalDataKind.U8:
return ulong.MaxValue;
return 0;
}
/// <summary>
/// For integer DataKinds, this returns the minimum legal value. For un-supported kinds,
/// it returns one.
/// </summary>
public static long ToMinInt(this InternalDataKind kind)
return 0;
}
/// <summary>
/// For integer Types, this returns the maximum legal value. For un-supported Types,
/// it returns zero.
/// </summary>
public static ulong ToMaxInt(this Type type)
{
if (type == typeof(sbyte))
return (ulong)sbyte.MaxValue;
else if (type == typeof(byte))
return byte.MaxValue;
else if (type == typeof(short))
return (ulong)short.MaxValue;
else if (type == typeof(ushort))
return ushort.MaxValue;
else if (type == typeof(int))
return int.MaxValue;
else if (type == typeof(uint))
return uint.MaxValue;
else if (type == typeof(long))
return long.MaxValue;
else if (type == typeof(ulong))
return ulong.MaxValue;
return 0;
}
/// <summary>
/// For integer DataKinds, this returns the minimum legal value. For un-supported kinds,
/// it returns one.
/// </summary>
public static long ToMinInt(this InternalDataKind kind)
{
switch (kind)
{
switch (kind)
{
case InternalDataKind.I1:
return sbyte.MinValue;
case InternalDataKind.U1:
return byte.MinValue;
case InternalDataKind.I2:
return short.MinValue;
case InternalDataKind.U2:
return ushort.MinValue;
case InternalDataKind.I4:
return int.MinValue;
case InternalDataKind.U4:
return uint.MinValue;
case InternalDataKind.I8:
return long.MinValue;
case InternalDataKind.U8:
return 0;
}
return 1;
case InternalDataKind.I1:
return sbyte.MinValue;
case InternalDataKind.U1:
return byte.MinValue;
case InternalDataKind.I2:
return short.MinValue;
case InternalDataKind.U2:
return ushort.MinValue;
case InternalDataKind.I4:
return int.MinValue;
case InternalDataKind.U4:
return uint.MinValue;
case InternalDataKind.I8:
return long.MinValue;
case InternalDataKind.U8:
return 0;
}
/// <summary>
/// Maps a DataKind to the associated .Net representation type.
/// </summary>
public static Type ToType(this InternalDataKind kind)
return 1;
}
/// <summary>
/// Maps a DataKind to the associated .Net representation type.
/// </summary>
public static Type ToType(this InternalDataKind kind)
{
switch (kind)
{
switch (kind)
{
case InternalDataKind.I1:
return typeof(sbyte);
case InternalDataKind.U1:
return typeof(byte);
case InternalDataKind.I2:
return typeof(short);
case InternalDataKind.U2:
return typeof(ushort);
case InternalDataKind.I4:
return typeof(int);
case InternalDataKind.U4:
return typeof(uint);
case InternalDataKind.I8:
return typeof(long);
case InternalDataKind.U8:
return typeof(ulong);
case InternalDataKind.R4:
return typeof(Single);
case InternalDataKind.R8:
return typeof(Double);
case InternalDataKind.TX:
return typeof(ReadOnlyMemory<char>);
case InternalDataKind.BL:
return typeof(bool);
case InternalDataKind.TS:
return typeof(TimeSpan);
case InternalDataKind.DT:
return typeof(DateTime);
case InternalDataKind.DZ:
return typeof(DateTimeOffset);
case InternalDataKind.UG:
return typeof(DataViewRowId);
}
return null;
case InternalDataKind.I1:
return typeof(sbyte);
case InternalDataKind.U1:
return typeof(byte);
case InternalDataKind.I2:
return typeof(short);
case InternalDataKind.U2:
return typeof(ushort);
case InternalDataKind.I4:
return typeof(int);
case InternalDataKind.U4:
return typeof(uint);
case InternalDataKind.I8:
return typeof(long);
case InternalDataKind.U8:
return typeof(ulong);
case InternalDataKind.R4:
return typeof(Single);
case InternalDataKind.R8:
return typeof(Double);
case InternalDataKind.TX:
return typeof(ReadOnlyMemory<char>);
case InternalDataKind.BL:
return typeof(bool);
case InternalDataKind.TS:
return typeof(TimeSpan);
case InternalDataKind.DT:
return typeof(DateTime);
case InternalDataKind.DZ:
return typeof(DateTimeOffset);
case InternalDataKind.UG:
return typeof(DataViewRowId);
}
/// <summary>
/// Try to map a System.Type to a corresponding DataKind value.
/// </summary>
public static bool TryGetDataKind(this Type type, out InternalDataKind kind)
return null;
}
/// <summary>
/// Try to map a System.Type to a corresponding DataKind value.
/// </summary>
public static bool TryGetDataKind(this Type type, out InternalDataKind kind)
{
Contracts.CheckValueOrNull(type);
// REVIEW: Make this more efficient. Should we have a global dictionary?
if (type == typeof(sbyte))
kind = InternalDataKind.I1;
else if (type == typeof(byte))
kind = InternalDataKind.U1;
else if (type == typeof(short))
kind = InternalDataKind.I2;
else if (type == typeof(ushort))
kind = InternalDataKind.U2;
else if (type == typeof(int))
kind = InternalDataKind.I4;
else if (type == typeof(uint))
kind = InternalDataKind.U4;
else if (type == typeof(long))
kind = InternalDataKind.I8;
else if (type == typeof(ulong))
kind = InternalDataKind.U8;
else if (type == typeof(Single))
kind = InternalDataKind.R4;
else if (type == typeof(Double))
kind = InternalDataKind.R8;
else if (type == typeof(ReadOnlyMemory<char>) || type == typeof(string))
kind = InternalDataKind.TX;
else if (type == typeof(bool))
kind = InternalDataKind.BL;
else if (type == typeof(TimeSpan))
kind = InternalDataKind.TS;
else if (type == typeof(DateTime))
kind = InternalDataKind.DT;
else if (type == typeof(DateTimeOffset))
kind = InternalDataKind.DZ;
else if (type == typeof(DataViewRowId))
kind = InternalDataKind.UG;
else
{
Contracts.CheckValueOrNull(type);
// REVIEW: Make this more efficient. Should we have a global dictionary?
if (type == typeof(sbyte))
kind = InternalDataKind.I1;
else if (type == typeof(byte))
kind = InternalDataKind.U1;
else if (type == typeof(short))
kind = InternalDataKind.I2;
else if (type == typeof(ushort))
kind = InternalDataKind.U2;
else if (type == typeof(int))
kind = InternalDataKind.I4;
else if (type == typeof(uint))
kind = InternalDataKind.U4;
else if (type == typeof(long))
kind = InternalDataKind.I8;
else if (type == typeof(ulong))
kind = InternalDataKind.U8;
else if (type == typeof(Single))
kind = InternalDataKind.R4;
else if (type == typeof(Double))
kind = InternalDataKind.R8;
else if (type == typeof(ReadOnlyMemory<char>) || type == typeof(string))
kind = InternalDataKind.TX;
else if (type == typeof(bool))
kind = InternalDataKind.BL;
else if (type == typeof(TimeSpan))
kind = InternalDataKind.TS;
else if (type == typeof(DateTime))
kind = InternalDataKind.DT;
else if (type == typeof(DateTimeOffset))
kind = InternalDataKind.DZ;
else if (type == typeof(DataViewRowId))
kind = InternalDataKind.UG;
else
{
kind = default(InternalDataKind);
return false;
}
return true;
kind = default(InternalDataKind);
return false;
}
/// <summary>
/// Get the canonical string for a DataKind. Note that using DataKind.ToString() is not stable
/// and is also slow, so use this instead.
/// </summary>
public static string GetString(this InternalDataKind kind)
return true;
}
/// <summary>
/// Get the canonical string for a DataKind. Note that using DataKind.ToString() is not stable
/// and is also slow, so use this instead.
/// </summary>
public static string GetString(this InternalDataKind kind)
{
switch (kind)
{
switch (kind)
{
case InternalDataKind.I1:
return "I1";
case InternalDataKind.I2:
return "I2";
case InternalDataKind.I4:
return "I4";
case InternalDataKind.I8:
return "I8";
case InternalDataKind.U1:
return "U1";
case InternalDataKind.U2:
return "U2";
case InternalDataKind.U4:
return "U4";
case InternalDataKind.U8:
return "U8";
case InternalDataKind.R4:
return "R4";
case InternalDataKind.R8:
return "R8";
case InternalDataKind.BL:
return "BL";
case InternalDataKind.TX:
return "TX";
case InternalDataKind.TS:
return "TS";
case InternalDataKind.DT:
return "DT";
case InternalDataKind.DZ:
return "DZ";
case InternalDataKind.UG:
return "UG";
}
return "";
case InternalDataKind.I1:
return "I1";
case InternalDataKind.I2:
return "I2";
case InternalDataKind.I4:
return "I4";
case InternalDataKind.I8:
return "I8";
case InternalDataKind.U1:
return "U1";
case InternalDataKind.U2:
return "U2";
case InternalDataKind.U4:
return "U4";
case InternalDataKind.U8:
return "U8";
case InternalDataKind.R4:
return "R4";
case InternalDataKind.R8:
return "R8";
case InternalDataKind.BL:
return "BL";
case InternalDataKind.TX:
return "TX";
case InternalDataKind.TS:
return "TS";
case InternalDataKind.DT:
return "DT";
case InternalDataKind.DZ:
return "DZ";
case InternalDataKind.UG:
return "UG";
}
return "";
}
}

Просмотреть файл

@ -2,17 +2,16 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
namespace Microsoft.ML.Command
{
/// <summary>
/// The signature for commands.
/// </summary>
[BestFriend]
internal delegate void SignatureCommand();
namespace Microsoft.ML.Command;
[BestFriend]
internal interface ICommand
{
void Run();
}
/// <summary>
/// The signature for commands.
/// </summary>
[BestFriend]
internal delegate void SignatureCommand();
[BestFriend]
internal interface ICommand
{
void Run();
}

Просмотреть файл

@ -8,315 +8,314 @@ using System.Linq;
using Microsoft.ML.Data;
using Microsoft.ML.Runtime;
namespace Microsoft.ML
namespace Microsoft.ML;
/// <summary>
/// A set of 'requirements' to the incoming schema, as well as a set of 'promises' of the outgoing schema.
/// This is more relaxed than the proper <see cref="DataViewSchema"/>, since it's only a subset of the columns,
/// and also since it doesn't specify exact <see cref="DataViewType"/>'s for vectors and keys.
/// </summary>
public sealed class SchemaShape : IReadOnlyList<SchemaShape.Column>
{
/// <summary>
/// A set of 'requirements' to the incoming schema, as well as a set of 'promises' of the outgoing schema.
/// This is more relaxed than the proper <see cref="DataViewSchema"/>, since it's only a subset of the columns,
/// and also since it doesn't specify exact <see cref="DataViewType"/>'s for vectors and keys.
/// </summary>
public sealed class SchemaShape : IReadOnlyList<SchemaShape.Column>
private readonly Column[] _columns;
private static readonly SchemaShape _empty = new SchemaShape(Enumerable.Empty<Column>());
public int Count => _columns.Count();
public Column this[int index] => _columns[index];
public struct Column
{
private readonly Column[] _columns;
private static readonly SchemaShape _empty = new SchemaShape(Enumerable.Empty<Column>());
public int Count => _columns.Count();
public Column this[int index] => _columns[index];
public struct Column
public enum VectorKind
{
public enum VectorKind
{
Scalar,
Vector,
VariableVector
}
/// <summary>
/// The column name.
/// </summary>
public readonly string Name;
/// <summary>
/// The type of the column: scalar, fixed vector or variable vector.
/// </summary>
public readonly VectorKind Kind;
/// <summary>
/// The 'raw' type of column item: must be a primitive type or a structured type.
/// </summary>
public readonly DataViewType ItemType;
/// <summary>
/// The flag whether the column is actually a key. If yes, <see cref="ItemType"/> is representing
/// the underlying primitive type.
/// </summary>
public readonly bool IsKey;
/// <summary>
/// The annotations that are present for this column.
/// </summary>
public readonly SchemaShape Annotations;
[BestFriend]
internal Column(string name, VectorKind vecKind, DataViewType itemType, bool isKey, SchemaShape annotations = null)
{
Contracts.CheckNonEmpty(name, nameof(name));
Contracts.CheckValueOrNull(annotations);
Contracts.CheckParam(!(itemType is KeyDataViewType), nameof(itemType), "Item type cannot be a key");
Contracts.CheckParam(!(itemType is VectorDataViewType), nameof(itemType), "Item type cannot be a vector");
Contracts.CheckParam(!isKey || KeyDataViewType.IsValidDataType(itemType.RawType), nameof(itemType), "The item type must be valid for a key");
Name = name;
Kind = vecKind;
ItemType = itemType;
IsKey = isKey;
Annotations = annotations ?? _empty;
}
/// <summary>
/// Returns whether <paramref name="source"/> is a valid input, if this object represents a
/// requirement.
///
/// Namely, it returns true iff:
/// - The <see cref="Name"/>, <see cref="Kind"/>, <see cref="ItemType"/>, <see cref="IsKey"/> fields match.
/// - The columns of <see cref="Annotations"/> of <paramref name="source"/> is a superset of our <see cref="Annotations"/> columns.
/// - Each such annotation column is itself compatible with the input annotation column.
/// </summary>
[BestFriend]
internal bool IsCompatibleWith(Column source)
{
Contracts.Check(source.IsValid, nameof(source));
if (Name != source.Name)
return false;
if (Kind != source.Kind)
return false;
if (!ItemType.Equals(source.ItemType))
return false;
if (IsKey != source.IsKey)
return false;
foreach (var annotationCol in Annotations)
{
if (!source.Annotations.TryFindColumn(annotationCol.Name, out var inputAnnotationCol))
return false;
if (!annotationCol.IsCompatibleWith(inputAnnotationCol))
return false;
}
return true;
}
[BestFriend]
internal string GetTypeString()
{
string result = ItemType.ToString();
if (IsKey)
result = $"Key<{result}>";
if (Kind == VectorKind.Vector)
result = $"Vector<{result}>";
else if (Kind == VectorKind.VariableVector)
result = $"VarVector<{result}>";
return result;
}
/// <summary>
/// Return if this structure is not identical to the default value of <see cref="Column"/>. If true,
/// it means this structure is initialized properly and therefore considered as valid.
/// </summary>
[BestFriend]
internal bool IsValid => Name != null;
}
public SchemaShape(IEnumerable<Column> columns)
{
Contracts.CheckValue(columns, nameof(columns));
_columns = columns.ToArray();
Contracts.CheckParam(columns.All(c => c.IsValid), nameof(columns), "Some items are not initialized properly.");
Scalar,
Vector,
VariableVector
}
/// <summary>
/// Given a <paramref name="type"/>, extract the type parameters that describe this type
/// as a <see cref="SchemaShape"/>'s column type.
/// The column name.
/// </summary>
/// <param name="type">The actual column type to process.</param>
/// <param name="vecKind">The vector kind of <paramref name="type"/>.</param>
/// <param name="itemType">The item type of <paramref name="type"/>.</param>
/// <param name="isKey">Whether <paramref name="type"/> (or its item type) is a key.</param>
[BestFriend]
internal static void GetColumnTypeShape(DataViewType type,
out Column.VectorKind vecKind,
out DataViewType itemType,
out bool isKey)
{
if (type is VectorDataViewType vectorType)
{
if (vectorType.IsKnownSize)
{
vecKind = Column.VectorKind.Vector;
}
else
{
vecKind = Column.VectorKind.VariableVector;
}
public readonly string Name;
itemType = vectorType.ItemType;
/// <summary>
/// The type of the column: scalar, fixed vector or variable vector.
/// </summary>
public readonly VectorKind Kind;
/// <summary>
/// The 'raw' type of column item: must be a primitive type or a structured type.
/// </summary>
public readonly DataViewType ItemType;
/// <summary>
/// The flag whether the column is actually a key. If yes, <see cref="ItemType"/> is representing
/// the underlying primitive type.
/// </summary>
public readonly bool IsKey;
/// <summary>
/// The annotations that are present for this column.
/// </summary>
public readonly SchemaShape Annotations;
[BestFriend]
internal Column(string name, VectorKind vecKind, DataViewType itemType, bool isKey, SchemaShape annotations = null)
{
Contracts.CheckNonEmpty(name, nameof(name));
Contracts.CheckValueOrNull(annotations);
Contracts.CheckParam(!(itemType is KeyDataViewType), nameof(itemType), "Item type cannot be a key");
Contracts.CheckParam(!(itemType is VectorDataViewType), nameof(itemType), "Item type cannot be a vector");
Contracts.CheckParam(!isKey || KeyDataViewType.IsValidDataType(itemType.RawType), nameof(itemType), "The item type must be valid for a key");
Name = name;
Kind = vecKind;
ItemType = itemType;
IsKey = isKey;
Annotations = annotations ?? _empty;
}
/// <summary>
/// Returns whether <paramref name="source"/> is a valid input, if this object represents a
/// requirement.
///
/// Namely, it returns true iff:
/// - The <see cref="Name"/>, <see cref="Kind"/>, <see cref="ItemType"/>, <see cref="IsKey"/> fields match.
/// - The columns of <see cref="Annotations"/> of <paramref name="source"/> is a superset of our <see cref="Annotations"/> columns.
/// - Each such annotation column is itself compatible with the input annotation column.
/// </summary>
[BestFriend]
internal bool IsCompatibleWith(Column source)
{
Contracts.Check(source.IsValid, nameof(source));
if (Name != source.Name)
return false;
if (Kind != source.Kind)
return false;
if (!ItemType.Equals(source.ItemType))
return false;
if (IsKey != source.IsKey)
return false;
foreach (var annotationCol in Annotations)
{
if (!source.Annotations.TryFindColumn(annotationCol.Name, out var inputAnnotationCol))
return false;
if (!annotationCol.IsCompatibleWith(inputAnnotationCol))
return false;
}
return true;
}
[BestFriend]
internal string GetTypeString()
{
string result = ItemType.ToString();
if (IsKey)
result = $"Key<{result}>";
if (Kind == VectorKind.Vector)
result = $"Vector<{result}>";
else if (Kind == VectorKind.VariableVector)
result = $"VarVector<{result}>";
return result;
}
/// <summary>
/// Return if this structure is not identical to the default value of <see cref="Column"/>. If true,
/// it means this structure is initialized properly and therefore considered as valid.
/// </summary>
[BestFriend]
internal bool IsValid => Name != null;
}
public SchemaShape(IEnumerable<Column> columns)
{
Contracts.CheckValue(columns, nameof(columns));
_columns = columns.ToArray();
Contracts.CheckParam(columns.All(c => c.IsValid), nameof(columns), "Some items are not initialized properly.");
}
/// <summary>
/// Given a <paramref name="type"/>, extract the type parameters that describe this type
/// as a <see cref="SchemaShape"/>'s column type.
/// </summary>
/// <param name="type">The actual column type to process.</param>
/// <param name="vecKind">The vector kind of <paramref name="type"/>.</param>
/// <param name="itemType">The item type of <paramref name="type"/>.</param>
/// <param name="isKey">Whether <paramref name="type"/> (or its item type) is a key.</param>
[BestFriend]
internal static void GetColumnTypeShape(DataViewType type,
out Column.VectorKind vecKind,
out DataViewType itemType,
out bool isKey)
{
if (type is VectorDataViewType vectorType)
{
if (vectorType.IsKnownSize)
{
vecKind = Column.VectorKind.Vector;
}
else
{
vecKind = Column.VectorKind.Scalar;
itemType = type;
vecKind = Column.VectorKind.VariableVector;
}
isKey = itemType is KeyDataViewType;
if (isKey)
itemType = ColumnTypeExtensions.PrimitiveTypeFromType(itemType.RawType);
itemType = vectorType.ItemType;
}
/// <summary>
/// Create a schema shape out of the fully defined schema.
/// </summary>
[BestFriend]
internal static SchemaShape Create(DataViewSchema schema)
else
{
Contracts.CheckValue(schema, nameof(schema));
var cols = new List<Column>();
for (int iCol = 0; iCol < schema.Count; iCol++)
{
if (!schema[iCol].IsHidden)
{
// First create the annotations.
var mCols = new List<Column>();
foreach (var annotationColumn in schema[iCol].Annotations.Schema)
{
GetColumnTypeShape(annotationColumn.Type, out var mVecKind, out var mItemType, out var mIsKey);
mCols.Add(new Column(annotationColumn.Name, mVecKind, mItemType, mIsKey));
}
var annotations = mCols.Count > 0 ? new SchemaShape(mCols) : _empty;
// Next create the single column.
GetColumnTypeShape(schema[iCol].Type, out var vecKind, out var itemType, out var isKey);
cols.Add(new Column(schema[iCol].Name, vecKind, itemType, isKey, annotations));
}
}
return new SchemaShape(cols);
vecKind = Column.VectorKind.Scalar;
itemType = type;
}
/// <summary>
/// Returns if there is a column with a specified <paramref name="name"/> and if so stores it in <paramref name="column"/>.
/// </summary>
[BestFriend]
internal bool TryFindColumn(string name, out Column column)
{
Contracts.CheckValue(name, nameof(name));
column = _columns.FirstOrDefault(x => x.Name == name);
return column.IsValid;
}
public IEnumerator<Column> GetEnumerator() => ((IEnumerable<Column>)_columns).GetEnumerator();
IEnumerator IEnumerable.GetEnumerator() => GetEnumerator();
// REVIEW: I think we should have an IsCompatible method to check if it's OK to use one schema shape
// as an input to another schema shape. I started writing, but realized that there's more than one way to check for
// the 'compatibility': as in, 'CAN be compatible' vs. 'WILL be compatible'.
isKey = itemType is KeyDataViewType;
if (isKey)
itemType = ColumnTypeExtensions.PrimitiveTypeFromType(itemType.RawType);
}
/// <summary>
/// The 'data loader' takes a certain kind of input and turns it into an <see cref="IDataView"/>.
/// Create a schema shape out of the fully defined schema.
/// </summary>
/// <typeparam name="TSource">The type of input the loader takes.</typeparam>
public interface IDataLoader<in TSource> : ICanSaveModel
{
/// <summary>
/// Produce the data view from the specified input.
/// Note that <see cref="IDataView"/>'s are lazy, so no actual loading happens here, just schema validation.
/// </summary>
IDataView Load(TSource input);
/// <summary>
/// The output schema of the loader.
/// </summary>
DataViewSchema GetOutputSchema();
}
/// <summary>
/// Sometimes we need to 'fit' an <see cref="IDataLoader{TIn}"/>.
/// A DataLoader estimator is the object that does it.
/// </summary>
public interface IDataLoaderEstimator<in TSource, out TLoader>
where TLoader : IDataLoader<TSource>
{
// REVIEW: you could consider the transformer to take a different <typeparamref name="TSource"/>, but we don't have such components
// yet, so why complicate matters?
/// <summary>
/// Train and return a data loader.
/// </summary>
TLoader Fit(TSource input);
/// <summary>
/// The 'promise' of the output schema.
/// It will be used for schema propagation.
/// </summary>
SchemaShape GetOutputSchema();
}
/// <summary>
/// The transformer is a component that transforms data.
/// It also supports 'schema propagation' to answer the question of 'how will the data with this schema look, after you transform it?'.
/// </summary>
public interface ITransformer : ICanSaveModel
{
/// <summary>
/// Schema propagation for transformers.
/// Returns the output schema of the data, if the input schema is like the one provided.
/// </summary>
DataViewSchema GetOutputSchema(DataViewSchema inputSchema);
/// <summary>
/// Take the data in, make transformations, output the data.
/// Note that <see cref="IDataView"/>'s are lazy, so no actual transformations happen here, just schema validation.
/// </summary>
IDataView Transform(IDataView input);
/// <summary>
/// Whether a call to <see cref="GetRowToRowMapper(DataViewSchema)"/> should succeed, on an
/// appropriate schema.
/// </summary>
bool IsRowToRowMapper { get; }
/// <summary>
/// Constructs a row-to-row mapper based on an input schema. If <see cref="IsRowToRowMapper"/>
/// is <c>false</c>, then an exception should be thrown. If the input schema is in any way
/// unsuitable for constructing the mapper, an exception should likewise be thrown.
/// </summary>
/// <param name="inputSchema">The input schema for which we should get the mapper.</param>
/// <returns>The row to row mapper.</returns>
IRowToRowMapper GetRowToRowMapper(DataViewSchema inputSchema);
}
[BestFriend]
internal interface ITransformerWithDifferentMappingAtTrainingTime : ITransformer
internal static SchemaShape Create(DataViewSchema schema)
{
IDataView TransformForTrainingPipeline(IDataView input);
Contracts.CheckValue(schema, nameof(schema));
var cols = new List<Column>();
for (int iCol = 0; iCol < schema.Count; iCol++)
{
if (!schema[iCol].IsHidden)
{
// First create the annotations.
var mCols = new List<Column>();
foreach (var annotationColumn in schema[iCol].Annotations.Schema)
{
GetColumnTypeShape(annotationColumn.Type, out var mVecKind, out var mItemType, out var mIsKey);
mCols.Add(new Column(annotationColumn.Name, mVecKind, mItemType, mIsKey));
}
var annotations = mCols.Count > 0 ? new SchemaShape(mCols) : _empty;
// Next create the single column.
GetColumnTypeShape(schema[iCol].Type, out var vecKind, out var itemType, out var isKey);
cols.Add(new Column(schema[iCol].Name, vecKind, itemType, isKey, annotations));
}
}
return new SchemaShape(cols);
}
/// <summary>
/// The estimator (in Spark terminology) is an 'untrained transformer'. It needs to 'fit' on the data to manufacture
/// a transformer.
/// It also provides the 'schema propagation' like transformers do, but over <see cref="SchemaShape"/> instead of <see cref="DataViewSchema"/>.
/// Returns if there is a column with a specified <paramref name="name"/> and if so stores it in <paramref name="column"/>.
/// </summary>
public interface IEstimator<out TTransformer>
where TTransformer : ITransformer
[BestFriend]
internal bool TryFindColumn(string name, out Column column)
{
/// <summary>
/// Train and return a transformer.
/// </summary>
TTransformer Fit(IDataView input);
/// <summary>
/// Schema propagation for estimators.
/// Returns the output schema shape of the estimator, if the input schema shape is like the one provided.
/// </summary>
SchemaShape GetOutputSchema(SchemaShape inputSchema);
Contracts.CheckValue(name, nameof(name));
column = _columns.FirstOrDefault(x => x.Name == name);
return column.IsValid;
}
public IEnumerator<Column> GetEnumerator() => ((IEnumerable<Column>)_columns).GetEnumerator();
IEnumerator IEnumerable.GetEnumerator() => GetEnumerator();
// REVIEW: I think we should have an IsCompatible method to check if it's OK to use one schema shape
// as an input to another schema shape. I started writing, but realized that there's more than one way to check for
// the 'compatibility': as in, 'CAN be compatible' vs. 'WILL be compatible'.
}
/// <summary>
/// The 'data loader' takes a certain kind of input and turns it into an <see cref="IDataView"/>.
/// </summary>
/// <typeparam name="TSource">The type of input the loader takes.</typeparam>
public interface IDataLoader<in TSource> : ICanSaveModel
{
/// <summary>
/// Produce the data view from the specified input.
/// Note that <see cref="IDataView"/>'s are lazy, so no actual loading happens here, just schema validation.
/// </summary>
IDataView Load(TSource input);
/// <summary>
/// The output schema of the loader.
/// </summary>
DataViewSchema GetOutputSchema();
}
/// <summary>
/// Sometimes we need to 'fit' an <see cref="IDataLoader{TIn}"/>.
/// A DataLoader estimator is the object that does it.
/// </summary>
public interface IDataLoaderEstimator<in TSource, out TLoader>
where TLoader : IDataLoader<TSource>
{
// REVIEW: you could consider the transformer to take a different <typeparamref name="TSource"/>, but we don't have such components
// yet, so why complicate matters?
/// <summary>
/// Train and return a data loader.
/// </summary>
TLoader Fit(TSource input);
/// <summary>
/// The 'promise' of the output schema.
/// It will be used for schema propagation.
/// </summary>
SchemaShape GetOutputSchema();
}
/// <summary>
/// The transformer is a component that transforms data.
/// It also supports 'schema propagation' to answer the question of 'how will the data with this schema look, after you transform it?'.
/// </summary>
public interface ITransformer : ICanSaveModel
{
/// <summary>
/// Schema propagation for transformers.
/// Returns the output schema of the data, if the input schema is like the one provided.
/// </summary>
DataViewSchema GetOutputSchema(DataViewSchema inputSchema);
/// <summary>
/// Take the data in, make transformations, output the data.
/// Note that <see cref="IDataView"/>'s are lazy, so no actual transformations happen here, just schema validation.
/// </summary>
IDataView Transform(IDataView input);
/// <summary>
/// Whether a call to <see cref="GetRowToRowMapper(DataViewSchema)"/> should succeed, on an
/// appropriate schema.
/// </summary>
bool IsRowToRowMapper { get; }
/// <summary>
/// Constructs a row-to-row mapper based on an input schema. If <see cref="IsRowToRowMapper"/>
/// is <c>false</c>, then an exception should be thrown. If the input schema is in any way
/// unsuitable for constructing the mapper, an exception should likewise be thrown.
/// </summary>
/// <param name="inputSchema">The input schema for which we should get the mapper.</param>
/// <returns>The row to row mapper.</returns>
IRowToRowMapper GetRowToRowMapper(DataViewSchema inputSchema);
}
[BestFriend]
internal interface ITransformerWithDifferentMappingAtTrainingTime : ITransformer
{
IDataView TransformForTrainingPipeline(IDataView input);
}
/// <summary>
/// The estimator (in Spark terminology) is an 'untrained transformer'. It needs to 'fit' on the data to manufacture
/// a transformer.
/// It also provides the 'schema propagation' like transformers do, but over <see cref="SchemaShape"/> instead of <see cref="DataViewSchema"/>.
/// </summary>
public interface IEstimator<out TTransformer>
where TTransformer : ITransformer
{
/// <summary>
/// Train and return a transformer.
/// </summary>
TTransformer Fit(IDataView input);
/// <summary>
/// Schema propagation for estimators.
/// Returns the output schema shape of the estimator, if the input schema shape is like the one provided.
/// </summary>
SchemaShape GetOutputSchema(SchemaShape inputSchema);
}

Просмотреть файл

@ -8,191 +8,190 @@ using System.IO;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Runtime;
namespace Microsoft.ML.Data
namespace Microsoft.ML.Data;
/// <summary>
/// A file handle.
/// </summary>
public interface IFileHandle : IDisposable
{
/// <summary>
/// A file handle.
/// Returns whether CreateWriteStream is expected to succeed. Typically, once
/// CreateWriteStream has been called once, this will forever more return false.
/// </summary>
public interface IFileHandle : IDisposable
{
/// <summary>
/// Returns whether CreateWriteStream is expected to succeed. Typically, once
/// CreateWriteStream has been called once, this will forever more return false.
/// </summary>
bool CanWrite { get; }
/// <summary>
/// Returns whether OpenReadStream is expected to succeed.
/// </summary>
bool CanRead { get; }
/// <summary>
/// Create a writable stream for this file handle.
/// </summary>
Stream CreateWriteStream();
/// <summary>
/// Open a readable stream for this file handle.
/// </summary>
Stream OpenReadStream();
}
bool CanWrite { get; }
/// <summary>
/// A simple disk-based file handle.
/// Returns whether OpenReadStream is expected to succeed.
/// </summary>
public sealed class SimpleFileHandle : IFileHandle
bool CanRead { get; }
/// <summary>
/// Create a writable stream for this file handle.
/// </summary>
Stream CreateWriteStream();
/// <summary>
/// Open a readable stream for this file handle.
/// </summary>
Stream OpenReadStream();
}
/// <summary>
/// A simple disk-based file handle.
/// </summary>
public sealed class SimpleFileHandle : IFileHandle
{
private readonly string _fullPath;
// Exception context.
private readonly IExceptionContext _ectx;
private readonly object _lock;
// Whether to delete the file when this is disposed.
private readonly bool _autoDelete;
// Whether this file has contents. This is false if the file needs CreateWriteStream to be
// called (before OpenReadStream can be called).
private bool _wrote;
// If non-null, the active write stream. This should be disposed before the first OpenReadStream call.
private Stream _streamWrite;
// This contains the potentially active read streams. This is set to null once this file
// handle has been disposed.
private List<Stream> _streams;
private bool IsDisposed => _streams == null;
public SimpleFileHandle(IExceptionContext ectx, string path, bool needsWrite, bool autoDelete)
{
private readonly string _fullPath;
Contracts.CheckValue(ectx, nameof(ectx));
ectx.CheckNonEmpty(path, nameof(path));
// Exception context.
private readonly IExceptionContext _ectx;
_ectx = ectx;
_fullPath = Path.GetFullPath(path);
private readonly object _lock;
_autoDelete = autoDelete;
// Whether to delete the file when this is disposed.
private readonly bool _autoDelete;
// The file has already been written to iff needsWrite is false.
_wrote = !needsWrite;
// Whether this file has contents. This is false if the file needs CreateWriteStream to be
// called (before OpenReadStream can be called).
private bool _wrote;
// If non-null, the active write stream. This should be disposed before the first OpenReadStream call.
private Stream _streamWrite;
// REVIEW: Should this do some basic validation? Eg, for output files, ensure that
// the directory exists (and perhaps even create an empty file); for input files, ensure
// that the file exists (and perhaps even attempt to open it).
// This contains the potentially active read streams. This is set to null once this file
// handle has been disposed.
private List<Stream> _streams;
_lock = new object();
_streams = new List<Stream>();
}
private bool IsDisposed => _streams == null;
public bool CanWrite => !_wrote && !IsDisposed;
public SimpleFileHandle(IExceptionContext ectx, string path, bool needsWrite, bool autoDelete)
{
Contracts.CheckValue(ectx, nameof(ectx));
ectx.CheckNonEmpty(path, nameof(path));
public bool CanRead => _wrote && !IsDisposed;
_ectx = ectx;
_fullPath = Path.GetFullPath(path);
_autoDelete = autoDelete;
// The file has already been written to iff needsWrite is false.
_wrote = !needsWrite;
// REVIEW: Should this do some basic validation? Eg, for output files, ensure that
// the directory exists (and perhaps even create an empty file); for input files, ensure
// that the file exists (and perhaps even attempt to open it).
_lock = new object();
_streams = new List<Stream>();
}
public bool CanWrite => !_wrote && !IsDisposed;
public bool CanRead => _wrote && !IsDisposed;
public void Dispose()
{
lock (_lock)
{
if (IsDisposed)
return;
Contracts.Assert(_streams != null);
// REVIEW: Is it safe to dispose these streams? What if they are
// being used on other threads? Does that matter?
if (_streamWrite != null)
{
try
{
_streamWrite.CloseEx();
_streamWrite.Dispose();
}
catch
{
// REVIEW: What should we do here?
Contracts.Assert(false, "Closing a SimpleFileHandle write stream failed!");
}
_streamWrite = null;
}
foreach (var stream in _streams)
{
try
{
stream.CloseEx();
stream.Dispose();
}
catch
{
// REVIEW: What should we do here?
Contracts.Assert(false, "Closing a SimpleFileHandle read stream failed!");
}
}
_streams = null;
Contracts.Assert(IsDisposed);
if (_autoDelete)
{
try
{
// Finally, delete the file.
File.Delete(_fullPath);
}
catch
{
// REVIEW: What should we do here?
Contracts.Assert(false, "Deleting a SimpleFileHandle physical file failed!");
}
}
}
}
private void CheckNotDisposed()
public void Dispose()
{
lock (_lock)
{
if (IsDisposed)
throw _ectx.Except("SimpleFileHandle has already been disposed");
}
return;
public Stream CreateWriteStream()
{
lock (_lock)
Contracts.Assert(_streams != null);
// REVIEW: Is it safe to dispose these streams? What if they are
// being used on other threads? Does that matter?
if (_streamWrite != null)
{
CheckNotDisposed();
if (_wrote)
throw _ectx.Except("CreateWriteStream called multiple times on SimpleFileHandle");
Contracts.Assert(_streamWrite == null);
_streamWrite = new FileStream(_fullPath, FileMode.Create, FileAccess.Write);
_wrote = true;
return _streamWrite;
}
}
public Stream OpenReadStream()
{
lock (_lock)
{
CheckNotDisposed();
if (!_wrote)
throw _ectx.Except("SimpleFileHandle hasn't been written yet");
if (_streamWrite != null)
try
{
if (_streamWrite.CanWrite)
throw _ectx.Except("Write stream for SimpleFileHandle hasn't been disposed");
_streamWrite = null;
_streamWrite.CloseEx();
_streamWrite.Dispose();
}
// Drop read streams that have already been disposed.
_streams.RemoveAll(s => !s.CanRead);
var stream = new FileStream(_fullPath, FileMode.Open, FileAccess.Read, FileShare.Read);
_streams.Add(stream);
return stream;
catch
{
// REVIEW: What should we do here?
Contracts.Assert(false, "Closing a SimpleFileHandle write stream failed!");
}
_streamWrite = null;
}
foreach (var stream in _streams)
{
try
{
stream.CloseEx();
stream.Dispose();
}
catch
{
// REVIEW: What should we do here?
Contracts.Assert(false, "Closing a SimpleFileHandle read stream failed!");
}
}
_streams = null;
Contracts.Assert(IsDisposed);
if (_autoDelete)
{
try
{
// Finally, delete the file.
File.Delete(_fullPath);
}
catch
{
// REVIEW: What should we do here?
Contracts.Assert(false, "Deleting a SimpleFileHandle physical file failed!");
}
}
}
}
private void CheckNotDisposed()
{
if (IsDisposed)
throw _ectx.Except("SimpleFileHandle has already been disposed");
}
public Stream CreateWriteStream()
{
lock (_lock)
{
CheckNotDisposed();
if (_wrote)
throw _ectx.Except("CreateWriteStream called multiple times on SimpleFileHandle");
Contracts.Assert(_streamWrite == null);
_streamWrite = new FileStream(_fullPath, FileMode.Create, FileAccess.Write);
_wrote = true;
return _streamWrite;
}
}
public Stream OpenReadStream()
{
lock (_lock)
{
CheckNotDisposed();
if (!_wrote)
throw _ectx.Except("SimpleFileHandle hasn't been written yet");
if (_streamWrite != null)
{
if (_streamWrite.CanWrite)
throw _ectx.Except("Write stream for SimpleFileHandle hasn't been disposed");
_streamWrite = null;
}
// Drop read streams that have already been disposed.
_streams.RemoveAll(s => !s.CanRead);
var stream = new FileStream(_fullPath, FileMode.Open, FileAccess.Read, FileShare.Read);
_streams.Add(stream);
return stream;
}
}
}

Просмотреть файл

@ -5,319 +5,318 @@
using System;
using Microsoft.ML.Data;
namespace Microsoft.ML.Runtime
namespace Microsoft.ML.Runtime;
/// <summary>
/// A channel provider can create new channels and generic information pipes.
/// </summary>
public interface IChannelProvider : IExceptionContext
{
/// <summary>
/// A channel provider can create new channels and generic information pipes.
/// Start a standard message channel.
/// </summary>
public interface IChannelProvider : IExceptionContext
{
/// <summary>
/// Start a standard message channel.
/// </summary>
IChannel Start(string name);
IChannel Start(string name);
/// <summary>
/// Start a generic information pipe.
/// </summary>
IPipe<TMessage> StartPipe<TMessage>(string name);
/// <summary>
/// Start a generic information pipe.
/// </summary>
IPipe<TMessage> StartPipe<TMessage>(string name);
}
/// <summary>
/// Utility class for IHostEnvironment
/// </summary>
[BestFriend]
internal static class HostEnvironmentExtensions
{
/// <summary>
/// Return a file handle for an input "file".
/// </summary>
public static IFileHandle OpenInputFile(this IHostEnvironment env, string path)
{
Contracts.AssertValue(env);
Contracts.CheckNonWhiteSpace(path, nameof(path));
return new SimpleFileHandle(env, path, needsWrite: false, autoDelete: false);
}
/// <summary>
/// Utility class for IHostEnvironment
/// Create an output "file" and return a handle to it.
/// </summary>
[BestFriend]
internal static class HostEnvironmentExtensions
public static IFileHandle CreateOutputFile(this IHostEnvironment env, string path)
{
/// <summary>
/// Return a file handle for an input "file".
/// </summary>
public static IFileHandle OpenInputFile(this IHostEnvironment env, string path)
{
Contracts.AssertValue(env);
Contracts.CheckNonWhiteSpace(path, nameof(path));
return new SimpleFileHandle(env, path, needsWrite: false, autoDelete: false);
}
/// <summary>
/// Create an output "file" and return a handle to it.
/// </summary>
public static IFileHandle CreateOutputFile(this IHostEnvironment env, string path)
{
Contracts.AssertValue(env);
Contracts.CheckNonWhiteSpace(path, nameof(path));
return new SimpleFileHandle(env, path, needsWrite: true, autoDelete: false);
}
}
/// <summary>
/// The host environment interface creates hosts for components. Note that the methods of
/// this interface should be called from the main thread for the environment. To get an environment
/// to service another thread, call Fork and pass the return result to that thread.
/// </summary>
public interface IHostEnvironment : IChannelProvider, IProgressChannelProvider
{
/// <summary>
/// Create a host with the given registration name.
/// </summary>
IHost Register(string name, int? seed = null, bool? verbose = null);
/// <summary>
/// The catalog of loadable components (<see cref="LoadableClassAttribute"/>) that are available in this host.
/// </summary>
ComponentCatalog ComponentCatalog { get; }
}
[BestFriend]
internal interface ICancelable
{
/// <summary>
/// Signal to stop execution in all the hosts.
/// </summary>
void CancelExecution();
/// <summary>
/// Flag which indicates host execution has been stopped.
/// </summary>
bool IsCanceled { get; }
}
[BestFriend]
internal interface IHostEnvironmentInternal : IHostEnvironment
{
/// <summary>
/// The seed property that, if assigned, makes components requiring randomness behave deterministically.
/// </summary>
int? Seed { get; }
/// <summary>
/// The location for the temp files created by ML.NET
/// </summary>
string TempFilePath { get; set; }
/// <summary>
/// Allow falling back to run on CPU if couldn't run on GPU.
/// </summary>
bool FallbackToCpu { get; set; }
/// <summary>
/// GPU device ID to run execution on, <see langword="null" /> to run on CPU.
/// </summary>
int? GpuDeviceId { get; set; }
}
/// <summary>
/// A host is coupled to a component and provides random number generation and concurrency guidance.
/// Note that the random number generation, like the host environment methods, should be accessed only
/// from the main thread for the component.
/// </summary>
public interface IHost : IHostEnvironment
{
/// <summary>
/// The random number generator issued to this component. Note that random number
/// generators are NOT thread safe.
/// </summary>
Random Rand { get; }
}
/// <summary>
/// A generic information pipe. Note that pipes are disposable. Generally, Done should
/// be called before disposing to signal a normal shut-down of the pipe, as opposed
/// to an aborted completion.
/// </summary>
public interface IPipe<TMessage> : IExceptionContext, IDisposable
{
/// <summary>
/// The caller relinquishes ownership of the <paramref name="msg"/> object.
/// </summary>
void Send(TMessage msg);
}
/// <summary>
/// The kinds of standard channel messages.
/// Note: These values should never be changed. We can add new kinds, but don't change these values.
/// Other code bases, including native code for other projects depends on these values.
/// </summary>
public enum ChannelMessageKind
{
Trace = 0,
Info = 1,
Warning = 2,
Error = 3
}
/// <summary>
/// A flag that can be attached to a message or exception to indicate that
/// it has a certain class of sensitive data. By default, messages should be
/// specified as being of unknown sensitivity, which is to say, every
/// sensitivity flag is turned on, corresponding to <see cref="Unknown"/>.
/// Messages that are totally safe should be marked as <see cref="None"/>.
/// However, if, say, one prints out data from a file (for example, this might
/// be done when expressing parse errors), it should be flagged in that case
/// with <see cref="UserData"/>.
/// </summary>
[Flags]
public enum MessageSensitivity
{
/// <summary>
/// For non-sensitive data.
/// </summary>
None = 0,
/// <summary>
/// For messages that may contain user-data from data files.
/// </summary>
UserData = 0x1,
/// <summary>
/// For messages that contain information like column names from datasets.
/// Note that, despite being part of the schema, annotations should be treated
/// as user data, since it is often derived from user data. Note also that
/// types, despite being part of the schema, are not considered "sensitive"
/// as such, in the same way that column names might be.
/// </summary>
Schema = 0x2,
// REVIEW: Other potentially sensitive things might include
// stack traces in certain environments.
/// <summary>
/// The default value, unknown, is treated as if everything is sensitive.
/// </summary>
Unknown = ~None,
/// <summary>
/// An alias for <see cref="Unknown"/>, so it is functionally the same, except
/// semantically it communicates the idea that we want all bits set.
/// </summary>
All = Unknown,
}
/// <summary>
/// A channel message.
/// </summary>
public readonly struct ChannelMessage
{
public readonly ChannelMessageKind Kind;
public readonly MessageSensitivity Sensitivity;
private readonly string _message;
private readonly object[] _args;
/// <summary>
/// Line endings may not be normalized.
/// </summary>
public string Message => _args != null ? string.Format(_message, _args) : _message;
[BestFriend]
internal ChannelMessage(ChannelMessageKind kind, MessageSensitivity sensitivity, string message)
{
Contracts.CheckNonEmpty(message, nameof(message));
Kind = kind;
Sensitivity = sensitivity;
_message = message;
_args = null;
}
[BestFriend]
internal ChannelMessage(ChannelMessageKind kind, MessageSensitivity sensitivity, string fmt, params object[] args)
{
Contracts.CheckNonEmpty(fmt, nameof(fmt));
Contracts.CheckNonEmpty(args, nameof(args));
Kind = kind;
Sensitivity = sensitivity;
_message = fmt;
_args = args;
}
}
/// <summary>
/// A standard communication channel.
/// </summary>
public interface IChannel : IPipe<ChannelMessage>
{
void Trace(MessageSensitivity sensitivity, string fmt);
void Trace(MessageSensitivity sensitivity, string fmt, params object[] args);
void Error(MessageSensitivity sensitivity, string fmt);
void Error(MessageSensitivity sensitivity, string fmt, params object[] args);
void Warning(MessageSensitivity sensitivity, string fmt);
void Warning(MessageSensitivity sensitivity, string fmt, params object[] args);
void Info(MessageSensitivity sensitivity, string fmt);
void Info(MessageSensitivity sensitivity, string fmt, params object[] args);
}
/// <summary>
/// General utility extension methods for objects in the "host" universe, i.e.,
/// <see cref="IHostEnvironment"/>, <see cref="IHost"/>, and <see cref="IChannel"/>
/// that do not belong in more specific areas, for example, <see cref="Contracts"/> or
/// component creation.
/// </summary>
[BestFriend]
internal static class HostExtensions
{
public static T Apply<T>(this IHost host, string channelName, Func<IChannel, T> func)
{
T t;
using (var ch = host.Start(channelName))
{
t = func(ch);
}
return t;
}
/// <summary>
/// Convenience variant of <see cref="IChannel.Trace(MessageSensitivity, string)"/>
/// setting <see cref="MessageSensitivity.Unknown"/>.
/// </summary>
public static void Trace(this IChannel ch, string fmt)
=> ch.Trace(MessageSensitivity.Unknown, fmt);
/// <summary>
/// Convenience variant of <see cref="IChannel.Trace(MessageSensitivity, string, object[])"/>
/// setting <see cref="MessageSensitivity.Unknown"/>.
/// </summary>
public static void Trace(this IChannel ch, string fmt, params object[] args)
=> ch.Trace(MessageSensitivity.Unknown, fmt, args);
/// <summary>
/// Convenience variant of <see cref="IChannel.Error(MessageSensitivity, string)"/>
/// setting <see cref="MessageSensitivity.Unknown"/>.
/// </summary>
public static void Error(this IChannel ch, string fmt)
=> ch.Error(MessageSensitivity.Unknown, fmt);
/// <summary>
/// Convenience variant of <see cref="IChannel.Error(MessageSensitivity, string, object[])"/>
/// setting <see cref="MessageSensitivity.Unknown"/>.
/// </summary>
public static void Error(this IChannel ch, string fmt, params object[] args)
=> ch.Error(MessageSensitivity.Unknown, fmt, args);
/// <summary>
/// Convenience variant of <see cref="IChannel.Warning(MessageSensitivity, string)"/>
/// setting <see cref="MessageSensitivity.Unknown"/>.
/// </summary>
public static void Warning(this IChannel ch, string fmt)
=> ch.Warning(MessageSensitivity.Unknown, fmt);
/// <summary>
/// Convenience variant of <see cref="IChannel.Warning(MessageSensitivity, string, object[])"/>
/// setting <see cref="MessageSensitivity.Unknown"/>.
/// </summary>
public static void Warning(this IChannel ch, string fmt, params object[] args)
=> ch.Warning(MessageSensitivity.Unknown, fmt, args);
/// <summary>
/// Convenience variant of <see cref="IChannel.Info(MessageSensitivity, string)"/>
/// setting <see cref="MessageSensitivity.Unknown"/>.
/// </summary>
public static void Info(this IChannel ch, string fmt)
=> ch.Info(MessageSensitivity.Unknown, fmt);
/// <summary>
/// Convenience variant of <see cref="IChannel.Info(MessageSensitivity, string, object[])"/>
/// setting <see cref="MessageSensitivity.Unknown"/>.
/// </summary>
public static void Info(this IChannel ch, string fmt, params object[] args)
=> ch.Info(MessageSensitivity.Unknown, fmt, args);
Contracts.AssertValue(env);
Contracts.CheckNonWhiteSpace(path, nameof(path));
return new SimpleFileHandle(env, path, needsWrite: true, autoDelete: false);
}
}
/// <summary>
/// The host environment interface creates hosts for components. Note that the methods of
/// this interface should be called from the main thread for the environment. To get an environment
/// to service another thread, call Fork and pass the return result to that thread.
/// </summary>
public interface IHostEnvironment : IChannelProvider, IProgressChannelProvider
{
/// <summary>
/// Create a host with the given registration name.
/// </summary>
IHost Register(string name, int? seed = null, bool? verbose = null);
/// <summary>
/// The catalog of loadable components (<see cref="LoadableClassAttribute"/>) that are available in this host.
/// </summary>
ComponentCatalog ComponentCatalog { get; }
}
[BestFriend]
internal interface ICancelable
{
/// <summary>
/// Signal to stop execution in all the hosts.
/// </summary>
void CancelExecution();
/// <summary>
/// Flag which indicates host execution has been stopped.
/// </summary>
bool IsCanceled { get; }
}
[BestFriend]
internal interface IHostEnvironmentInternal : IHostEnvironment
{
/// <summary>
/// The seed property that, if assigned, makes components requiring randomness behave deterministically.
/// </summary>
int? Seed { get; }
/// <summary>
/// The location for the temp files created by ML.NET
/// </summary>
string TempFilePath { get; set; }
/// <summary>
/// Allow falling back to run on CPU if couldn't run on GPU.
/// </summary>
bool FallbackToCpu { get; set; }
/// <summary>
/// GPU device ID to run execution on, <see langword="null" /> to run on CPU.
/// </summary>
int? GpuDeviceId { get; set; }
}
/// <summary>
/// A host is coupled to a component and provides random number generation and concurrency guidance.
/// Note that the random number generation, like the host environment methods, should be accessed only
/// from the main thread for the component.
/// </summary>
public interface IHost : IHostEnvironment
{
/// <summary>
/// The random number generator issued to this component. Note that random number
/// generators are NOT thread safe.
/// </summary>
Random Rand { get; }
}
/// <summary>
/// A generic information pipe. Note that pipes are disposable. Generally, Done should
/// be called before disposing to signal a normal shut-down of the pipe, as opposed
/// to an aborted completion.
/// </summary>
public interface IPipe<TMessage> : IExceptionContext, IDisposable
{
/// <summary>
/// The caller relinquishes ownership of the <paramref name="msg"/> object.
/// </summary>
void Send(TMessage msg);
}
/// <summary>
/// The kinds of standard channel messages.
/// Note: These values should never be changed. We can add new kinds, but don't change these values.
/// Other code bases, including native code for other projects depends on these values.
/// </summary>
public enum ChannelMessageKind
{
Trace = 0,
Info = 1,
Warning = 2,
Error = 3
}
/// <summary>
/// A flag that can be attached to a message or exception to indicate that
/// it has a certain class of sensitive data. By default, messages should be
/// specified as being of unknown sensitivity, which is to say, every
/// sensitivity flag is turned on, corresponding to <see cref="Unknown"/>.
/// Messages that are totally safe should be marked as <see cref="None"/>.
/// However, if, say, one prints out data from a file (for example, this might
/// be done when expressing parse errors), it should be flagged in that case
/// with <see cref="UserData"/>.
/// </summary>
[Flags]
public enum MessageSensitivity
{
/// <summary>
/// For non-sensitive data.
/// </summary>
None = 0,
/// <summary>
/// For messages that may contain user-data from data files.
/// </summary>
UserData = 0x1,
/// <summary>
/// For messages that contain information like column names from datasets.
/// Note that, despite being part of the schema, annotations should be treated
/// as user data, since it is often derived from user data. Note also that
/// types, despite being part of the schema, are not considered "sensitive"
/// as such, in the same way that column names might be.
/// </summary>
Schema = 0x2,
// REVIEW: Other potentially sensitive things might include
// stack traces in certain environments.
/// <summary>
/// The default value, unknown, is treated as if everything is sensitive.
/// </summary>
Unknown = ~None,
/// <summary>
/// An alias for <see cref="Unknown"/>, so it is functionally the same, except
/// semantically it communicates the idea that we want all bits set.
/// </summary>
All = Unknown,
}
/// <summary>
/// A channel message.
/// </summary>
public readonly struct ChannelMessage
{
public readonly ChannelMessageKind Kind;
public readonly MessageSensitivity Sensitivity;
private readonly string _message;
private readonly object[] _args;
/// <summary>
/// Line endings may not be normalized.
/// </summary>
public string Message => _args != null ? string.Format(_message, _args) : _message;
[BestFriend]
internal ChannelMessage(ChannelMessageKind kind, MessageSensitivity sensitivity, string message)
{
Contracts.CheckNonEmpty(message, nameof(message));
Kind = kind;
Sensitivity = sensitivity;
_message = message;
_args = null;
}
[BestFriend]
internal ChannelMessage(ChannelMessageKind kind, MessageSensitivity sensitivity, string fmt, params object[] args)
{
Contracts.CheckNonEmpty(fmt, nameof(fmt));
Contracts.CheckNonEmpty(args, nameof(args));
Kind = kind;
Sensitivity = sensitivity;
_message = fmt;
_args = args;
}
}
/// <summary>
/// A standard communication channel.
/// </summary>
public interface IChannel : IPipe<ChannelMessage>
{
void Trace(MessageSensitivity sensitivity, string fmt);
void Trace(MessageSensitivity sensitivity, string fmt, params object[] args);
void Error(MessageSensitivity sensitivity, string fmt);
void Error(MessageSensitivity sensitivity, string fmt, params object[] args);
void Warning(MessageSensitivity sensitivity, string fmt);
void Warning(MessageSensitivity sensitivity, string fmt, params object[] args);
void Info(MessageSensitivity sensitivity, string fmt);
void Info(MessageSensitivity sensitivity, string fmt, params object[] args);
}
/// <summary>
/// General utility extension methods for objects in the "host" universe, i.e.,
/// <see cref="IHostEnvironment"/>, <see cref="IHost"/>, and <see cref="IChannel"/>
/// that do not belong in more specific areas, for example, <see cref="Contracts"/> or
/// component creation.
/// </summary>
[BestFriend]
internal static class HostExtensions
{
public static T Apply<T>(this IHost host, string channelName, Func<IChannel, T> func)
{
T t;
using (var ch = host.Start(channelName))
{
t = func(ch);
}
return t;
}
/// <summary>
/// Convenience variant of <see cref="IChannel.Trace(MessageSensitivity, string)"/>
/// setting <see cref="MessageSensitivity.Unknown"/>.
/// </summary>
public static void Trace(this IChannel ch, string fmt)
=> ch.Trace(MessageSensitivity.Unknown, fmt);
/// <summary>
/// Convenience variant of <see cref="IChannel.Trace(MessageSensitivity, string, object[])"/>
/// setting <see cref="MessageSensitivity.Unknown"/>.
/// </summary>
public static void Trace(this IChannel ch, string fmt, params object[] args)
=> ch.Trace(MessageSensitivity.Unknown, fmt, args);
/// <summary>
/// Convenience variant of <see cref="IChannel.Error(MessageSensitivity, string)"/>
/// setting <see cref="MessageSensitivity.Unknown"/>.
/// </summary>
public static void Error(this IChannel ch, string fmt)
=> ch.Error(MessageSensitivity.Unknown, fmt);
/// <summary>
/// Convenience variant of <see cref="IChannel.Error(MessageSensitivity, string, object[])"/>
/// setting <see cref="MessageSensitivity.Unknown"/>.
/// </summary>
public static void Error(this IChannel ch, string fmt, params object[] args)
=> ch.Error(MessageSensitivity.Unknown, fmt, args);
/// <summary>
/// Convenience variant of <see cref="IChannel.Warning(MessageSensitivity, string)"/>
/// setting <see cref="MessageSensitivity.Unknown"/>.
/// </summary>
public static void Warning(this IChannel ch, string fmt)
=> ch.Warning(MessageSensitivity.Unknown, fmt);
/// <summary>
/// Convenience variant of <see cref="IChannel.Warning(MessageSensitivity, string, object[])"/>
/// setting <see cref="MessageSensitivity.Unknown"/>.
/// </summary>
public static void Warning(this IChannel ch, string fmt, params object[] args)
=> ch.Warning(MessageSensitivity.Unknown, fmt, args);
/// <summary>
/// Convenience variant of <see cref="IChannel.Info(MessageSensitivity, string)"/>
/// setting <see cref="MessageSensitivity.Unknown"/>.
/// </summary>
public static void Info(this IChannel ch, string fmt)
=> ch.Info(MessageSensitivity.Unknown, fmt);
/// <summary>
/// Convenience variant of <see cref="IChannel.Info(MessageSensitivity, string, object[])"/>
/// setting <see cref="MessageSensitivity.Unknown"/>.
/// </summary>
public static void Info(this IChannel ch, string fmt, params object[] args)
=> ch.Info(MessageSensitivity.Unknown, fmt, args);
}

Просмотреть файл

@ -5,140 +5,139 @@
using System;
using System.Collections.Generic;
namespace Microsoft.ML.Runtime
namespace Microsoft.ML.Runtime;
/// <summary>
/// This is a factory interface for <see cref="IProgressChannel"/>.
/// Both <see cref="IHostEnvironment"/> and <see cref="IProgressChannel"/> implement this interface,
/// to allow for nested progress reporters.
///
/// REVIEW: make <see cref="IChannelProvider"/> implement this, instead of the environment?
/// </summary>
public interface IProgressChannelProvider
{
/// <summary>
/// This is a factory interface for <see cref="IProgressChannel"/>.
/// Both <see cref="IHostEnvironment"/> and <see cref="IProgressChannel"/> implement this interface,
/// to allow for nested progress reporters.
/// Create a progress channel for a computation named <paramref name="name"/>.
/// </summary>
IProgressChannel StartProgressChannel(string name);
}
/// <summary>
/// A common interface for progress reporting.
/// It is expected that the progress channel interface is used from only one thread.
///
/// Supported workflow:
/// 1) Create the channel via <see cref="IProgressChannelProvider.StartProgressChannel"/>.
/// 2) Call <see cref="SetHeader"/> as many times as desired (including 0).
/// Each call to <see cref="SetHeader"/> supersedes the previous one.
/// 3) Report checkpoints (0 or more) by calling <see cref="Checkpoint"/>.
/// 4) Repeat steps 2-3 as often as necessary.
/// 5) Dispose the channel.
/// </summary>
public interface IProgressChannel : IProgressChannelProvider, IDisposable
{
/// <summary>
/// Set up the reporting structure:
/// - Set the 'header' of the progress reports, defining which progress units and metrics are going to be reported.
/// - Provide a thread-safe delegate to be invoked whenever anyone needs to know the progress.
///
/// REVIEW: make <see cref="IChannelProvider"/> implement this, instead of the environment?
/// It is acceptable to call <see cref="SetHeader"/> multiple times (or none), regardless of whether the calculation is running
/// or not. Because of synchronization, the computation should not deny calls to the 'old' <paramref name="fillAction"/>
/// delegates even after a new one is provided.
/// </summary>
public interface IProgressChannelProvider
{
/// <summary>
/// Create a progress channel for a computation named <paramref name="name"/>.
/// </summary>
IProgressChannel StartProgressChannel(string name);
}
/// <param name="header">The header object.</param>
/// <param name="fillAction">The delegate to provide actual progress. The <see cref="IProgressEntry"/> parameter of
/// the delegate will correspond to the provided <paramref name="header"/>.</param>
void SetHeader(ProgressHeader header, Action<IProgressEntry> fillAction);
/// <summary>
/// A common interface for progress reporting.
/// It is expected that the progress channel interface is used from only one thread.
/// Submit a 'checkpoint' entry. These entries are guaranteed to be delivered to the progress listener,
/// if it is interested. Typically, this would contain some intermediate metrics, that are only calculated
/// at certain moments ('checkpoints') of the computation.
///
/// Supported workflow:
/// 1) Create the channel via <see cref="IProgressChannelProvider.StartProgressChannel"/>.
/// 2) Call <see cref="SetHeader"/> as many times as desired (including 0).
/// Each call to <see cref="SetHeader"/> supersedes the previous one.
/// 3) Report checkpoints (0 or more) by calling <see cref="Checkpoint"/>.
/// 4) Repeat steps 2-3 as often as necessary.
/// 5) Dispose the channel.
/// For example, SDCA may report a checkpoint every time it computes the loss, or LBFGS may report a checkpoint
/// every iteration.
///
/// The only parameter, <paramref name="values"/>, is interpreted in the following fashion:
/// * First MetricNames.Length items, if present, are metrics.
/// * Subsequent ProgressNames.Length items, if present, are progress units.
/// * Subsequent ProgressNames.Length items, if present, are progress limits.
/// * If any more values remain, an exception is thrown.
/// </summary>
public interface IProgressChannel : IProgressChannelProvider, IDisposable
{
/// <summary>
/// Set up the reporting structure:
/// - Set the 'header' of the progress reports, defining which progress units and metrics are going to be reported.
/// - Provide a thread-safe delegate to be invoked whenever anyone needs to know the progress.
///
/// It is acceptable to call <see cref="SetHeader"/> multiple times (or none), regardless of whether the calculation is running
/// or not. Because of synchronization, the computation should not deny calls to the 'old' <paramref name="fillAction"/>
/// delegates even after a new one is provided.
/// </summary>
/// <param name="header">The header object.</param>
/// <param name="fillAction">The delegate to provide actual progress. The <see cref="IProgressEntry"/> parameter of
/// the delegate will correspond to the provided <paramref name="header"/>.</param>
void SetHeader(ProgressHeader header, Action<IProgressEntry> fillAction);
/// <param name="values">The metrics, progress units and progress limits.</param>
void Checkpoint(params Double?[] values);
}
/// <summary>
/// Submit a 'checkpoint' entry. These entries are guaranteed to be delivered to the progress listener,
/// if it is interested. Typically, this would contain some intermediate metrics, that are only calculated
/// at certain moments ('checkpoints') of the computation.
///
/// For example, SDCA may report a checkpoint every time it computes the loss, or LBFGS may report a checkpoint
/// every iteration.
///
/// The only parameter, <paramref name="values"/>, is interpreted in the following fashion:
/// * First MetricNames.Length items, if present, are metrics.
/// * Subsequent ProgressNames.Length items, if present, are progress units.
/// * Subsequent ProgressNames.Length items, if present, are progress limits.
/// * If any more values remain, an exception is thrown.
/// </summary>
/// <param name="values">The metrics, progress units and progress limits.</param>
void Checkpoint(params Double?[] values);
/// <summary>
/// This is the 'header' of the progress report.
/// </summary>
public sealed class ProgressHeader
{
/// <summary>
/// These are the names of the progress 'units', from the least granular to the most granular.
/// For example, neural network might have {'epoch', 'example'} and FastTree might have {'tree', 'split', 'feature'}.
/// Will never be null, but can be empty.
/// </summary>
public readonly IReadOnlyList<string> UnitNames;
/// <summary>
/// These are the names of the reported metrics. For example, this could be the 'loss', 'weight updates/sec' etc.
/// Will never be null, but can be empty.
/// </summary>
public readonly IReadOnlyList<string> MetricNames;
/// <summary>
/// Initialize the header. This will take ownership of the arrays.
/// Both arrays can be null, even simultaneously. This 'empty' header indicated that the calculation doesn't report
/// any units of progress, but the tracker can still track start, stop and elapsed time. Of course, if there's any
/// progress or metrics to report, it is always better to report them.
/// </summary>
/// <param name="metricNames">The metrics that the calculation reports. These are completely independent, and there
/// is no contract on whether the metric values should increase or not. As naming convention, <paramref name="metricNames"/>
/// can have multiple words with spaces, and should be title-cased.</param>
/// <param name="unitNames">The names of the progress units, listed from least granular to most granular.
/// The idea is that the progress should be lexicographically increasing (like [0,0], [0,10], [1,0], [1,15], [2,5] etc.).
/// As naming convention, <paramref name="unitNames"/> should be lower-cased and typically plural
/// (for example, iterations, clusters, examples). </param>
public ProgressHeader(string[] metricNames, string[] unitNames)
{
Contracts.CheckValueOrNull(unitNames);
Contracts.CheckValueOrNull(metricNames);
UnitNames = unitNames ?? new string[0];
MetricNames = metricNames ?? new string[0];
}
/// <summary>
/// This is the 'header' of the progress report.
/// A constructor for no metrics, just progress units. As naming convention, <paramref name="unitNames"/> should be lower-cased
/// and typically plural (for example, iterations, clusters, examples).
/// </summary>
public sealed class ProgressHeader
public ProgressHeader(params string[] unitNames)
: this(null, unitNames)
{
/// <summary>
/// These are the names of the progress 'units', from the least granular to the most granular.
/// For example, neural network might have {'epoch', 'example'} and FastTree might have {'tree', 'split', 'feature'}.
/// Will never be null, but can be empty.
/// </summary>
public readonly IReadOnlyList<string> UnitNames;
/// <summary>
/// These are the names of the reported metrics. For example, this could be the 'loss', 'weight updates/sec' etc.
/// Will never be null, but can be empty.
/// </summary>
public readonly IReadOnlyList<string> MetricNames;
/// <summary>
/// Initialize the header. This will take ownership of the arrays.
/// Both arrays can be null, even simultaneously. This 'empty' header indicated that the calculation doesn't report
/// any units of progress, but the tracker can still track start, stop and elapsed time. Of course, if there's any
/// progress or metrics to report, it is always better to report them.
/// </summary>
/// <param name="metricNames">The metrics that the calculation reports. These are completely independent, and there
/// is no contract on whether the metric values should increase or not. As naming convention, <paramref name="metricNames"/>
/// can have multiple words with spaces, and should be title-cased.</param>
/// <param name="unitNames">The names of the progress units, listed from least granular to most granular.
/// The idea is that the progress should be lexicographically increasing (like [0,0], [0,10], [1,0], [1,15], [2,5] etc.).
/// As naming convention, <paramref name="unitNames"/> should be lower-cased and typically plural
/// (for example, iterations, clusters, examples). </param>
public ProgressHeader(string[] metricNames, string[] unitNames)
{
Contracts.CheckValueOrNull(unitNames);
Contracts.CheckValueOrNull(metricNames);
UnitNames = unitNames ?? new string[0];
MetricNames = metricNames ?? new string[0];
}
/// <summary>
/// A constructor for no metrics, just progress units. As naming convention, <paramref name="unitNames"/> should be lower-cased
/// and typically plural (for example, iterations, clusters, examples).
/// </summary>
public ProgressHeader(params string[] unitNames)
: this(null, unitNames)
{
}
}
/// <summary>
/// A metric/progress holder item.
/// </summary>
public interface IProgressEntry
{
/// <summary>
/// Set the progress value for the index <paramref name="index"/> to <paramref name="value"/>,
/// and the limit value for the progress becomes 'unknown'.
/// </summary>
void SetProgress(int index, Double value);
/// <summary>
/// Set the progress value for the index <paramref name="index"/> to <paramref name="value"/>,
/// and the limit value to <paramref name="lim"/>. If <paramref name="lim"/> is a NAN, it is set to null instead.
/// </summary>
void SetProgress(int index, Double value, Double lim);
/// <summary>
/// Sets the metric with index <paramref name="index"/> to <paramref name="value"/>.
/// </summary>
void SetMetric(int index, Double value);
}
}
/// <summary>
/// A metric/progress holder item.
/// </summary>
public interface IProgressEntry
{
/// <summary>
/// Set the progress value for the index <paramref name="index"/> to <paramref name="value"/>,
/// and the limit value for the progress becomes 'unknown'.
/// </summary>
void SetProgress(int index, Double value);
/// <summary>
/// Set the progress value for the index <paramref name="index"/> to <paramref name="value"/>,
/// and the limit value to <paramref name="lim"/>. If <paramref name="lim"/> is a NAN, it is set to null instead.
/// </summary>
void SetProgress(int index, Double value, Double lim);
/// <summary>
/// Sets the metric with index <paramref name="index"/> to <paramref name="value"/>.
/// </summary>
void SetMetric(int index, Double value);
}

Просмотреть файл

@ -5,47 +5,46 @@
using System;
using System.Collections.Generic;
namespace Microsoft.ML.Data
namespace Microsoft.ML.Data;
/// <summary>
/// This interface maps an input <see cref="DataViewRow"/> to an output <see cref="DataViewRow"/>. Typically, the output contains
/// both the input columns and new columns added by the implementing class, although some implementations may
/// return a subset of the input columns.
/// This interface is similar to <see cref="ISchemaBoundRowMapper"/>, except it does not have any input role mappings,
/// so to rebind, the same input column names must be used.
/// Implementations of this interface are typically created over defined input <see cref="DataViewSchema"/>.
/// </summary>
public interface IRowToRowMapper
{
/// <summary>
/// This interface maps an input <see cref="DataViewRow"/> to an output <see cref="DataViewRow"/>. Typically, the output contains
/// both the input columns and new columns added by the implementing class, although some implementations may
/// return a subset of the input columns.
/// This interface is similar to <see cref="ISchemaBoundRowMapper"/>, except it does not have any input role mappings,
/// so to rebind, the same input column names must be used.
/// Implementations of this interface are typically created over defined input <see cref="DataViewSchema"/>.
/// Mappers are defined as accepting inputs with this very specific schema.
/// </summary>
public interface IRowToRowMapper
{
/// <summary>
/// Mappers are defined as accepting inputs with this very specific schema.
/// </summary>
DataViewSchema InputSchema { get; }
DataViewSchema InputSchema { get; }
/// <summary>
/// Gets an instance of <see cref="DataViewSchema"/> which describes the columns' names and types in the output generated by this mapper.
/// </summary>
DataViewSchema OutputSchema { get; }
/// <summary>
/// Gets an instance of <see cref="DataViewSchema"/> which describes the columns' names and types in the output generated by this mapper.
/// </summary>
DataViewSchema OutputSchema { get; }
/// <summary>
/// Given a set of columns, return the input columns that are needed to generate those output columns.
/// </summary>
IEnumerable<DataViewSchema.Column> GetDependencies(IEnumerable<DataViewSchema.Column> dependingColumns);
/// <summary>
/// Given a set of columns, return the input columns that are needed to generate those output columns.
/// </summary>
IEnumerable<DataViewSchema.Column> GetDependencies(IEnumerable<DataViewSchema.Column> dependingColumns);
/// <summary>
/// Get an <see cref="DataViewRow"/> with the indicated active columns, based on the input <paramref name="input"/>.
/// Getting values on inactive columns of the returned row will throw.
///
/// The <see cref="DataViewRow.Schema"/> of <paramref name="input"/> should be the same object as
/// <see cref="InputSchema"/>. Implementors of this method should throw if that is not the case. Conversely,
/// the returned value must have the same schema as <see cref="OutputSchema"/>.
///
/// This method creates a live connection between the input <see cref="DataViewRow"/> and the output <see
/// cref="DataViewRow"/>. In particular, when the getters of the output <see cref="DataViewRow"/> are invoked, they invoke the
/// getters of the input row and base the output values on the current values of the input <see cref="DataViewRow"/>.
/// The output <see cref="DataViewRow"/> values are re-computed when requested through the getters. Also, the returned
/// <see cref="DataViewRow"/> will dispose <paramref name="input"/> when it is disposed.
/// </summary>
DataViewRow GetRow(DataViewRow input, IEnumerable<DataViewSchema.Column> activeColumns);
}
/// <summary>
/// Get an <see cref="DataViewRow"/> with the indicated active columns, based on the input <paramref name="input"/>.
/// Getting values on inactive columns of the returned row will throw.
///
/// The <see cref="DataViewRow.Schema"/> of <paramref name="input"/> should be the same object as
/// <see cref="InputSchema"/>. Implementors of this method should throw if that is not the case. Conversely,
/// the returned value must have the same schema as <see cref="OutputSchema"/>.
///
/// This method creates a live connection between the input <see cref="DataViewRow"/> and the output <see
/// cref="DataViewRow"/>. In particular, when the getters of the output <see cref="DataViewRow"/> are invoked, they invoke the
/// getters of the input row and base the output values on the current values of the input <see cref="DataViewRow"/>.
/// The output <see cref="DataViewRow"/> values are re-computed when requested through the getters. Also, the returned
/// <see cref="DataViewRow"/> will dispose <paramref name="input"/> when it is disposed.
/// </summary>
DataViewRow GetRow(DataViewRow input, IEnumerable<DataViewSchema.Column> activeColumns);
}

Просмотреть файл

@ -6,85 +6,84 @@ using System;
using System.Collections.Generic;
using Microsoft.ML.Runtime;
namespace Microsoft.ML.Data
namespace Microsoft.ML.Data;
/// <summary>
/// A mapper that can be bound to a <see cref="RoleMappedSchema"/> (which encapsulates a <see cref="DataViewSchema"/> and has mappings from column kinds
/// to columns of that schema). Binding an <see cref="ISchemaBindableMapper"/> to a <see cref="RoleMappedSchema"/> produces an
/// <see cref="ISchemaBoundMapper"/>, which is an interface that has methods to return the names and indices of the input columns
/// needed by the mapper to compute its output. The <see cref="ISchemaBoundRowMapper"/> is an extention to this interface, that
/// can also produce an output <see cref="DataViewRow"/> given an input <see cref="DataViewRow"/>. The <see cref="DataViewRow"/> produced generally contains only the output columns of the mapper, and not
/// the input columns (but there is nothing preventing an <see cref="ISchemaBoundRowMapper"/> from mapping input columns directly to outputs).
/// This interface is implemented by wrappers of IValueMapper based predictors, which are predictors that take a single
/// features column. New predictors can implement <see cref="ISchemaBindableMapper"/> directly. Implementing <see cref="ISchemaBindableMapper"/>
/// includes implementing a corresponding <see cref="ISchemaBoundMapper"/> (or <see cref="ISchemaBoundRowMapper"/>) and a corresponding ISchema
/// for the output schema of the <see cref="ISchemaBoundMapper"/>. In case the <see cref="ISchemaBoundRowMapper"/> interface is implemented,
/// the SimpleRow class can be used in the <see cref="IRowToRowMapper.GetRow"/> method.
/// </summary>
[BestFriend]
internal interface ISchemaBindableMapper
{
ISchemaBoundMapper Bind(IHostEnvironment env, RoleMappedSchema schema);
}
/// <summary>
/// This interface is used to map a schema from input columns to output columns. The <see cref="ISchemaBoundMapper"/> should keep track
/// of the input columns that are needed for the mapping.
/// </summary>
[BestFriend]
internal interface ISchemaBoundMapper
{
/// <summary>
/// A mapper that can be bound to a <see cref="RoleMappedSchema"/> (which encapsulates a <see cref="DataViewSchema"/> and has mappings from column kinds
/// to columns of that schema). Binding an <see cref="ISchemaBindableMapper"/> to a <see cref="RoleMappedSchema"/> produces an
/// <see cref="ISchemaBoundMapper"/>, which is an interface that has methods to return the names and indices of the input columns
/// needed by the mapper to compute its output. The <see cref="ISchemaBoundRowMapper"/> is an extention to this interface, that
/// can also produce an output <see cref="DataViewRow"/> given an input <see cref="DataViewRow"/>. The <see cref="DataViewRow"/> produced generally contains only the output columns of the mapper, and not
/// the input columns (but there is nothing preventing an <see cref="ISchemaBoundRowMapper"/> from mapping input columns directly to outputs).
/// This interface is implemented by wrappers of IValueMapper based predictors, which are predictors that take a single
/// features column. New predictors can implement <see cref="ISchemaBindableMapper"/> directly. Implementing <see cref="ISchemaBindableMapper"/>
/// includes implementing a corresponding <see cref="ISchemaBoundMapper"/> (or <see cref="ISchemaBoundRowMapper"/>) and a corresponding ISchema
/// for the output schema of the <see cref="ISchemaBoundMapper"/>. In case the <see cref="ISchemaBoundRowMapper"/> interface is implemented,
/// the SimpleRow class can be used in the <see cref="IRowToRowMapper.GetRow"/> method.
/// The <see cref="RoleMappedSchema"/> that was passed to the <see cref="ISchemaBoundMapper"/> in the binding process.
/// </summary>
[BestFriend]
internal interface ISchemaBindableMapper
{
ISchemaBoundMapper Bind(IHostEnvironment env, RoleMappedSchema schema);
}
RoleMappedSchema InputRoleMappedSchema { get; }
/// <summary>
/// This interface is used to map a schema from input columns to output columns. The <see cref="ISchemaBoundMapper"/> should keep track
/// of the input columns that are needed for the mapping.
/// Gets schema of this mapper's output.
/// </summary>
[BestFriend]
internal interface ISchemaBoundMapper
{
/// <summary>
/// The <see cref="RoleMappedSchema"/> that was passed to the <see cref="ISchemaBoundMapper"/> in the binding process.
/// </summary>
RoleMappedSchema InputRoleMappedSchema { get; }
/// <summary>
/// Gets schema of this mapper's output.
/// </summary>
DataViewSchema OutputSchema { get; }
/// <summary>
/// A property to get back the <see cref="ISchemaBindableMapper"/> that produced this <see cref="ISchemaBoundMapper"/>.
/// </summary>
ISchemaBindableMapper Bindable { get; }
/// <summary>
/// This method returns the binding information: which input columns are used and in what roles.
/// </summary>
IEnumerable<KeyValuePair<RoleMappedSchema.ColumnRole, string>> GetInputColumnRoles();
}
DataViewSchema OutputSchema { get; }
/// <summary>
/// This interface extends <see cref="ISchemaBoundMapper"/>.
/// A property to get back the <see cref="ISchemaBindableMapper"/> that produced this <see cref="ISchemaBoundMapper"/>.
/// </summary>
[BestFriend]
internal interface ISchemaBoundRowMapper : ISchemaBoundMapper
{
/// <summary>
/// Input schema accepted.
/// </summary>
DataViewSchema InputSchema { get; }
ISchemaBindableMapper Bindable { get; }
/// <summary>
/// Given a set of columns, from the newly generated ones, return the input columns that are needed to generate those output columns.
/// </summary>
IEnumerable<DataViewSchema.Column> GetDependenciesForNewColumns(IEnumerable<DataViewSchema.Column> dependingColumns);
/// <summary>
/// Get an <see cref="DataViewRow"/> with the indicated active columns, based on the input <paramref name="input"/>.
/// Getting values on inactive columns of the returned row will throw.
///
/// The <see cref="DataViewRow.Schema"/> of <paramref name="input"/> should be the same object as
/// <see cref="InputSchema"/>. Implementors of this method should throw if that is not the case. Conversely,
/// the returned value must have the same schema as <see cref="ISchemaBoundMapper.OutputSchema"/>.
///
/// This method creates a live connection between the input <see cref="DataViewRow"/> and the output <see
/// cref="DataViewRow"/>. In particular, when the getters of the output <see cref="DataViewRow"/> are invoked, they invoke the
/// getters of the input row and base the output values on the current values of the input <see cref="DataViewRow"/>.
/// The output <see cref="DataViewRow"/> values are re-computed when requested through the getters. Also, the returned
/// <see cref="DataViewRow"/> will dispose <paramref name="input"/> when it is disposed.
/// </summary>
DataViewRow GetRow(DataViewRow input, IEnumerable<DataViewSchema.Column> activeColumns);
}
/// <summary>
/// This method returns the binding information: which input columns are used and in what roles.
/// </summary>
IEnumerable<KeyValuePair<RoleMappedSchema.ColumnRole, string>> GetInputColumnRoles();
}
/// <summary>
/// This interface extends <see cref="ISchemaBoundMapper"/>.
/// </summary>
[BestFriend]
internal interface ISchemaBoundRowMapper : ISchemaBoundMapper
{
/// <summary>
/// Input schema accepted.
/// </summary>
DataViewSchema InputSchema { get; }
/// <summary>
/// Given a set of columns, from the newly generated ones, return the input columns that are needed to generate those output columns.
/// </summary>
IEnumerable<DataViewSchema.Column> GetDependenciesForNewColumns(IEnumerable<DataViewSchema.Column> dependingColumns);
/// <summary>
/// Get an <see cref="DataViewRow"/> with the indicated active columns, based on the input <paramref name="input"/>.
/// Getting values on inactive columns of the returned row will throw.
///
/// The <see cref="DataViewRow.Schema"/> of <paramref name="input"/> should be the same object as
/// <see cref="InputSchema"/>. Implementors of this method should throw if that is not the case. Conversely,
/// the returned value must have the same schema as <see cref="ISchemaBoundMapper.OutputSchema"/>.
///
/// This method creates a live connection between the input <see cref="DataViewRow"/> and the output <see
/// cref="DataViewRow"/>. In particular, when the getters of the output <see cref="DataViewRow"/> are invoked, they invoke the
/// getters of the input row and base the output values on the current values of the input <see cref="DataViewRow"/>.
/// The output <see cref="DataViewRow"/> values are re-computed when requested through the getters. Also, the returned
/// <see cref="DataViewRow"/> will dispose <paramref name="input"/> when it is disposed.
/// </summary>
DataViewRow GetRow(DataViewRow input, IEnumerable<DataViewSchema.Column> activeColumns);
}

Просмотреть файл

@ -2,57 +2,56 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
namespace Microsoft.ML.Data
namespace Microsoft.ML.Data;
/// <summary>
/// Delegate type to map/convert a value.
/// </summary>
[BestFriend]
internal delegate void ValueMapper<TSrc, TDst>(in TSrc src, ref TDst dst);
/// <summary>
/// Delegate type to map/convert among three values, for example, one input with two
/// outputs, or two inputs with one output.
/// </summary>
[BestFriend]
internal delegate void ValueMapper<TVal1, TVal2, TVal3>(in TVal1 val1, ref TVal2 val2, ref TVal3 val3);
/// <summary>
/// Interface for mapping a single input value (of an indicated ColumnType) to
/// an output value (of an indicated ColumnType). This interface is commonly implemented
/// by predictors. Note that the input and output ColumnTypes determine the proper
/// type arguments for GetMapper, but typically contain additional information like
/// vector lengths.
/// </summary>
[BestFriend]
internal interface IValueMapper
{
/// <summary>
/// Delegate type to map/convert a value.
/// </summary>
[BestFriend]
internal delegate void ValueMapper<TSrc, TDst>(in TSrc src, ref TDst dst);
DataViewType InputType { get; }
DataViewType OutputType { get; }
/// <summary>
/// Delegate type to map/convert among three values, for example, one input with two
/// outputs, or two inputs with one output.
/// Get a delegate used for mapping from input to output values. Note that the delegate
/// should only be used on a single thread - it should NOT be assumed to be safe for concurrency.
/// </summary>
[BestFriend]
internal delegate void ValueMapper<TVal1, TVal2, TVal3>(in TVal1 val1, ref TVal2 val2, ref TVal3 val3);
/// <summary>
/// Interface for mapping a single input value (of an indicated ColumnType) to
/// an output value (of an indicated ColumnType). This interface is commonly implemented
/// by predictors. Note that the input and output ColumnTypes determine the proper
/// type arguments for GetMapper, but typically contain additional information like
/// vector lengths.
/// </summary>
[BestFriend]
internal interface IValueMapper
{
DataViewType InputType { get; }
DataViewType OutputType { get; }
/// <summary>
/// Get a delegate used for mapping from input to output values. Note that the delegate
/// should only be used on a single thread - it should NOT be assumed to be safe for concurrency.
/// </summary>
ValueMapper<TSrc, TDst> GetMapper<TSrc, TDst>();
}
/// <summary>
/// Interface for mapping a single input value (of an indicated ColumnType) to an output value
/// plus distribution value (of indicated ColumnTypes). This interface is commonly implemented
/// by predictors. Note that the input, output, and distribution ColumnTypes determine the proper
/// type arguments for GetMapper, but typically contain additional information like
/// vector lengths.
/// </summary>
[BestFriend]
internal interface IValueMapperDist : IValueMapper
{
DataViewType DistType { get; }
/// <summary>
/// Get a delegate used for mapping from input to output values. Note that the delegate
/// should only be used on a single thread - it should NOT be assumed to be safe for concurrency.
/// </summary>
ValueMapper<TSrc, TDst, TDist> GetMapper<TSrc, TDst, TDist>();
}
ValueMapper<TSrc, TDst> GetMapper<TSrc, TDst>();
}
/// <summary>
/// Interface for mapping a single input value (of an indicated ColumnType) to an output value
/// plus distribution value (of indicated ColumnTypes). This interface is commonly implemented
/// by predictors. Note that the input, output, and distribution ColumnTypes determine the proper
/// type arguments for GetMapper, but typically contain additional information like
/// vector lengths.
/// </summary>
[BestFriend]
internal interface IValueMapperDist : IValueMapper
{
DataViewType DistType { get; }
/// <summary>
/// Get a delegate used for mapping from input to output values. Note that the delegate
/// should only be used on a single thread - it should NOT be assumed to be safe for concurrency.
/// </summary>
ValueMapper<TSrc, TDst, TDist> GetMapper<TSrc, TDst, TDist>();
}

Просмотреть файл

@ -2,8 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
namespace Microsoft.ML.Data
{
[BestFriend]
internal delegate bool InPredicate<T>(in T value);
}
namespace Microsoft.ML.Data;
[BestFriend]
internal delegate bool InPredicate<T>(in T value);

Просмотреть файл

@ -4,21 +4,20 @@
using Microsoft.ML.Runtime;
namespace Microsoft.ML.Data
namespace Microsoft.ML.Data;
/// <summary>
/// Extension methods related to the <see cref="KeyDataViewType"/> class.
/// </summary>
[BestFriend]
internal static class KeyTypeExtensions
{
/// <summary>
/// Extension methods related to the <see cref="KeyDataViewType"/> class.
/// Sometimes it is necessary to cast the Count to an int. This performs overflow check.
/// </summary>
[BestFriend]
internal static class KeyTypeExtensions
public static int GetCountAsInt32(this KeyDataViewType key, IExceptionContext ectx = null)
{
/// <summary>
/// Sometimes it is necessary to cast the Count to an int. This performs overflow check.
/// </summary>
public static int GetCountAsInt32(this KeyDataViewType key, IExceptionContext ectx = null)
{
ectx.Check(key.Count <= int.MaxValue, nameof(KeyDataViewType) + "." + nameof(KeyDataViewType.Count) + " exceeds int.MaxValue.");
return (int)key.Count;
}
ectx.Check(key.Count <= int.MaxValue, nameof(KeyDataViewType) + "." + nameof(KeyDataViewType.Count) + " exceeds int.MaxValue.");
return (int)key.Count;
}
}

Просмотреть файл

@ -4,50 +4,49 @@
using Microsoft.ML.Runtime;
namespace Microsoft.ML.Data
namespace Microsoft.ML.Data;
/// <summary>
/// Base class for a cursor has an input cursor, but still needs to do work on <see cref="DataViewRowCursor.MoveNext"/>.
/// </summary>
[BestFriend]
internal abstract class LinkedRootCursorBase : RootCursorBase
{
/// <summary>Gets the input cursor.</summary>
protected DataViewRowCursor Input { get; }
/// <summary>
/// Base class for a cursor has an input cursor, but still needs to do work on <see cref="DataViewRowCursor.MoveNext"/>.
/// Returns the root cursor of the input. It should be used to perform <see cref="DataViewRowCursor.MoveNext"/>
/// operations, but with the distinction, as compared to <see cref="SynchronizedCursorBase"/>, that this is not
/// a simple passthrough, but rather very implementation specific. For example, a common usage of this class is
/// on filter cursor implementations, where how that input cursor is consumed is very implementation specific.
/// That is why this is <see langword="protected"/>, not <see langword="private"/>.
/// </summary>
[BestFriend]
internal abstract class LinkedRootCursorBase : RootCursorBase
protected DataViewRowCursor Root { get; }
private bool _disposed;
protected LinkedRootCursorBase(IChannelProvider provider, DataViewRowCursor input)
: base(provider)
{
Ch.AssertValue(input, nameof(input));
/// <summary>Gets the input cursor.</summary>
protected DataViewRowCursor Input { get; }
Input = input;
Root = Input is SynchronizedCursorBase snycInput ? snycInput.Root : input;
}
/// <summary>
/// Returns the root cursor of the input. It should be used to perform <see cref="DataViewRowCursor.MoveNext"/>
/// operations, but with the distinction, as compared to <see cref="SynchronizedCursorBase"/>, that this is not
/// a simple passthrough, but rather very implementation specific. For example, a common usage of this class is
/// on filter cursor implementations, where how that input cursor is consumed is very implementation specific.
/// That is why this is <see langword="protected"/>, not <see langword="private"/>.
/// </summary>
protected DataViewRowCursor Root { get; }
private bool _disposed;
protected LinkedRootCursorBase(IChannelProvider provider, DataViewRowCursor input)
: base(provider)
protected override void Dispose(bool disposing)
{
if (_disposed)
return;
if (disposing)
{
Ch.AssertValue(input, nameof(input));
Input.Dispose();
// The base class should set the state to done under these circumstances.
Input = input;
Root = Input is SynchronizedCursorBase snycInput ? snycInput.Root : input;
}
protected override void Dispose(bool disposing)
{
if (_disposed)
return;
if (disposing)
{
Input.Dispose();
// The base class should set the state to done under these circumstances.
}
_disposed = true;
base.Dispose(disposing);
}
_disposed = true;
base.Dispose(disposing);
}
}

Просмотреть файл

@ -4,40 +4,39 @@
using Microsoft.ML.Runtime;
namespace Microsoft.ML.Data
namespace Microsoft.ML.Data;
/// <summary>
/// Base class for creating a cursor of rows that filters out some input rows.
/// </summary>
[BestFriend]
internal abstract class LinkedRowFilterCursorBase : LinkedRowRootCursorBase
{
/// <summary>
/// Base class for creating a cursor of rows that filters out some input rows.
/// </summary>
[BestFriend]
internal abstract class LinkedRowFilterCursorBase : LinkedRowRootCursorBase
public override long Batch => Input.Batch;
protected LinkedRowFilterCursorBase(IChannelProvider provider, DataViewRowCursor input, DataViewSchema schema, bool[] active)
: base(provider, input, schema, active)
{
public override long Batch => Input.Batch;
protected LinkedRowFilterCursorBase(IChannelProvider provider, DataViewRowCursor input, DataViewSchema schema, bool[] active)
: base(provider, input, schema, active)
{
}
public override ValueGetter<DataViewRowId> GetIdGetter()
{
return Input.GetIdGetter();
}
protected override bool MoveNextCore()
{
while (Root.MoveNext())
{
if (Accept())
return true;
}
return false;
}
/// <summary>
/// Return whether the current input row should be returned by this cursor.
/// </summary>
protected abstract bool Accept();
}
public override ValueGetter<DataViewRowId> GetIdGetter()
{
return Input.GetIdGetter();
}
protected override bool MoveNextCore()
{
while (Root.MoveNext())
{
if (Accept())
return true;
}
return false;
}
/// <summary>
/// Return whether the current input row should be returned by this cursor.
/// </summary>
protected abstract bool Accept();
}

Просмотреть файл

@ -4,50 +4,49 @@
using Microsoft.ML.Runtime;
namespace Microsoft.ML.Data
namespace Microsoft.ML.Data;
/// <summary>
/// A base class for a <see cref="DataViewRowCursor"/> that has an input cursor, but still needs to do work on
/// <see cref="DataViewRowCursor.MoveNext"/>. Note that the default
/// <see cref="LinkedRowRootCursorBase.GetGetter{TValue}(DataViewSchema.Column)"/> assumes that each input column is exposed as an
/// output column with the same column index.
/// </summary>
[BestFriend]
internal abstract class LinkedRowRootCursorBase : LinkedRootCursorBase
{
/// <summary>
/// A base class for a <see cref="DataViewRowCursor"/> that has an input cursor, but still needs to do work on
/// <see cref="DataViewRowCursor.MoveNext"/>. Note that the default
/// <see cref="LinkedRowRootCursorBase.GetGetter{TValue}(DataViewSchema.Column)"/> assumes that each input column is exposed as an
/// output column with the same column index.
/// </summary>
[BestFriend]
internal abstract class LinkedRowRootCursorBase : LinkedRootCursorBase
private readonly bool[] _active;
/// <summary>Gets row's schema.</summary>
public sealed override DataViewSchema Schema { get; }
protected LinkedRowRootCursorBase(IChannelProvider provider, DataViewRowCursor input, DataViewSchema schema, bool[] active)
: base(provider, input)
{
private readonly bool[] _active;
Ch.CheckValue(schema, nameof(schema));
Ch.Check(active == null || active.Length == schema.Count);
_active = active;
Schema = schema;
}
/// <summary>Gets row's schema.</summary>
public sealed override DataViewSchema Schema { get; }
/// <summary>
/// Returns whether the given column is active in this row.
/// </summary>
public sealed override bool IsColumnActive(DataViewSchema.Column column)
{
Ch.Check(column.Index < Schema.Count);
return _active == null || _active[column.Index];
}
protected LinkedRowRootCursorBase(IChannelProvider provider, DataViewRowCursor input, DataViewSchema schema, bool[] active)
: base(provider, input)
{
Ch.CheckValue(schema, nameof(schema));
Ch.Check(active == null || active.Length == schema.Count);
_active = active;
Schema = schema;
}
/// <summary>
/// Returns whether the given column is active in this row.
/// </summary>
public sealed override bool IsColumnActive(DataViewSchema.Column column)
{
Ch.Check(column.Index < Schema.Count);
return _active == null || _active[column.Index];
}
/// <summary>
/// Returns a value getter delegate to fetch the value of column with the given columnIndex, from the row.
/// This throws if the column is not active in this row, or if the type
/// <typeparamref name="TValue"/> differs from this column's type.
/// </summary>
/// <typeparam name="TValue"> is the column's content type.</typeparam>
/// <param name="column"> is the output column whose getter should be returned.</param>
public override ValueGetter<TValue> GetGetter<TValue>(DataViewSchema.Column column)
{
return Input.GetGetter<TValue>(column);
}
/// <summary>
/// Returns a value getter delegate to fetch the value of column with the given columnIndex, from the row.
/// This throws if the column is not active in this row, or if the type
/// <typeparamref name="TValue"/> differs from this column's type.
/// </summary>
/// <typeparam name="TValue"> is the column's content type.</typeparam>
/// <param name="column"> is the output column whose getter should be returned.</param>
public override ValueGetter<TValue> GetGetter<TValue>(DataViewSchema.Column column)
{
return Input.GetGetter<TValue>(column);
}
}

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Просмотреть файл

@ -8,183 +8,182 @@ using System.Text;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Runtime;
namespace Microsoft.ML
namespace Microsoft.ML;
/// <summary>
/// This is a convenience context object for loading models from a repository, for
/// implementors of ICanSaveModel. It is not mandated but designed to reduce the
/// amount of boiler plate code. It can also be used when loading from a single stream,
/// for implementors of ICanSaveInBinaryFormat.
/// </summary>
[BestFriend]
internal sealed partial class ModelLoadContext : IDisposable
{
/// <summary>
/// This is a convenience context object for loading models from a repository, for
/// implementors of ICanSaveModel. It is not mandated but designed to reduce the
/// amount of boiler plate code. It can also be used when loading from a single stream,
/// for implementors of ICanSaveInBinaryFormat.
/// When in repository mode, this is the repository we're reading from. It is null when
/// in single-stream mode.
/// </summary>
public readonly RepositoryReader Repository;
/// <summary>
/// When in repository mode, this is the directory we're reading from. Null means the root
/// of the repository. It is always null in single-stream mode.
/// </summary>
public readonly string Directory;
/// <summary>
/// The main stream reader.
/// </summary>
public readonly BinaryReader Reader;
/// <summary>
/// The strings loaded from the main stream's string table.
/// </summary>
public readonly string[] Strings;
/// <summary>
/// The name of the assembly that the loader lives in.
/// </summary>
/// <remarks>
/// This may be null or empty if one was never written to the model, or is an older model version.
/// </remarks>
public readonly string LoaderAssemblyName;
/// <summary>
/// The main stream's model header.
/// </summary>
[BestFriend]
internal sealed partial class ModelLoadContext : IDisposable
internal ModelHeader Header;
/// <summary>
/// The min file position of the main stream.
/// </summary>
public readonly long FpMin;
/// <summary>
/// Exception context provided by Repository (can be null).
/// </summary>
private readonly IExceptionContext _ectx;
/// <summary>
/// Returns whether this context is in repository mode (true) or single-stream mode (false).
/// </summary>
public bool InRepository { get { return Repository != null; } }
/// <summary>
/// Create a ModelLoadContext supporting loading from a repository, for implementors of ICanSaveModel.
/// </summary>
internal ModelLoadContext(RepositoryReader rep, Repository.Entry ent, string dir)
{
/// <summary>
/// When in repository mode, this is the repository we're reading from. It is null when
/// in single-stream mode.
/// </summary>
public readonly RepositoryReader Repository;
Contracts.CheckValue(rep, nameof(rep));
Repository = rep;
_ectx = rep.ExceptionContext;
/// <summary>
/// When in repository mode, this is the directory we're reading from. Null means the root
/// of the repository. It is always null in single-stream mode.
/// </summary>
public readonly string Directory;
_ectx.CheckValue(ent, nameof(ent));
_ectx.CheckValueOrNull(dir);
/// <summary>
/// The main stream reader.
/// </summary>
public readonly BinaryReader Reader;
Directory = dir;
/// <summary>
/// The strings loaded from the main stream's string table.
/// </summary>
public readonly string[] Strings;
/// <summary>
/// The name of the assembly that the loader lives in.
/// </summary>
/// <remarks>
/// This may be null or empty if one was never written to the model, or is an older model version.
/// </remarks>
public readonly string LoaderAssemblyName;
/// <summary>
/// The main stream's model header.
/// </summary>
[BestFriend]
internal ModelHeader Header;
/// <summary>
/// The min file position of the main stream.
/// </summary>
public readonly long FpMin;
/// <summary>
/// Exception context provided by Repository (can be null).
/// </summary>
private readonly IExceptionContext _ectx;
/// <summary>
/// Returns whether this context is in repository mode (true) or single-stream mode (false).
/// </summary>
public bool InRepository { get { return Repository != null; } }
/// <summary>
/// Create a ModelLoadContext supporting loading from a repository, for implementors of ICanSaveModel.
/// </summary>
internal ModelLoadContext(RepositoryReader rep, Repository.Entry ent, string dir)
Reader = new BinaryReader(ent.Stream, Encoding.UTF8, leaveOpen: true);
try
{
Contracts.CheckValue(rep, nameof(rep));
Repository = rep;
_ectx = rep.ExceptionContext;
_ectx.CheckValue(ent, nameof(ent));
_ectx.CheckValueOrNull(dir);
Directory = dir;
Reader = new BinaryReader(ent.Stream, Encoding.UTF8, leaveOpen: true);
try
{
ModelHeader.BeginRead(out FpMin, out Header, out Strings, out LoaderAssemblyName, Reader);
}
catch
{
Reader.Dispose();
throw;
}
}
/// <summary>
/// Create a ModelLoadContext supporting loading from a single-stream, for implementors of ICanSaveInBinaryFormat.
/// </summary>
internal ModelLoadContext(BinaryReader reader, IExceptionContext ectx = null)
{
Contracts.AssertValueOrNull(ectx);
_ectx = ectx;
_ectx.CheckValue(reader, nameof(reader));
Repository = null;
Directory = null;
Reader = reader;
ModelHeader.BeginRead(out FpMin, out Header, out Strings, out LoaderAssemblyName, Reader);
}
public void CheckAtModel()
catch
{
_ectx.Check(Reader.BaseStream.Position == FpMin + Header.FpModel);
}
public void CheckAtModel(VersionInfo ver)
{
_ectx.Check(Reader.BaseStream.Position == FpMin + Header.FpModel);
ModelHeader.CheckVersionInfo(ref Header, ver);
}
/// <summary>
/// Performs version checks.
/// </summary>
public void CheckVersionInfo(VersionInfo ver)
{
ModelHeader.CheckVersionInfo(ref Header, ver);
}
/// <summary>
/// Reads an integer from the load context's reader, and returns the associated string,
/// or null (encoded as -1).
/// </summary>
public string LoadStringOrNull()
{
int id = Reader.ReadInt32();
// Note that -1 means null. Empty strings are in the string table.
_ectx.CheckDecode(-1 <= id && id < Utils.Size(Strings));
if (id >= 0)
return Strings[id];
return null;
}
/// <summary>
/// Reads an integer from the load context's reader, and returns the associated string.
/// </summary>
public string LoadString()
{
int id = Reader.ReadInt32();
Contracts.CheckDecode(0 <= id && id < Utils.Size(Strings));
return Strings[id];
}
/// <summary>
/// Reads an integer from the load context's reader, and returns the associated string.
/// Throws if the string is empty or null.
/// </summary>
public string LoadNonEmptyString()
{
int id = Reader.ReadInt32();
_ectx.CheckDecode(0 <= id && id < Utils.Size(Strings));
var str = Strings[id];
_ectx.CheckDecode(str.Length > 0);
return str;
}
/// <summary>
/// Commit the load operation. This completes reading of the main stream. When in repository
/// mode, it disposes the Reader (but not the repository).
/// </summary>
public void Done()
{
ModelHeader.EndRead(FpMin, ref Header, Reader);
Dispose();
}
/// <summary>
/// When in repository mode, this disposes the Reader (but no the repository).
/// </summary>
public void Dispose()
{
// When in single-stream mode, we don't own the Reader.
if (InRepository)
Reader.Dispose();
Reader.Dispose();
throw;
}
}
/// <summary>
/// Create a ModelLoadContext supporting loading from a single-stream, for implementors of ICanSaveInBinaryFormat.
/// </summary>
internal ModelLoadContext(BinaryReader reader, IExceptionContext ectx = null)
{
Contracts.AssertValueOrNull(ectx);
_ectx = ectx;
_ectx.CheckValue(reader, nameof(reader));
Repository = null;
Directory = null;
Reader = reader;
ModelHeader.BeginRead(out FpMin, out Header, out Strings, out LoaderAssemblyName, Reader);
}
public void CheckAtModel()
{
_ectx.Check(Reader.BaseStream.Position == FpMin + Header.FpModel);
}
public void CheckAtModel(VersionInfo ver)
{
_ectx.Check(Reader.BaseStream.Position == FpMin + Header.FpModel);
ModelHeader.CheckVersionInfo(ref Header, ver);
}
/// <summary>
/// Performs version checks.
/// </summary>
public void CheckVersionInfo(VersionInfo ver)
{
ModelHeader.CheckVersionInfo(ref Header, ver);
}
/// <summary>
/// Reads an integer from the load context's reader, and returns the associated string,
/// or null (encoded as -1).
/// </summary>
public string LoadStringOrNull()
{
int id = Reader.ReadInt32();
// Note that -1 means null. Empty strings are in the string table.
_ectx.CheckDecode(-1 <= id && id < Utils.Size(Strings));
if (id >= 0)
return Strings[id];
return null;
}
/// <summary>
/// Reads an integer from the load context's reader, and returns the associated string.
/// </summary>
public string LoadString()
{
int id = Reader.ReadInt32();
Contracts.CheckDecode(0 <= id && id < Utils.Size(Strings));
return Strings[id];
}
/// <summary>
/// Reads an integer from the load context's reader, and returns the associated string.
/// Throws if the string is empty or null.
/// </summary>
public string LoadNonEmptyString()
{
int id = Reader.ReadInt32();
_ectx.CheckDecode(0 <= id && id < Utils.Size(Strings));
var str = Strings[id];
_ectx.CheckDecode(str.Length > 0);
return str;
}
/// <summary>
/// Commit the load operation. This completes reading of the main stream. When in repository
/// mode, it disposes the Reader (but not the repository).
/// </summary>
public void Done()
{
ModelHeader.EndRead(FpMin, ref Header, Reader);
Dispose();
}
/// <summary>
/// When in repository mode, this disposes the Reader (but no the repository).
/// </summary>
public void Dispose()
{
// When in single-stream mode, we don't own the Reader.
if (InRepository)
Reader.Dispose();
}
}

Просмотреть файл

@ -9,357 +9,356 @@ using System.Text;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Runtime;
namespace Microsoft.ML
namespace Microsoft.ML;
/// <summary>
/// Signature for a repository based model loader. This is the dual of <see cref="ICanSaveModel"/>.
/// </summary>
[BestFriend]
internal delegate void SignatureLoadModel(ModelLoadContext ctx);
internal sealed partial class ModelLoadContext : IDisposable
{
public const string ModelStreamName = "Model.key";
internal const string NameBinary = "Model.bin";
/// <summary>
/// Signature for a repository based model loader. This is the dual of <see cref="ICanSaveModel"/>.
/// Returns the new assembly name to maintain backward compatibility.
/// </summary>
[BestFriend]
internal delegate void SignatureLoadModel(ModelLoadContext ctx);
internal sealed partial class ModelLoadContext : IDisposable
private string ForwardedLoaderAssemblyName
{
public const string ModelStreamName = "Model.key";
internal const string NameBinary = "Model.bin";
/// <summary>
/// Returns the new assembly name to maintain backward compatibility.
/// </summary>
private string ForwardedLoaderAssemblyName
get
{
get
string[] nameDetails = LoaderAssemblyName.Split(',');
switch (nameDetails[0])
{
string[] nameDetails = LoaderAssemblyName.Split(',');
switch (nameDetails[0])
{
case "Microsoft.ML.HalLearners":
nameDetails[0] = "Microsoft.ML.Mkl.Components";
break;
case "Microsoft.ML.StandardLearners":
nameDetails[0] = "Microsoft.ML.StandardTrainers";
break;
default:
return LoaderAssemblyName;
}
return string.Join(",", nameDetails);
case "Microsoft.ML.HalLearners":
nameDetails[0] = "Microsoft.ML.Mkl.Components";
break;
case "Microsoft.ML.StandardLearners":
nameDetails[0] = "Microsoft.ML.StandardTrainers";
break;
default:
return LoaderAssemblyName;
}
return string.Join(",", nameDetails);
}
}
/// <summary>
/// Return whether this context contains a directory and stream for a sub-model with
/// the indicated name. This does not attempt to load the sub-model.
/// </summary>
public bool ContainsModel(string name)
{
if (!InRepository)
return false;
if (string.IsNullOrEmpty(name))
return false;
var dir = Path.Combine(Directory ?? "", name);
var ent = Repository.OpenEntryOrNull(dir, ModelStreamName);
if (ent != null)
{
ent.Dispose();
return true;
}
if ((ent = Repository.OpenEntryOrNull(dir, NameBinary)) != null)
{
ent.Dispose();
return true;
}
/// <summary>
/// Return whether this context contains a directory and stream for a sub-model with
/// the indicated name. This does not attempt to load the sub-model.
/// </summary>
public bool ContainsModel(string name)
{
if (!InRepository)
return false;
}
/// <summary>
/// Load an optional object from the repository directory.
/// Returns false iff no stream was found for the object, iff result is set to null.
/// Throws if loading fails for any other reason.
/// </summary>
public static bool LoadModelOrNull<TRes, TSig>(IHostEnvironment env, out TRes result, RepositoryReader rep, string dir, params object[] extra)
where TRes : class
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(rep, nameof(rep));
var ent = rep.OpenEntryOrNull(dir, ModelStreamName);
if (ent != null)
{
using (ent)
{
// Provide the repository, entry, and directory name to the loadable class ctor.
env.Assert(ent.Stream.Position == 0);
LoadModel<TRes, TSig>(env, out result, rep, ent, dir, extra);
return true;
}
}
if ((ent = rep.OpenEntryOrNull(dir, NameBinary)) != null)
{
using (ent)
{
env.Assert(ent.Stream.Position == 0);
LoadModel<TRes, TSig>(env, out result, ent.Stream, extra);
return true;
}
}
result = null;
if (string.IsNullOrEmpty(name))
return false;
var dir = Path.Combine(Directory ?? "", name);
var ent = Repository.OpenEntryOrNull(dir, ModelStreamName);
if (ent != null)
{
ent.Dispose();
return true;
}
/// <summary>
/// Load an object from the repository directory.
/// </summary>
public static void LoadModel<TRes, TSig>(IHostEnvironment env, out TRes result, RepositoryReader rep, string dir, params object[] extra)
where TRes : class
if ((ent = Repository.OpenEntryOrNull(dir, NameBinary)) != null)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(rep, nameof(rep));
if (!LoadModelOrNull<TRes, TSig>(env, out result, rep, dir, extra))
throw env.ExceptDecode("Corrupt model file");
env.AssertValue(result);
ent.Dispose();
return true;
}
/// <summary>
/// Load a sub model from the given sub directory if it exists. This requires InRepository to be true.
/// Returns false iff no stream was found for the object, iff result is set to null.
/// Throws if loading fails for any other reason.
/// </summary>
public bool LoadModelOrNull<TRes, TSig>(IHostEnvironment env, out TRes result, string name, params object[] extra)
where TRes : class
return false;
}
/// <summary>
/// Load an optional object from the repository directory.
/// Returns false iff no stream was found for the object, iff result is set to null.
/// Throws if loading fails for any other reason.
/// </summary>
public static bool LoadModelOrNull<TRes, TSig>(IHostEnvironment env, out TRes result, RepositoryReader rep, string dir, params object[] extra)
where TRes : class
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(rep, nameof(rep));
var ent = rep.OpenEntryOrNull(dir, ModelStreamName);
if (ent != null)
{
_ectx.CheckValue(env, nameof(env));
_ectx.Check(InRepository, "Can't load a sub-model when reading from a single stream");
return LoadModelOrNull<TRes, TSig>(env, out result, Repository, Path.Combine(Directory ?? "", name), extra);
}
/// <summary>
/// Load a sub model from the given sub directory. This requires InRepository to be true.
/// </summary>
public void LoadModel<TRes, TSig>(IHostEnvironment env, out TRes result, string name, params object[] extra)
where TRes : class
{
_ectx.CheckValue(env, nameof(env));
if (!LoadModelOrNull<TRes, TSig>(env, out result, name, extra))
throw _ectx.ExceptDecode("Corrupt model file");
_ectx.AssertValue(result);
}
/// <summary>
/// Try to load from the given repository entry using the default loader(s) specified in the header.
/// Returns false iff the default loader(s) could not be bound to a compatible loadable class.
/// </summary>
private static bool TryLoadModel<TRes, TSig>(IHostEnvironment env, out TRes result, RepositoryReader rep, Repository.Entry ent, string dir, params object[] extra)
where TRes : class
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(rep, nameof(rep));
long fp = ent.Stream.Position;
using (var ctx = new ModelLoadContext(rep, ent, dir))
{
env.Assert(fp == ctx.FpMin);
if (ctx.TryLoadModelCore<TRes, TSig>(env, out result, extra))
return true;
}
// TryLoadModelCore should rewind on failure.
Contracts.Assert(fp == ent.Stream.Position);
return false;
}
/// <summary>
/// Load from the given repository entry using the default loader(s) specified in the header.
/// </summary>
public static void LoadModel<TRes, TSig>(IHostEnvironment env, out TRes result, RepositoryReader rep, Repository.Entry ent, string dir, params object[] extra)
where TRes : class
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(rep, nameof(rep));
if (!TryLoadModel<TRes, TSig>(env, out result, rep, ent, dir, extra))
throw env.ExceptDecode("Couldn't load model: '{0}'", dir);
}
/// <summary>
/// Try to load from the given stream (non-Repository).
/// Returns false iff the default loader(s) could not be bound to a compatible loadable class.
/// </summary>
public static bool TryLoadModel<TRes, TSig>(IHostEnvironment env, out TRes result, Stream stream, params object[] extra)
where TRes : class
{
Contracts.CheckValue(env, nameof(env));
using (var reader = new BinaryReader(stream, Encoding.UTF8, leaveOpen: true))
return TryLoadModel<TRes, TSig>(env, out result, reader, extra);
}
/// <summary>
/// Load from the given stream (non-Repository) using the default loader(s) specified in the header.
/// </summary>
public static void LoadModel<TRes, TSig>(IHostEnvironment env, out TRes result, Stream stream, params object[] extra)
where TRes : class
{
Contracts.CheckValue(env, nameof(env));
if (!TryLoadModel<TRes, TSig>(env, out result, stream, extra))
throw Contracts.ExceptDecode("Couldn't load model");
}
/// <summary>
/// Try to load from the given reader (non-Repository).
/// Returns false iff the default loader(s) could not be bound to a compatible loadable class.
/// </summary>
public static bool TryLoadModel<TRes, TSig>(IHostEnvironment env, out TRes result, BinaryReader reader, params object[] extra)
where TRes : class
{
Contracts.CheckValue(env, nameof(env));
long fp = reader.BaseStream.Position;
using (var ctx = new ModelLoadContext(reader))
{
Contracts.Assert(fp == ctx.FpMin);
return ctx.TryLoadModelCore<TRes, TSig>(env, out result, extra);
}
}
/// <summary>
/// Load from the given reader (non-Repository) using the default loader(s) specified in the header.
/// </summary>
public static void LoadModel<TRes, TSig>(IHostEnvironment env, out TRes result, BinaryReader reader, params object[] extra)
where TRes : class
{
Contracts.CheckValue(env, nameof(env));
if (!TryLoadModel<TRes, TSig>(env, out result, reader, extra))
throw Contracts.ExceptDecode("Couldn't load model");
}
/// <summary>
/// Tries to load.
/// Returns false iff the default loader(s) could not be bound to a compatible loadable class.
/// </summary>
private bool TryLoadModelCore<TRes, TSig>(IHostEnvironment env, out TRes result, params object[] extra)
where TRes : class
{
_ectx.AssertValue(env, "env");
_ectx.Assert(Reader.BaseStream.Position == FpMin + Header.FpModel);
var args = ConcatArgsRev(extra, this);
EnsureLoaderAssemblyIsRegistered(env.ComponentCatalog);
object tmp;
string sig = ModelHeader.GetLoaderSig(ref Header);
if (!string.IsNullOrWhiteSpace(sig) &&
ComponentCatalog.TryCreateInstance<object, TSig>(env, out tmp, sig, "", args))
{
result = tmp as TRes;
if (result != null)
{
Done();
return true;
}
// REVIEW: Should this fall through?
}
_ectx.Assert(Reader.BaseStream.Position == FpMin + Header.FpModel);
string sigAlt = ModelHeader.GetLoaderSigAlt(ref Header);
if (!string.IsNullOrWhiteSpace(sigAlt) &&
ComponentCatalog.TryCreateInstance<object, TSig>(env, out tmp, sigAlt, "", args))
{
result = tmp as TRes;
if (result != null)
{
Done();
return true;
}
// REVIEW: Should this fall through?
}
_ectx.Assert(Reader.BaseStream.Position == FpMin + Header.FpModel);
Reader.BaseStream.Position = FpMin;
result = null;
return false;
}
private void EnsureLoaderAssemblyIsRegistered(ComponentCatalog catalog)
{
if (!string.IsNullOrEmpty(LoaderAssemblyName))
{
var assembly = Assembly.Load(ForwardedLoaderAssemblyName);
catalog.RegisterAssembly(assembly);
}
}
private static object[] ConcatArgsRev(object[] args2, params object[] args1)
{
Contracts.AssertNonEmpty(args1);
return Utils.Concat(args1, args2);
}
/// <summary>
/// Try to load a sub model from the given sub directory. This requires InRepository to be true.
/// </summary>
public bool TryProcessSubModel(string dir, Action<ModelLoadContext> action)
{
_ectx.Check(InRepository, "Can't Load a sub-model when reading from a single stream");
_ectx.CheckNonEmpty(dir, nameof(dir));
_ectx.CheckValue(action, nameof(action));
string path = Path.Combine(Directory, dir);
var ent = Repository.OpenEntryOrNull(path, ModelStreamName);
if (ent == null)
return false;
using (ent)
{
// Provide the repository, entry, and directory name to the loadable class ctor.
_ectx.Assert(ent.Stream.Position == 0);
using (var ctx = new ModelLoadContext(Repository, ent, path))
action(ctx);
env.Assert(ent.Stream.Position == 0);
LoadModel<TRes, TSig>(env, out result, rep, ent, dir, extra);
return true;
}
return true;
}
/// <summary>
/// Try to load a binary stream from the current directory. This requires InRepository to be true.
/// </summary>
public bool TryLoadBinaryStream(string name, Action<BinaryReader> action)
if ((ent = rep.OpenEntryOrNull(dir, NameBinary)) != null)
{
_ectx.Check(InRepository, "Can't Load a sub-model when reading from a single stream");
_ectx.CheckNonEmpty(name, nameof(name));
_ectx.CheckValue(action, nameof(action));
var ent = Repository.OpenEntryOrNull(Directory, name);
if (ent == null)
return false;
using (ent)
using (var reader = new BinaryReader(ent.Stream, Encoding.UTF8, leaveOpen: true))
{
action(reader);
env.Assert(ent.Stream.Position == 0);
LoadModel<TRes, TSig>(env, out result, ent.Stream, extra);
return true;
}
return true;
}
/// <summary>
/// Try to load a text stream from the current directory. This requires InRepository to be true.
/// </summary>
public bool TryLoadTextStream(string name, Action<TextReader> action)
result = null;
return false;
}
/// <summary>
/// Load an object from the repository directory.
/// </summary>
public static void LoadModel<TRes, TSig>(IHostEnvironment env, out TRes result, RepositoryReader rep, string dir, params object[] extra)
where TRes : class
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(rep, nameof(rep));
if (!LoadModelOrNull<TRes, TSig>(env, out result, rep, dir, extra))
throw env.ExceptDecode("Corrupt model file");
env.AssertValue(result);
}
/// <summary>
/// Load a sub model from the given sub directory if it exists. This requires InRepository to be true.
/// Returns false iff no stream was found for the object, iff result is set to null.
/// Throws if loading fails for any other reason.
/// </summary>
public bool LoadModelOrNull<TRes, TSig>(IHostEnvironment env, out TRes result, string name, params object[] extra)
where TRes : class
{
_ectx.CheckValue(env, nameof(env));
_ectx.Check(InRepository, "Can't load a sub-model when reading from a single stream");
return LoadModelOrNull<TRes, TSig>(env, out result, Repository, Path.Combine(Directory ?? "", name), extra);
}
/// <summary>
/// Load a sub model from the given sub directory. This requires InRepository to be true.
/// </summary>
public void LoadModel<TRes, TSig>(IHostEnvironment env, out TRes result, string name, params object[] extra)
where TRes : class
{
_ectx.CheckValue(env, nameof(env));
if (!LoadModelOrNull<TRes, TSig>(env, out result, name, extra))
throw _ectx.ExceptDecode("Corrupt model file");
_ectx.AssertValue(result);
}
/// <summary>
/// Try to load from the given repository entry using the default loader(s) specified in the header.
/// Returns false iff the default loader(s) could not be bound to a compatible loadable class.
/// </summary>
private static bool TryLoadModel<TRes, TSig>(IHostEnvironment env, out TRes result, RepositoryReader rep, Repository.Entry ent, string dir, params object[] extra)
where TRes : class
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(rep, nameof(rep));
long fp = ent.Stream.Position;
using (var ctx = new ModelLoadContext(rep, ent, dir))
{
_ectx.Check(InRepository, "Can't Load a sub-model when reading from a single stream");
_ectx.CheckNonEmpty(name, nameof(name));
_ectx.CheckValue(action, nameof(action));
env.Assert(fp == ctx.FpMin);
if (ctx.TryLoadModelCore<TRes, TSig>(env, out result, extra))
return true;
}
var ent = Repository.OpenEntryOrNull(Directory, name);
if (ent == null)
return false;
// TryLoadModelCore should rewind on failure.
Contracts.Assert(fp == ent.Stream.Position);
using (ent)
using (var reader = new StreamReader(ent.Stream))
{
action(reader);
}
return true;
return false;
}
/// <summary>
/// Load from the given repository entry using the default loader(s) specified in the header.
/// </summary>
public static void LoadModel<TRes, TSig>(IHostEnvironment env, out TRes result, RepositoryReader rep, Repository.Entry ent, string dir, params object[] extra)
where TRes : class
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(rep, nameof(rep));
if (!TryLoadModel<TRes, TSig>(env, out result, rep, ent, dir, extra))
throw env.ExceptDecode("Couldn't load model: '{0}'", dir);
}
/// <summary>
/// Try to load from the given stream (non-Repository).
/// Returns false iff the default loader(s) could not be bound to a compatible loadable class.
/// </summary>
public static bool TryLoadModel<TRes, TSig>(IHostEnvironment env, out TRes result, Stream stream, params object[] extra)
where TRes : class
{
Contracts.CheckValue(env, nameof(env));
using (var reader = new BinaryReader(stream, Encoding.UTF8, leaveOpen: true))
return TryLoadModel<TRes, TSig>(env, out result, reader, extra);
}
/// <summary>
/// Load from the given stream (non-Repository) using the default loader(s) specified in the header.
/// </summary>
public static void LoadModel<TRes, TSig>(IHostEnvironment env, out TRes result, Stream stream, params object[] extra)
where TRes : class
{
Contracts.CheckValue(env, nameof(env));
if (!TryLoadModel<TRes, TSig>(env, out result, stream, extra))
throw Contracts.ExceptDecode("Couldn't load model");
}
/// <summary>
/// Try to load from the given reader (non-Repository).
/// Returns false iff the default loader(s) could not be bound to a compatible loadable class.
/// </summary>
public static bool TryLoadModel<TRes, TSig>(IHostEnvironment env, out TRes result, BinaryReader reader, params object[] extra)
where TRes : class
{
Contracts.CheckValue(env, nameof(env));
long fp = reader.BaseStream.Position;
using (var ctx = new ModelLoadContext(reader))
{
Contracts.Assert(fp == ctx.FpMin);
return ctx.TryLoadModelCore<TRes, TSig>(env, out result, extra);
}
}
/// <summary>
/// Load from the given reader (non-Repository) using the default loader(s) specified in the header.
/// </summary>
public static void LoadModel<TRes, TSig>(IHostEnvironment env, out TRes result, BinaryReader reader, params object[] extra)
where TRes : class
{
Contracts.CheckValue(env, nameof(env));
if (!TryLoadModel<TRes, TSig>(env, out result, reader, extra))
throw Contracts.ExceptDecode("Couldn't load model");
}
/// <summary>
/// Tries to load.
/// Returns false iff the default loader(s) could not be bound to a compatible loadable class.
/// </summary>
private bool TryLoadModelCore<TRes, TSig>(IHostEnvironment env, out TRes result, params object[] extra)
where TRes : class
{
_ectx.AssertValue(env, "env");
_ectx.Assert(Reader.BaseStream.Position == FpMin + Header.FpModel);
var args = ConcatArgsRev(extra, this);
EnsureLoaderAssemblyIsRegistered(env.ComponentCatalog);
object tmp;
string sig = ModelHeader.GetLoaderSig(ref Header);
if (!string.IsNullOrWhiteSpace(sig) &&
ComponentCatalog.TryCreateInstance<object, TSig>(env, out tmp, sig, "", args))
{
result = tmp as TRes;
if (result != null)
{
Done();
return true;
}
// REVIEW: Should this fall through?
}
_ectx.Assert(Reader.BaseStream.Position == FpMin + Header.FpModel);
string sigAlt = ModelHeader.GetLoaderSigAlt(ref Header);
if (!string.IsNullOrWhiteSpace(sigAlt) &&
ComponentCatalog.TryCreateInstance<object, TSig>(env, out tmp, sigAlt, "", args))
{
result = tmp as TRes;
if (result != null)
{
Done();
return true;
}
// REVIEW: Should this fall through?
}
_ectx.Assert(Reader.BaseStream.Position == FpMin + Header.FpModel);
Reader.BaseStream.Position = FpMin;
result = null;
return false;
}
private void EnsureLoaderAssemblyIsRegistered(ComponentCatalog catalog)
{
if (!string.IsNullOrEmpty(LoaderAssemblyName))
{
var assembly = Assembly.Load(ForwardedLoaderAssemblyName);
catalog.RegisterAssembly(assembly);
}
}
private static object[] ConcatArgsRev(object[] args2, params object[] args1)
{
Contracts.AssertNonEmpty(args1);
return Utils.Concat(args1, args2);
}
/// <summary>
/// Try to load a sub model from the given sub directory. This requires InRepository to be true.
/// </summary>
public bool TryProcessSubModel(string dir, Action<ModelLoadContext> action)
{
_ectx.Check(InRepository, "Can't Load a sub-model when reading from a single stream");
_ectx.CheckNonEmpty(dir, nameof(dir));
_ectx.CheckValue(action, nameof(action));
string path = Path.Combine(Directory, dir);
var ent = Repository.OpenEntryOrNull(path, ModelStreamName);
if (ent == null)
return false;
using (ent)
{
// Provide the repository, entry, and directory name to the loadable class ctor.
_ectx.Assert(ent.Stream.Position == 0);
using (var ctx = new ModelLoadContext(Repository, ent, path))
action(ctx);
}
return true;
}
/// <summary>
/// Try to load a binary stream from the current directory. This requires InRepository to be true.
/// </summary>
public bool TryLoadBinaryStream(string name, Action<BinaryReader> action)
{
_ectx.Check(InRepository, "Can't Load a sub-model when reading from a single stream");
_ectx.CheckNonEmpty(name, nameof(name));
_ectx.CheckValue(action, nameof(action));
var ent = Repository.OpenEntryOrNull(Directory, name);
if (ent == null)
return false;
using (ent)
using (var reader = new BinaryReader(ent.Stream, Encoding.UTF8, leaveOpen: true))
{
action(reader);
}
return true;
}
/// <summary>
/// Try to load a text stream from the current directory. This requires InRepository to be true.
/// </summary>
public bool TryLoadTextStream(string name, Action<TextReader> action)
{
_ectx.Check(InRepository, "Can't Load a sub-model when reading from a single stream");
_ectx.CheckNonEmpty(name, nameof(name));
_ectx.CheckValue(action, nameof(action));
var ent = Repository.OpenEntryOrNull(Directory, name);
if (ent == null)
return false;
using (ent)
using (var reader = new StreamReader(ent.Stream))
{
action(reader);
}
return true;
}
}

Просмотреть файл

@ -8,253 +8,252 @@ using System.Text;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Runtime;
namespace Microsoft.ML
namespace Microsoft.ML;
/// <summary>
/// Convenience context object for saving models to a repository, for
/// implementors of <see cref="ICanSaveModel"/>.
/// </summary>
/// <remarks>
/// This class reduces the amount of boiler plate code needed to implement <see cref="ICanSaveModel"/>.
/// It can also be used when saving to a single stream, for implementors of <see cref="ICanSaveInBinaryFormat"/>.
/// </remarks>
public sealed partial class ModelSaveContext : IDisposable
{
/// <summary>
/// Convenience context object for saving models to a repository, for
/// implementors of <see cref="ICanSaveModel"/>.
/// When in repository mode, this is the repository we're writing to. It is null when
/// in single-stream mode.
/// </summary>
/// <remarks>
/// This class reduces the amount of boiler plate code needed to implement <see cref="ICanSaveModel"/>.
/// It can also be used when saving to a single stream, for implementors of <see cref="ICanSaveInBinaryFormat"/>.
/// </remarks>
public sealed partial class ModelSaveContext : IDisposable
[BestFriend]
internal readonly RepositoryWriter Repository;
/// <summary>
/// When in repository mode, this is the directory we're reading from. Null means the root
/// of the repository. It is always null in single-stream mode.
/// </summary>
[BestFriend]
internal readonly string Directory;
/// <summary>
/// The main stream writer.
/// </summary>
[BestFriend]
internal readonly BinaryWriter Writer;
/// <summary>
/// The strings that will be saved in the main stream's string table.
/// </summary>
[BestFriend]
internal readonly NormStr.Pool Strings;
/// <summary>
/// The main stream's model header.
/// </summary>
[BestFriend]
internal ModelHeader Header;
/// <summary>
/// The min file position of the main stream.
/// </summary>
[BestFriend]
internal readonly long FpMin;
/// <summary>
/// The wrapped entry.
/// </summary>
private readonly Repository.Entry _ent;
/// <summary>
/// Exception context provided by Repository (can be null).
/// </summary>
private readonly IExceptionContext _ectx;
/// <summary>
/// The assembly name where the loader resides.
/// </summary>
private string _loaderAssemblyName;
/// <summary>
/// Returns whether this context is in repository mode (true) or single-stream mode (false).
/// </summary>
[BestFriend]
internal bool InRepository => Repository != null;
/// <summary>
/// Create a <see cref="ModelSaveContext"/> supporting saving to a repository, for implementors of <see cref="ICanSaveModel"/>.
/// </summary>
internal ModelSaveContext(RepositoryWriter rep, string dir, string name)
{
/// <summary>
/// When in repository mode, this is the repository we're writing to. It is null when
/// in single-stream mode.
/// </summary>
[BestFriend]
internal readonly RepositoryWriter Repository;
Contracts.CheckValue(rep, nameof(rep));
Repository = rep;
_ectx = rep.ExceptionContext;
/// <summary>
/// When in repository mode, this is the directory we're reading from. Null means the root
/// of the repository. It is always null in single-stream mode.
/// </summary>
[BestFriend]
internal readonly string Directory;
_ectx.CheckValueOrNull(dir);
_ectx.CheckNonEmpty(name, nameof(name));
/// <summary>
/// The main stream writer.
/// </summary>
[BestFriend]
internal readonly BinaryWriter Writer;
Directory = dir;
Strings = new NormStr.Pool();
/// <summary>
/// The strings that will be saved in the main stream's string table.
/// </summary>
[BestFriend]
internal readonly NormStr.Pool Strings;
/// <summary>
/// The main stream's model header.
/// </summary>
[BestFriend]
internal ModelHeader Header;
/// <summary>
/// The min file position of the main stream.
/// </summary>
[BestFriend]
internal readonly long FpMin;
/// <summary>
/// The wrapped entry.
/// </summary>
private readonly Repository.Entry _ent;
/// <summary>
/// Exception context provided by Repository (can be null).
/// </summary>
private readonly IExceptionContext _ectx;
/// <summary>
/// The assembly name where the loader resides.
/// </summary>
private string _loaderAssemblyName;
/// <summary>
/// Returns whether this context is in repository mode (true) or single-stream mode (false).
/// </summary>
[BestFriend]
internal bool InRepository => Repository != null;
/// <summary>
/// Create a <see cref="ModelSaveContext"/> supporting saving to a repository, for implementors of <see cref="ICanSaveModel"/>.
/// </summary>
internal ModelSaveContext(RepositoryWriter rep, string dir, string name)
_ent = rep.CreateEntry(dir, name);
try
{
Contracts.CheckValue(rep, nameof(rep));
Repository = rep;
_ectx = rep.ExceptionContext;
_ectx.CheckValueOrNull(dir);
_ectx.CheckNonEmpty(name, nameof(name));
Directory = dir;
Strings = new NormStr.Pool();
_ent = rep.CreateEntry(dir, name);
Writer = new BinaryWriter(_ent.Stream, Encoding.UTF8, leaveOpen: true);
try
{
Writer = new BinaryWriter(_ent.Stream, Encoding.UTF8, leaveOpen: true);
try
{
ModelHeader.BeginWrite(Writer, out FpMin, out Header);
}
catch
{
Writer.Dispose();
throw;
}
ModelHeader.BeginWrite(Writer, out FpMin, out Header);
}
catch
{
_ent.Dispose();
Writer.Dispose();
throw;
}
}
/// <summary>
/// Create a <see cref="ModelSaveContext"/> supporting saving to a single-stream, for implementors of <see cref="ICanSaveInBinaryFormat"/>.
/// </summary>
internal ModelSaveContext(BinaryWriter writer, IExceptionContext ectx = null)
catch
{
Contracts.AssertValueOrNull(ectx);
_ectx = ectx;
_ectx.CheckValue(writer, nameof(writer));
Repository = null;
Directory = null;
_ent = null;
Strings = new NormStr.Pool();
Writer = writer;
ModelHeader.BeginWrite(Writer, out FpMin, out Header);
_ent.Dispose();
throw;
}
}
[BestFriend]
internal void CheckAtModel()
/// <summary>
/// Create a <see cref="ModelSaveContext"/> supporting saving to a single-stream, for implementors of <see cref="ICanSaveInBinaryFormat"/>.
/// </summary>
internal ModelSaveContext(BinaryWriter writer, IExceptionContext ectx = null)
{
Contracts.AssertValueOrNull(ectx);
_ectx = ectx;
_ectx.CheckValue(writer, nameof(writer));
Repository = null;
Directory = null;
_ent = null;
Strings = new NormStr.Pool();
Writer = writer;
ModelHeader.BeginWrite(Writer, out FpMin, out Header);
}
[BestFriend]
internal void CheckAtModel()
{
_ectx.Check(Writer.BaseStream.Position == FpMin + Header.FpModel);
}
/// <summary>
/// Set the version information in the main stream's header. This should be called before
/// <see cref="Done"/> is called.
/// </summary>
/// <param name="ver"></param>
[BestFriend]
internal void SetVersionInfo(VersionInfo ver)
{
ModelHeader.SetVersionInfo(ref Header, ver);
_loaderAssemblyName = ver.LoaderAssemblyName;
}
[BestFriend]
internal void SaveTextStream(string name, Action<TextWriter> action)
{
_ectx.Check(InRepository, "Can't save a text stream when writing to a single stream");
_ectx.CheckNonEmpty(name, nameof(name));
_ectx.CheckValue(action, nameof(action));
// I verified in the CLR source that the default buffer size is 1024. It's unfortunate
// that to set leaveOpen to true, we have to specify the buffer size....
using (var ent = Repository.CreateEntry(Directory, name))
using (var writer = Utils.OpenWriter(ent.Stream))
{
_ectx.Check(Writer.BaseStream.Position == FpMin + Header.FpModel);
action(writer);
}
}
/// <summary>
/// Set the version information in the main stream's header. This should be called before
/// <see cref="Done"/> is called.
/// </summary>
/// <param name="ver"></param>
[BestFriend]
internal void SetVersionInfo(VersionInfo ver)
[BestFriend]
internal void SaveBinaryStream(string name, Action<BinaryWriter> action)
{
_ectx.Check(InRepository, "Can't save a text stream when writing to a single stream");
_ectx.CheckNonEmpty(name, nameof(name));
_ectx.CheckValue(action, nameof(action));
// I verified in the CLR source that the default buffer size is 1024. It's unfortunate
// that to set leaveOpen to true, we have to specify the buffer size....
using (var ent = Repository.CreateEntry(Directory, name))
using (var writer = new BinaryWriter(ent.Stream, Encoding.UTF8, leaveOpen: true))
{
ModelHeader.SetVersionInfo(ref Header, ver);
_loaderAssemblyName = ver.LoaderAssemblyName;
action(writer);
}
}
[BestFriend]
internal void SaveTextStream(string name, Action<TextWriter> action)
{
_ectx.Check(InRepository, "Can't save a text stream when writing to a single stream");
_ectx.CheckNonEmpty(name, nameof(name));
_ectx.CheckValue(action, nameof(action));
// I verified in the CLR source that the default buffer size is 1024. It's unfortunate
// that to set leaveOpen to true, we have to specify the buffer size....
using (var ent = Repository.CreateEntry(Directory, name))
using (var writer = Utils.OpenWriter(ent.Stream))
{
action(writer);
}
}
[BestFriend]
internal void SaveBinaryStream(string name, Action<BinaryWriter> action)
{
_ectx.Check(InRepository, "Can't save a text stream when writing to a single stream");
_ectx.CheckNonEmpty(name, nameof(name));
_ectx.CheckValue(action, nameof(action));
// I verified in the CLR source that the default buffer size is 1024. It's unfortunate
// that to set leaveOpen to true, we have to specify the buffer size....
using (var ent = Repository.CreateEntry(Directory, name))
using (var writer = new BinaryWriter(ent.Stream, Encoding.UTF8, leaveOpen: true))
{
action(writer);
}
}
/// <summary>
/// Puts a string into the context pool, and writes the integer code of the string ID
/// to the write stream. If str is null, this writes -1 and doesn't add it to the pool.
/// </summary>
[BestFriend]
internal void SaveStringOrNull(string str)
{
if (str == null)
Writer.Write(-1);
else
Writer.Write(Strings.Add(str).Id);
}
/// <summary>
/// Puts a string into the context pool, and writes the integer code of the string ID
/// to the write stream. Checks that str is not null.
/// </summary>
[BestFriend]
internal void SaveString(string str)
{
_ectx.CheckValue(str, nameof(str));
/// <summary>
/// Puts a string into the context pool, and writes the integer code of the string ID
/// to the write stream. If str is null, this writes -1 and doesn't add it to the pool.
/// </summary>
[BestFriend]
internal void SaveStringOrNull(string str)
{
if (str == null)
Writer.Write(-1);
else
Writer.Write(Strings.Add(str).Id);
}
}
[BestFriend]
internal void SaveString(ReadOnlyMemory<char> str)
/// <summary>
/// Puts a string into the context pool, and writes the integer code of the string ID
/// to the write stream. Checks that str is not null.
/// </summary>
[BestFriend]
internal void SaveString(string str)
{
_ectx.CheckValue(str, nameof(str));
Writer.Write(Strings.Add(str).Id);
}
[BestFriend]
internal void SaveString(ReadOnlyMemory<char> str)
{
Writer.Write(Strings.Add(str).Id);
}
/// <summary>
/// Puts a string into the context pool, and writes the integer code of the string ID
/// to the write stream.
/// </summary>
[BestFriend]
internal void SaveNonEmptyString(string str)
{
_ectx.CheckParam(!string.IsNullOrEmpty(str), nameof(str));
Writer.Write(Strings.Add(str).Id);
}
[BestFriend]
internal void SaveNonEmptyString(ReadOnlyMemory<char> str)
{
Writer.Write(Strings.Add(str).Id);
}
/// <summary>
/// Commit the save operation. This completes writing of the main stream. When in repository
/// mode, it disposes <see cref="Writer"/> (but not <see cref="Repository"/>).
/// </summary>
[BestFriend]
internal void Done()
{
_ectx.Check(Header.ModelSignature != 0, "ModelSignature not specified!");
ModelHeader.EndWrite(Writer, FpMin, ref Header, Strings, _loaderAssemblyName);
Dispose();
}
/// <summary>
/// When in repository mode, this disposes the Writer (but not the repository).
/// </summary>
public void Dispose()
{
_ectx.Assert((_ent == null) == !InRepository);
// When in single stream mode, we don't own the Writer.
if (InRepository)
{
Writer.Write(Strings.Add(str).Id);
}
/// <summary>
/// Puts a string into the context pool, and writes the integer code of the string ID
/// to the write stream.
/// </summary>
[BestFriend]
internal void SaveNonEmptyString(string str)
{
_ectx.CheckParam(!string.IsNullOrEmpty(str), nameof(str));
Writer.Write(Strings.Add(str).Id);
}
[BestFriend]
internal void SaveNonEmptyString(ReadOnlyMemory<char> str)
{
Writer.Write(Strings.Add(str).Id);
}
/// <summary>
/// Commit the save operation. This completes writing of the main stream. When in repository
/// mode, it disposes <see cref="Writer"/> (but not <see cref="Repository"/>).
/// </summary>
[BestFriend]
internal void Done()
{
_ectx.Check(Header.ModelSignature != 0, "ModelSignature not specified!");
ModelHeader.EndWrite(Writer, FpMin, ref Header, Strings, _loaderAssemblyName);
Dispose();
}
/// <summary>
/// When in repository mode, this disposes the Writer (but not the repository).
/// </summary>
public void Dispose()
{
_ectx.Assert((_ent == null) == !InRepository);
// When in single stream mode, we don't own the Writer.
if (InRepository)
{
Writer.Dispose();
_ent.Dispose();
}
Writer.Dispose();
_ent.Dispose();
}
}
}

Просмотреть файл

@ -7,85 +7,84 @@ using System.IO;
using System.Text;
using Microsoft.ML.Runtime;
namespace Microsoft.ML
namespace Microsoft.ML;
public sealed partial class ModelSaveContext : IDisposable
{
public sealed partial class ModelSaveContext : IDisposable
/// <summary>
/// Save a sub model to the given sub directory. This requires <see cref="InRepository"/> to be <see langword="true"/>.
/// </summary>
[BestFriend]
internal void SaveModel<T>(T value, string name)
where T : class
{
/// <summary>
/// Save a sub model to the given sub directory. This requires <see cref="InRepository"/> to be <see langword="true"/>.
/// </summary>
[BestFriend]
internal void SaveModel<T>(T value, string name)
where T : class
_ectx.Check(InRepository, "Can't save a sub-model when writing to a single stream");
SaveModel(Repository, value, Path.Combine(Directory ?? "", name));
}
/// <summary>
/// Save the object by calling TrySaveModel then falling back to .net serialization.
/// </summary>
[BestFriend]
internal static void SaveModel<T>(RepositoryWriter rep, T value, string path)
where T : class
{
if (value == null)
return;
var sm = value as ICanSaveModel;
if (sm != null)
{
_ectx.Check(InRepository, "Can't save a sub-model when writing to a single stream");
SaveModel(Repository, value, Path.Combine(Directory ?? "", name));
}
/// <summary>
/// Save the object by calling TrySaveModel then falling back to .net serialization.
/// </summary>
[BestFriend]
internal static void SaveModel<T>(RepositoryWriter rep, T value, string path)
where T : class
{
if (value == null)
return;
var sm = value as ICanSaveModel;
if (sm != null)
using (var ctx = new ModelSaveContext(rep, path, ModelLoadContext.ModelStreamName))
{
using (var ctx = new ModelSaveContext(rep, path, ModelLoadContext.ModelStreamName))
{
sm.Save(ctx);
ctx.Done();
}
return;
}
var sb = value as ICanSaveInBinaryFormat;
if (sb != null)
{
using (var ent = rep.CreateEntry(path, ModelLoadContext.NameBinary))
using (var writer = new BinaryWriter(ent.Stream, Encoding.UTF8, leaveOpen: true))
{
sb.SaveAsBinary(writer);
}
}
}
/// <summary>
/// Save to a single-stream by invoking the given action.
/// </summary>
[BestFriend]
internal static void Save(BinaryWriter writer, Action<ModelSaveContext> fn)
{
Contracts.CheckValue(writer, nameof(writer));
Contracts.CheckValue(fn, nameof(fn));
using (var ctx = new ModelSaveContext(writer))
{
fn(ctx);
sm.Save(ctx);
ctx.Done();
}
return;
}
/// <summary>
/// Save to the given sub directory by invoking the given action. This requires
/// <see cref="InRepository"/> to be <see langword="true"/>.
/// </summary>
[BestFriend]
internal void SaveSubModel(string dir, Action<ModelSaveContext> fn)
var sb = value as ICanSaveInBinaryFormat;
if (sb != null)
{
_ectx.Check(InRepository, "Can't save a sub-model when writing to a single stream");
_ectx.CheckNonEmpty(dir, nameof(dir));
_ectx.CheckValue(fn, nameof(fn));
using (var ctx = new ModelSaveContext(Repository, Path.Combine(Directory ?? "", dir), ModelLoadContext.ModelStreamName))
using (var ent = rep.CreateEntry(path, ModelLoadContext.NameBinary))
using (var writer = new BinaryWriter(ent.Stream, Encoding.UTF8, leaveOpen: true))
{
fn(ctx);
ctx.Done();
sb.SaveAsBinary(writer);
}
}
}
/// <summary>
/// Save to a single-stream by invoking the given action.
/// </summary>
[BestFriend]
internal static void Save(BinaryWriter writer, Action<ModelSaveContext> fn)
{
Contracts.CheckValue(writer, nameof(writer));
Contracts.CheckValue(fn, nameof(fn));
using (var ctx = new ModelSaveContext(writer))
{
fn(ctx);
ctx.Done();
}
}
/// <summary>
/// Save to the given sub directory by invoking the given action. This requires
/// <see cref="InRepository"/> to be <see langword="true"/>.
/// </summary>
[BestFriend]
internal void SaveSubModel(string dir, Action<ModelSaveContext> fn)
{
_ectx.Check(InRepository, "Can't save a sub-model when writing to a single stream");
_ectx.CheckNonEmpty(dir, nameof(dir));
_ectx.CheckValue(fn, nameof(fn));
using (var ctx = new ModelSaveContext(Repository, Path.Combine(Directory ?? "", dir), ModelLoadContext.ModelStreamName))
{
fn(ctx);
ctx.Done();
}
}
}

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Просмотреть файл

@ -9,264 +9,263 @@ using System.Text;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Runtime;
namespace Microsoft.ML.Data
namespace Microsoft.ML.Data;
[BestFriend]
internal static class ReadOnlyMemoryUtils
{
[BestFriend]
internal static class ReadOnlyMemoryUtils
/// <summary>
/// Compare equality with the given system string value.
/// </summary>
public static bool EqualsStr(string s, ReadOnlyMemory<char> memory)
{
Contracts.CheckValueOrNull(s);
/// <summary>
/// Compare equality with the given system string value.
/// </summary>
public static bool EqualsStr(string s, ReadOnlyMemory<char> memory)
if (s == null)
return memory.Length == 0;
if (s.Length != memory.Length)
return false;
return memory.Span.SequenceEqual(s.AsSpan());
}
public static IEnumerable<ReadOnlyMemory<char>> Split(ReadOnlyMemory<char> memory, char[] separators)
{
Contracts.CheckValueOrNull(separators);
if (memory.IsEmpty)
yield break;
if (separators == null || separators.Length == 0)
{
Contracts.CheckValueOrNull(s);
if (s == null)
return memory.Length == 0;
if (s.Length != memory.Length)
return false;
return memory.Span.SequenceEqual(s.AsSpan());
yield return memory;
yield break;
}
public static IEnumerable<ReadOnlyMemory<char>> Split(ReadOnlyMemory<char> memory, char[] separators)
var span = memory.Span;
if (separators.Length == 1)
{
Contracts.CheckValueOrNull(separators);
if (memory.IsEmpty)
yield break;
if (separators == null || separators.Length == 0)
char chSep = separators[0];
for (int ichCur = 0; ;)
{
yield return memory;
yield break;
}
var span = memory.Span;
if (separators.Length == 1)
{
char chSep = separators[0];
for (int ichCur = 0; ;)
int nextSep = span.IndexOf(chSep);
if (nextSep == -1)
{
int nextSep = span.IndexOf(chSep);
if (nextSep == -1)
{
yield return memory.Slice(ichCur);
yield break;
}
yield return memory.Slice(ichCur, nextSep);
// Skip the separator.
ichCur += nextSep + 1;
span = memory.Slice(ichCur).Span;
}
}
else
{
for (int ichCur = 0; ;)
{
int nextSep = span.IndexOfAny(separators);
if (nextSep == -1)
{
yield return memory.Slice(ichCur);
yield break;
}
yield return memory.Slice(ichCur, nextSep);
// Skip the separator.
ichCur += nextSep + 1;
span = memory.Slice(ichCur).Span;
}
}
}
/// <summary>
/// Splits <paramref name="memory"/> on the left-most occurrence of separator and produces the left
/// and right <see cref="ReadOnlyMemory{T}"/> of <see cref="char"/> values. If <paramref name="memory"/> does not contain the separator character,
/// this returns false and sets <paramref name="left"/> to this instance and <paramref name="right"/>
/// to the default <see cref="ReadOnlyMemory{T}"/> of <see cref="char"/> value.
/// </summary>
public static bool SplitOne(ReadOnlyMemory<char> memory, char separator, out ReadOnlyMemory<char> left, out ReadOnlyMemory<char> right)
{
if (memory.IsEmpty)
{
left = memory;
right = default;
return false;
}
int index = memory.Span.IndexOf(separator);
if (index == -1)
{
left = memory;
right = default;
return false;
}
left = memory.Slice(0, index);
right = memory.Slice(index + 1, memory.Length - index - 1);
return true;
}
/// <summary>
/// Splits <paramref name="memory"/> on the left-most occurrence of an element of separators character array and
/// produces the left and right <see cref="ReadOnlyMemory{T}"/> of <see cref="char"/> values. If <paramref name="memory"/> does not contain any of the
/// characters in separators, this return false and initializes <paramref name="left"/> to this instance
/// and <paramref name="right"/> to the default <see cref="ReadOnlyMemory{T}"/> of <see cref="char"/> value.
/// </summary>
public static bool SplitOne(ReadOnlyMemory<char> memory, char[] separators, out ReadOnlyMemory<char> left, out ReadOnlyMemory<char> right)
{
Contracts.CheckValueOrNull(separators);
if (memory.IsEmpty || separators == null || separators.Length == 0)
{
left = memory;
right = default;
return false;
}
int index;
if (separators.Length == 1)
index = memory.Span.IndexOf(separators[0]);
else
index = memory.Span.IndexOfAny(separators);
if (index == -1)
{
left = memory;
right = default;
return false;
}
left = memory.Slice(0, index);
right = memory.Slice(index + 1, memory.Length - index - 1);
return true;
}
/// <summary>
/// Returns a <see cref="ReadOnlyMemory{T}"/> of <see cref="char"/> with leading and trailing spaces trimmed. Note that this
/// will remove only spaces, not any form of whitespace.
/// </summary>
public static ReadOnlyMemory<char> TrimSpaces(ReadOnlyMemory<char> memory)
{
if (memory.IsEmpty)
return memory;
int ichLim = memory.Length;
int ichMin = 0;
var span = memory.Span;
if (span[ichMin] != ' ' && span[ichLim - 1] != ' ')
return memory;
while (ichMin < ichLim && span[ichMin] == ' ')
ichMin++;
while (ichMin < ichLim && span[ichLim - 1] == ' ')
ichLim--;
return memory.Slice(ichMin, ichLim - ichMin);
}
/// <summary>
/// Returns a <see cref="ReadOnlyMemory{T}"/> of <see cref="char"/> with leading and trailing whitespace trimmed.
/// </summary>
public static ReadOnlyMemory<char> TrimWhiteSpace(ReadOnlyMemory<char> memory)
{
if (memory.IsEmpty)
return memory;
int ichMin = 0;
int ichLim = memory.Length;
var span = memory.Span;
if (!char.IsWhiteSpace(span[ichMin]) && !char.IsWhiteSpace(span[ichLim - 1]))
return memory;
while (ichMin < ichLim && char.IsWhiteSpace(span[ichMin]))
ichMin++;
while (ichMin < ichLim && char.IsWhiteSpace(span[ichLim - 1]))
ichLim--;
return memory.Slice(ichMin, ichLim - ichMin);
}
/// <summary>
/// Returns a <see cref="ReadOnlyMemory{T}"/> of <see cref="char"/> with trailing whitespace trimmed.
/// </summary>
public static ReadOnlyMemory<char> TrimEndWhiteSpace(ReadOnlyMemory<char> memory)
{
if (memory.IsEmpty)
return memory;
int ichLim = memory.Length;
var span = memory.Span;
if (!char.IsWhiteSpace(span[ichLim - 1]))
return memory;
while (0 < ichLim && char.IsWhiteSpace(span[ichLim - 1]))
ichLim--;
return memory.Slice(0, ichLim);
}
public static void AddLowerCaseToStringBuilder(ReadOnlySpan<char> span, StringBuilder sb)
{
Contracts.CheckValue(sb, nameof(sb));
if (!span.IsEmpty)
{
int min = 0;
int j;
for (j = min; j < span.Length; j++)
{
char ch = CharUtils.ToLowerInvariant(span[j]);
if (ch != span[j])
{
sb.AppendSpan(span.Slice(min, j - min)).Append(ch);
min = j + 1;
}
yield return memory.Slice(ichCur);
yield break;
}
Contracts.Assert(j == span.Length);
if (min != j)
sb.AppendSpan(span.Slice(min, j - min));
yield return memory.Slice(ichCur, nextSep);
// Skip the separator.
ichCur += nextSep + 1;
span = memory.Slice(ichCur).Span;
}
}
public static StringBuilder AppendMemory(this StringBuilder sb, ReadOnlyMemory<char> memory)
else
{
Contracts.CheckValue(sb, nameof(sb));
if (!memory.IsEmpty)
sb.AppendSpan(memory.Span);
return sb;
}
public static StringBuilder AppendSpan(this StringBuilder sb, ReadOnlySpan<char> span)
{
unsafe
for (int ichCur = 0; ;)
{
fixed (char* valueChars = &MemoryMarshal.GetReference(span))
int nextSep = span.IndexOfAny(separators);
if (nextSep == -1)
{
sb.Append(valueChars, span.Length);
yield return memory.Slice(ichCur);
yield break;
}
}
return sb;
}
yield return memory.Slice(ichCur, nextSep);
public sealed class ReadonlyMemoryCharComparer : IEqualityComparer<ReadOnlyMemory<char>>
{
public bool Equals(ReadOnlyMemory<char> x, ReadOnlyMemory<char> y)
{
return x.Span.SequenceEqual(y.Span);
}
public int GetHashCode(ReadOnlyMemory<char> obj)
{
return (int)Hashing.HashString(obj.Span);
// Skip the separator.
ichCur += nextSep + 1;
span = memory.Slice(ichCur).Span;
}
}
}
/// <summary>
/// Splits <paramref name="memory"/> on the left-most occurrence of separator and produces the left
/// and right <see cref="ReadOnlyMemory{T}"/> of <see cref="char"/> values. If <paramref name="memory"/> does not contain the separator character,
/// this returns false and sets <paramref name="left"/> to this instance and <paramref name="right"/>
/// to the default <see cref="ReadOnlyMemory{T}"/> of <see cref="char"/> value.
/// </summary>
public static bool SplitOne(ReadOnlyMemory<char> memory, char separator, out ReadOnlyMemory<char> left, out ReadOnlyMemory<char> right)
{
if (memory.IsEmpty)
{
left = memory;
right = default;
return false;
}
int index = memory.Span.IndexOf(separator);
if (index == -1)
{
left = memory;
right = default;
return false;
}
left = memory.Slice(0, index);
right = memory.Slice(index + 1, memory.Length - index - 1);
return true;
}
/// <summary>
/// Splits <paramref name="memory"/> on the left-most occurrence of an element of separators character array and
/// produces the left and right <see cref="ReadOnlyMemory{T}"/> of <see cref="char"/> values. If <paramref name="memory"/> does not contain any of the
/// characters in separators, this return false and initializes <paramref name="left"/> to this instance
/// and <paramref name="right"/> to the default <see cref="ReadOnlyMemory{T}"/> of <see cref="char"/> value.
/// </summary>
public static bool SplitOne(ReadOnlyMemory<char> memory, char[] separators, out ReadOnlyMemory<char> left, out ReadOnlyMemory<char> right)
{
Contracts.CheckValueOrNull(separators);
if (memory.IsEmpty || separators == null || separators.Length == 0)
{
left = memory;
right = default;
return false;
}
int index;
if (separators.Length == 1)
index = memory.Span.IndexOf(separators[0]);
else
index = memory.Span.IndexOfAny(separators);
if (index == -1)
{
left = memory;
right = default;
return false;
}
left = memory.Slice(0, index);
right = memory.Slice(index + 1, memory.Length - index - 1);
return true;
}
/// <summary>
/// Returns a <see cref="ReadOnlyMemory{T}"/> of <see cref="char"/> with leading and trailing spaces trimmed. Note that this
/// will remove only spaces, not any form of whitespace.
/// </summary>
public static ReadOnlyMemory<char> TrimSpaces(ReadOnlyMemory<char> memory)
{
if (memory.IsEmpty)
return memory;
int ichLim = memory.Length;
int ichMin = 0;
var span = memory.Span;
if (span[ichMin] != ' ' && span[ichLim - 1] != ' ')
return memory;
while (ichMin < ichLim && span[ichMin] == ' ')
ichMin++;
while (ichMin < ichLim && span[ichLim - 1] == ' ')
ichLim--;
return memory.Slice(ichMin, ichLim - ichMin);
}
/// <summary>
/// Returns a <see cref="ReadOnlyMemory{T}"/> of <see cref="char"/> with leading and trailing whitespace trimmed.
/// </summary>
public static ReadOnlyMemory<char> TrimWhiteSpace(ReadOnlyMemory<char> memory)
{
if (memory.IsEmpty)
return memory;
int ichMin = 0;
int ichLim = memory.Length;
var span = memory.Span;
if (!char.IsWhiteSpace(span[ichMin]) && !char.IsWhiteSpace(span[ichLim - 1]))
return memory;
while (ichMin < ichLim && char.IsWhiteSpace(span[ichMin]))
ichMin++;
while (ichMin < ichLim && char.IsWhiteSpace(span[ichLim - 1]))
ichLim--;
return memory.Slice(ichMin, ichLim - ichMin);
}
/// <summary>
/// Returns a <see cref="ReadOnlyMemory{T}"/> of <see cref="char"/> with trailing whitespace trimmed.
/// </summary>
public static ReadOnlyMemory<char> TrimEndWhiteSpace(ReadOnlyMemory<char> memory)
{
if (memory.IsEmpty)
return memory;
int ichLim = memory.Length;
var span = memory.Span;
if (!char.IsWhiteSpace(span[ichLim - 1]))
return memory;
while (0 < ichLim && char.IsWhiteSpace(span[ichLim - 1]))
ichLim--;
return memory.Slice(0, ichLim);
}
public static void AddLowerCaseToStringBuilder(ReadOnlySpan<char> span, StringBuilder sb)
{
Contracts.CheckValue(sb, nameof(sb));
if (!span.IsEmpty)
{
int min = 0;
int j;
for (j = min; j < span.Length; j++)
{
char ch = CharUtils.ToLowerInvariant(span[j]);
if (ch != span[j])
{
sb.AppendSpan(span.Slice(min, j - min)).Append(ch);
min = j + 1;
}
}
Contracts.Assert(j == span.Length);
if (min != j)
sb.AppendSpan(span.Slice(min, j - min));
}
}
public static StringBuilder AppendMemory(this StringBuilder sb, ReadOnlyMemory<char> memory)
{
Contracts.CheckValue(sb, nameof(sb));
if (!memory.IsEmpty)
sb.AppendSpan(memory.Span);
return sb;
}
public static StringBuilder AppendSpan(this StringBuilder sb, ReadOnlySpan<char> span)
{
unsafe
{
fixed (char* valueChars = &MemoryMarshal.GetReference(span))
{
sb.Append(valueChars, span.Length);
}
}
return sb;
}
public sealed class ReadonlyMemoryCharComparer : IEqualityComparer<ReadOnlyMemory<char>>
{
public bool Equals(ReadOnlyMemory<char> x, ReadOnlyMemory<char> y)
{
return x.Span.SequenceEqual(y.Span);
}
public int GetHashCode(ReadOnlyMemory<char> obj)
{
return (int)Hashing.HashString(obj.Span);
}
}
}

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Просмотреть файл

@ -6,484 +6,483 @@ using System.Collections.Generic;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Runtime;
namespace Microsoft.ML.Data
namespace Microsoft.ML.Data;
/// <summary>
/// Encapsulates a <see cref="DataViewSchema"/> plus column role mapping information. The purpose of role mappings is to
/// provide information on what the intended usage is for. That is: while a given data view may have a column named
/// "Features", by itself that is insufficient: the trainer must be fed a role mapping that says that the role
/// mapping for features is filled by that "Features" column. This allows things like columns not named "Features"
/// to actually fill that role (as opposed to insisting on a hard coding, or having every trainer have to be
/// individually configured). Also, by being a one-to-many mapping, it is a way for learners that can consume
/// multiple features columns to consume that information.
///
/// This class has convenience fields for several common column roles (for example, <see cref="Feature"/>, <see
/// cref="Label"/>), but can hold an arbitrary set of column infos. The convenience fields are non-null if and only
/// if there is a unique column with the corresponding role. When there are no such columns or more than one such
/// column, the field is <c>null</c>. The <see cref="Has"/>, <see cref="HasUnique"/>, and <see cref="HasMultiple"/>
/// methods provide some cardinality information. Note that all columns assigned roles are guaranteed to be non-hidden
/// in this schema.
/// </summary>
/// <remarks>
/// Note that instances of this class are, like instances of <see cref="DataViewSchema"/>, immutable.
///
/// It is often the case that one wishes to bundle the actual data with the role mappings, not just the schema. For
/// that case, please use the <see cref="RoleMappedData"/> class.
///
/// Note that there is no need for components consuming a <see cref="RoleMappedData"/> or <see cref="RoleMappedSchema"/>
/// to make use of every defined mapping. Consuming components are also expected to ignore any <see cref="ColumnRole"/>
/// they do not handle. They may very well however complain if a mapping they wanted to see is not present, or the column(s)
/// mapped from the role are not of the form they require.
/// </remarks>
/// <seealso cref="ColumnRole"/>
/// <seealso cref="RoleMappedData"/>
[BestFriend]
internal sealed class RoleMappedSchema
{
private const string FeatureString = "Feature";
private const string LabelString = "Label";
private const string GroupString = "Group";
private const string WeightString = "Weight";
private const string NameString = "Name";
private const string FeatureContributionsString = "FeatureContributions";
/// <summary>
/// Encapsulates a <see cref="DataViewSchema"/> plus column role mapping information. The purpose of role mappings is to
/// provide information on what the intended usage is for. That is: while a given data view may have a column named
/// "Features", by itself that is insufficient: the trainer must be fed a role mapping that says that the role
/// mapping for features is filled by that "Features" column. This allows things like columns not named "Features"
/// to actually fill that role (as opposed to insisting on a hard coding, or having every trainer have to be
/// individually configured). Also, by being a one-to-many mapping, it is a way for learners that can consume
/// multiple features columns to consume that information.
///
/// This class has convenience fields for several common column roles (for example, <see cref="Feature"/>, <see
/// cref="Label"/>), but can hold an arbitrary set of column infos. The convenience fields are non-null if and only
/// if there is a unique column with the corresponding role. When there are no such columns or more than one such
/// column, the field is <c>null</c>. The <see cref="Has"/>, <see cref="HasUnique"/>, and <see cref="HasMultiple"/>
/// methods provide some cardinality information. Note that all columns assigned roles are guaranteed to be non-hidden
/// in this schema.
/// Instances of this are the keys of a <see cref="RoleMappedSchema"/>. This class also holds some important
/// commonly used pre-defined instances available (for example, <see cref="Label"/>, <see cref="Feature"/>) that should
/// be used when possible for consistency reasons. However, practitioners should not be afraid to declare custom
/// roles if approppriate for their task.
/// </summary>
/// <remarks>
/// Note that instances of this class are, like instances of <see cref="DataViewSchema"/>, immutable.
///
/// It is often the case that one wishes to bundle the actual data with the role mappings, not just the schema. For
/// that case, please use the <see cref="RoleMappedData"/> class.
///
/// Note that there is no need for components consuming a <see cref="RoleMappedData"/> or <see cref="RoleMappedSchema"/>
/// to make use of every defined mapping. Consuming components are also expected to ignore any <see cref="ColumnRole"/>
/// they do not handle. They may very well however complain if a mapping they wanted to see is not present, or the column(s)
/// mapped from the role are not of the form they require.
/// </remarks>
/// <seealso cref="ColumnRole"/>
/// <seealso cref="RoleMappedData"/>
[BestFriend]
internal sealed class RoleMappedSchema
public readonly struct ColumnRole
{
private const string FeatureString = "Feature";
private const string LabelString = "Label";
private const string GroupString = "Group";
private const string WeightString = "Weight";
private const string NameString = "Name";
private const string FeatureContributionsString = "FeatureContributions";
/// <summary>
/// Role for features. Commonly used as the independent variables given to trainers, and scorers.
/// </summary>
public static ColumnRole Feature => FeatureString;
/// <summary>
/// Instances of this are the keys of a <see cref="RoleMappedSchema"/>. This class also holds some important
/// commonly used pre-defined instances available (for example, <see cref="Label"/>, <see cref="Feature"/>) that should
/// be used when possible for consistency reasons. However, practitioners should not be afraid to declare custom
/// roles if approppriate for their task.
/// Role for labels. Commonly used as the dependent variables given to trainers, and evaluators.
/// </summary>
public readonly struct ColumnRole
public static ColumnRole Label => LabelString;
/// <summary>
/// Role for group ID. Commonly used in ranking applications, for defining query boundaries, or
/// sequence classification, for defining the boundaries of an utterance.
/// </summary>
public static ColumnRole Group => GroupString;
/// <summary>
/// Role for sample weights. Commonly used to point to a number to make trainers give more weight
/// to a particular example.
/// </summary>
public static ColumnRole Weight => WeightString;
/// <summary>
/// Role for sample names. Useful for informational and tracking purposes when scoring, but typically
/// without affecting results.
/// </summary>
public static ColumnRole Name => NameString;
// REVIEW: Does this really belong here?
/// <summary>
/// Role for feature contributions. Useful for specific diagnostic functionality.
/// </summary>
public static ColumnRole FeatureContributions => FeatureContributionsString;
/// <summary>
/// The string value for the role. Guaranteed to be non-empty.
/// </summary>
public readonly string Value;
/// <summary>
/// Constructor for the column role.
/// </summary>
/// <param name="value">The value for the role. Must be non-empty.</param>
public ColumnRole(string value)
{
/// <summary>
/// Role for features. Commonly used as the independent variables given to trainers, and scorers.
/// </summary>
public static ColumnRole Feature => FeatureString;
/// <summary>
/// Role for labels. Commonly used as the dependent variables given to trainers, and evaluators.
/// </summary>
public static ColumnRole Label => LabelString;
/// <summary>
/// Role for group ID. Commonly used in ranking applications, for defining query boundaries, or
/// sequence classification, for defining the boundaries of an utterance.
/// </summary>
public static ColumnRole Group => GroupString;
/// <summary>
/// Role for sample weights. Commonly used to point to a number to make trainers give more weight
/// to a particular example.
/// </summary>
public static ColumnRole Weight => WeightString;
/// <summary>
/// Role for sample names. Useful for informational and tracking purposes when scoring, but typically
/// without affecting results.
/// </summary>
public static ColumnRole Name => NameString;
// REVIEW: Does this really belong here?
/// <summary>
/// Role for feature contributions. Useful for specific diagnostic functionality.
/// </summary>
public static ColumnRole FeatureContributions => FeatureContributionsString;
/// <summary>
/// The string value for the role. Guaranteed to be non-empty.
/// </summary>
public readonly string Value;
/// <summary>
/// Constructor for the column role.
/// </summary>
/// <param name="value">The value for the role. Must be non-empty.</param>
public ColumnRole(string value)
{
Contracts.CheckNonEmpty(value, nameof(value));
Value = value;
}
public static implicit operator ColumnRole(string value)
=> new ColumnRole(value);
/// <summary>
/// Convenience method for creating a mapping pair from a role to a column name
/// for giving to constructors of <see cref="RoleMappedSchema"/> and <see cref="RoleMappedData"/>.
/// </summary>
/// <param name="name">The column name to map to. Can be <c>null</c>, in which case when used
/// to construct a role mapping structure this pair will be ignored</param>
/// <returns>A key-value pair with this instance as the key and <paramref name="name"/> as the value</returns>
public KeyValuePair<ColumnRole, string> Bind(string name)
=> new KeyValuePair<ColumnRole, string>(this, name);
Contracts.CheckNonEmpty(value, nameof(value));
Value = value;
}
public static KeyValuePair<ColumnRole, string> CreatePair(ColumnRole role, string name)
=> new KeyValuePair<ColumnRole, string>(role, name);
public static implicit operator ColumnRole(string value)
=> new ColumnRole(value);
/// <summary>
/// The source <see cref="Schema"/>.
/// Convenience method for creating a mapping pair from a role to a column name
/// for giving to constructors of <see cref="RoleMappedSchema"/> and <see cref="RoleMappedData"/>.
/// </summary>
public DataViewSchema Schema { get; }
/// <param name="name">The column name to map to. Can be <c>null</c>, in which case when used
/// to construct a role mapping structure this pair will be ignored</param>
/// <returns>A key-value pair with this instance as the key and <paramref name="name"/> as the value</returns>
public KeyValuePair<ColumnRole, string> Bind(string name)
=> new KeyValuePair<ColumnRole, string>(this, name);
}
/// <summary>
/// The <see cref="ColumnRole.Feature"/> column, when there is exactly one (null otherwise).
/// </summary>
public DataViewSchema.Column? Feature { get; }
public static KeyValuePair<ColumnRole, string> CreatePair(ColumnRole role, string name)
=> new KeyValuePair<ColumnRole, string>(role, name);
/// <summary>
/// The <see cref="ColumnRole.Label"/> column, when there is exactly one (null otherwise).
/// </summary>
public DataViewSchema.Column? Label { get; }
/// <summary>
/// The source <see cref="Schema"/>.
/// </summary>
public DataViewSchema Schema { get; }
/// <summary>
/// The <see cref="ColumnRole.Group"/> column, when there is exactly one (null otherwise).
/// </summary>
public DataViewSchema.Column? Group { get; }
/// <summary>
/// The <see cref="ColumnRole.Feature"/> column, when there is exactly one (null otherwise).
/// </summary>
public DataViewSchema.Column? Feature { get; }
/// <summary>
/// The <see cref="ColumnRole.Weight"/> column, when there is exactly one (null otherwise).
/// </summary>
public DataViewSchema.Column? Weight { get; }
/// <summary>
/// The <see cref="ColumnRole.Label"/> column, when there is exactly one (null otherwise).
/// </summary>
public DataViewSchema.Column? Label { get; }
/// <summary>
/// The <see cref="ColumnRole.Name"/> column, when there is exactly one (null otherwise).
/// </summary>
public DataViewSchema.Column? Name { get; }
/// <summary>
/// The <see cref="ColumnRole.Group"/> column, when there is exactly one (null otherwise).
/// </summary>
public DataViewSchema.Column? Group { get; }
// Maps from role to the associated column infos.
private readonly Dictionary<string, IReadOnlyList<DataViewSchema.Column>> _map;
/// <summary>
/// The <see cref="ColumnRole.Weight"/> column, when there is exactly one (null otherwise).
/// </summary>
public DataViewSchema.Column? Weight { get; }
private RoleMappedSchema(DataViewSchema schema, Dictionary<string, IReadOnlyList<DataViewSchema.Column>> map)
/// <summary>
/// The <see cref="ColumnRole.Name"/> column, when there is exactly one (null otherwise).
/// </summary>
public DataViewSchema.Column? Name { get; }
// Maps from role to the associated column infos.
private readonly Dictionary<string, IReadOnlyList<DataViewSchema.Column>> _map;
private RoleMappedSchema(DataViewSchema schema, Dictionary<string, IReadOnlyList<DataViewSchema.Column>> map)
{
Contracts.AssertValue(schema);
Contracts.AssertValue(map);
Schema = schema;
_map = map;
foreach (var kvp in _map)
{
Contracts.AssertValue(schema);
Contracts.AssertValue(map);
Schema = schema;
_map = map;
foreach (var kvp in _map)
{
Contracts.Assert(Utils.Size(kvp.Value) > 0);
var cols = kvp.Value;
Contracts.Assert(Utils.Size(kvp.Value) > 0);
var cols = kvp.Value;
#if DEBUG
foreach (var info in cols)
Contracts.Assert(!schema[info.Index].IsHidden, "How did a hidden column sneak in?");
foreach (var info in cols)
Contracts.Assert(!schema[info.Index].IsHidden, "How did a hidden column sneak in?");
#endif
if (cols.Count == 1)
if (cols.Count == 1)
{
switch (kvp.Key)
{
switch (kvp.Key)
{
case FeatureString:
Feature = cols[0];
break;
case LabelString:
Label = cols[0];
break;
case GroupString:
Group = cols[0];
break;
case WeightString:
Weight = cols[0];
break;
case NameString:
Name = cols[0];
break;
}
case FeatureString:
Feature = cols[0];
break;
case LabelString:
Label = cols[0];
break;
case GroupString:
Group = cols[0];
break;
case WeightString:
Weight = cols[0];
break;
case NameString:
Name = cols[0];
break;
}
}
}
}
private RoleMappedSchema(DataViewSchema schema, Dictionary<string, List<DataViewSchema.Column>> map)
: this(schema, Copy(map))
private RoleMappedSchema(DataViewSchema schema, Dictionary<string, List<DataViewSchema.Column>> map)
: this(schema, Copy(map))
{
}
private static void Add(Dictionary<string, List<DataViewSchema.Column>> map, ColumnRole role, DataViewSchema.Column column)
{
Contracts.AssertValue(map);
Contracts.AssertNonEmpty(role.Value);
if (!map.TryGetValue(role.Value, out var list))
{
list = new List<DataViewSchema.Column>();
map.Add(role.Value, list);
}
list.Add(column);
}
private static void Add(Dictionary<string, List<DataViewSchema.Column>> map, ColumnRole role, DataViewSchema.Column column)
private static Dictionary<string, List<DataViewSchema.Column>> MapFromNames(DataViewSchema schema, IEnumerable<KeyValuePair<ColumnRole, string>> roles, bool opt = false)
{
Contracts.AssertValue(schema);
Contracts.AssertValue(roles);
var map = new Dictionary<string, List<DataViewSchema.Column>>();
foreach (var kvp in roles)
{
Contracts.AssertValue(map);
Contracts.AssertNonEmpty(role.Value);
if (!map.TryGetValue(role.Value, out var list))
{
list = new List<DataViewSchema.Column>();
map.Add(role.Value, list);
}
list.Add(column);
Contracts.AssertNonEmpty(kvp.Key.Value);
if (string.IsNullOrEmpty(kvp.Value))
continue;
var info = schema.GetColumnOrNull(kvp.Value);
if (info.HasValue)
Add(map, kvp.Key.Value, info.Value);
else if (!opt)
throw Contracts.ExceptParam(nameof(schema), $"{kvp.Value} column '{kvp.Key.Value}' not found");
}
return map;
}
private static Dictionary<string, List<DataViewSchema.Column>> MapFromNames(DataViewSchema schema, IEnumerable<KeyValuePair<ColumnRole, string>> roles, bool opt = false)
/// <summary>
/// Returns whether there are any columns with the given column role.
/// </summary>
public bool Has(ColumnRole role)
=> _map.ContainsKey(role.Value);
/// <summary>
/// Returns whether there is exactly one column of the given role.
/// </summary>
public bool HasUnique(ColumnRole role)
=> _map.TryGetValue(role.Value, out var cols) && cols.Count == 1;
/// <summary>
/// Returns whether there are two or more columns of the given role.
/// </summary>
public bool HasMultiple(ColumnRole role)
=> _map.TryGetValue(role.Value, out var cols) && cols.Count > 1;
/// <summary>
/// If there are columns of the given role, this returns the infos as a readonly list. Otherwise,
/// it returns null.
/// </summary>
public IReadOnlyList<DataViewSchema.Column> GetColumns(ColumnRole role)
=> _map.TryGetValue(role.Value, out var list) ? list : null;
/// <summary>
/// An enumerable over all role-column associations within this object.
/// </summary>
public IEnumerable<KeyValuePair<ColumnRole, DataViewSchema.Column>> GetColumnRoles()
{
foreach (var roleAndList in _map)
{
Contracts.AssertValue(schema);
Contracts.AssertValue(roles);
var map = new Dictionary<string, List<DataViewSchema.Column>>();
foreach (var kvp in roles)
{
Contracts.AssertNonEmpty(kvp.Key.Value);
if (string.IsNullOrEmpty(kvp.Value))
continue;
var info = schema.GetColumnOrNull(kvp.Value);
if (info.HasValue)
Add(map, kvp.Key.Value, info.Value);
else if (!opt)
throw Contracts.ExceptParam(nameof(schema), $"{kvp.Value} column '{kvp.Key.Value}' not found");
}
return map;
}
/// <summary>
/// Returns whether there are any columns with the given column role.
/// </summary>
public bool Has(ColumnRole role)
=> _map.ContainsKey(role.Value);
/// <summary>
/// Returns whether there is exactly one column of the given role.
/// </summary>
public bool HasUnique(ColumnRole role)
=> _map.TryGetValue(role.Value, out var cols) && cols.Count == 1;
/// <summary>
/// Returns whether there are two or more columns of the given role.
/// </summary>
public bool HasMultiple(ColumnRole role)
=> _map.TryGetValue(role.Value, out var cols) && cols.Count > 1;
/// <summary>
/// If there are columns of the given role, this returns the infos as a readonly list. Otherwise,
/// it returns null.
/// </summary>
public IReadOnlyList<DataViewSchema.Column> GetColumns(ColumnRole role)
=> _map.TryGetValue(role.Value, out var list) ? list : null;
/// <summary>
/// An enumerable over all role-column associations within this object.
/// </summary>
public IEnumerable<KeyValuePair<ColumnRole, DataViewSchema.Column>> GetColumnRoles()
{
foreach (var roleAndList in _map)
{
foreach (var info in roleAndList.Value)
yield return new KeyValuePair<ColumnRole, DataViewSchema.Column>(roleAndList.Key, info);
}
}
/// <summary>
/// An enumerable over all role-column associations within this object.
/// </summary>
public IEnumerable<KeyValuePair<ColumnRole, string>> GetColumnRoleNames()
{
foreach (var roleAndList in _map)
{
foreach (var info in roleAndList.Value)
yield return new KeyValuePair<ColumnRole, string>(roleAndList.Key, info.Name);
}
}
/// <summary>
/// An enumerable over all role-column associations for the given role. This is a helper function
/// for implementing the <see cref="ISchemaBoundMapper.GetInputColumnRoles"/> method.
/// </summary>
public IEnumerable<KeyValuePair<ColumnRole, string>> GetColumnRoleNames(ColumnRole role)
{
if (_map.TryGetValue(role.Value, out var list))
{
foreach (var info in list)
yield return new KeyValuePair<ColumnRole, string>(role, info.Name);
}
}
/// <summary>
/// Returns the <see cref="DataViewSchema.Column"/> corresponding to <paramref name="role"/> if there is
/// exactly one such mapping, and otherwise throws an exception.
/// </summary>
/// <param name="role">The role to look up</param>
/// <returns>The column corresponding to that role, assuming there was only one column
/// mapped to that</returns>
public DataViewSchema.Column GetUniqueColumn(ColumnRole role)
{
var infos = GetColumns(role);
if (Utils.Size(infos) != 1)
throw Contracts.Except("Expected exactly one column with role '{0}', but found {1}.", role.Value, Utils.Size(infos));
return infos[0];
}
private static Dictionary<string, IReadOnlyList<DataViewSchema.Column>> Copy(Dictionary<string, List<DataViewSchema.Column>> map)
{
var copy = new Dictionary<string, IReadOnlyList<DataViewSchema.Column>>(map.Count);
foreach (var kvp in map)
{
Contracts.Assert(Utils.Size(kvp.Value) > 0);
var cols = kvp.Value.ToArray();
copy.Add(kvp.Key, cols);
}
return copy;
}
/// <summary>
/// Constructor given a schema, and mapping pairs of roles to columns in the schema.
/// This skips null or empty column-names. It will also skip column-names that are not
/// found in the schema if <paramref name="opt"/> is true.
/// </summary>
/// <param name="schema">The schema over which roles are defined</param>
/// <param name="opt">Whether to consider the column names specified "optional" or not. If <c>false</c> then any non-empty
/// values for the column names that does not appear in <paramref name="schema"/> will result in an exception being thrown,
/// but if <c>true</c> such values will be ignored</param>
/// <param name="roles">The column role to column name mappings</param>
public RoleMappedSchema(DataViewSchema schema, bool opt = false, params KeyValuePair<ColumnRole, string>[] roles)
: this(Contracts.CheckRef(schema, nameof(schema)), Contracts.CheckRef(roles, nameof(roles)), opt)
{
}
/// <summary>
/// Constructor given a schema, and mapping pairs of roles to columns in the schema.
/// This skips null or empty column names. It will also skip column-names that are not
/// found in the schema if <paramref name="opt"/> is true.
/// </summary>
/// <param name="schema">The schema over which roles are defined</param>
/// <param name="roles">The column role to column name mappings</param>
/// <param name="opt">Whether to consider the column names specified "optional" or not. If <c>false</c> then any non-empty
/// values for the column names that does not appear in <paramref name="schema"/> will result in an exception being thrown,
/// but if <c>true</c> such values will be ignored</param>
public RoleMappedSchema(DataViewSchema schema, IEnumerable<KeyValuePair<ColumnRole, string>> roles, bool opt = false)
: this(Contracts.CheckRef(schema, nameof(schema)),
MapFromNames(schema, Contracts.CheckRef(roles, nameof(roles)), opt))
{
}
private static IEnumerable<KeyValuePair<ColumnRole, string>> PredefinedRolesHelper(
string label, string feature, string group, string weight, string name,
IEnumerable<KeyValuePair<ColumnRole, string>> custom = null)
{
if (!string.IsNullOrWhiteSpace(label))
yield return ColumnRole.Label.Bind(label);
if (!string.IsNullOrWhiteSpace(feature))
yield return ColumnRole.Feature.Bind(feature);
if (!string.IsNullOrWhiteSpace(group))
yield return ColumnRole.Group.Bind(group);
if (!string.IsNullOrWhiteSpace(weight))
yield return ColumnRole.Weight.Bind(weight);
if (!string.IsNullOrWhiteSpace(name))
yield return ColumnRole.Name.Bind(name);
if (custom != null)
{
foreach (var role in custom)
yield return role;
}
}
/// <summary>
/// Convenience constructor for role-mappings over the commonly used roles. Note that if any column name specified
/// is <c>null</c> or whitespace, it is ignored.
/// </summary>
/// <param name="schema">The schema over which roles are defined</param>
/// <param name="label">The column name that will be mapped to the <see cref="ColumnRole.Label"/> role</param>
/// <param name="feature">The column name that will be mapped to the <see cref="ColumnRole.Feature"/> role</param>
/// <param name="group">The column name that will be mapped to the <see cref="ColumnRole.Group"/> role</param>
/// <param name="weight">The column name that will be mapped to the <see cref="ColumnRole.Weight"/> role</param>
/// <param name="name">The column name that will be mapped to the <see cref="ColumnRole.Name"/> role</param>
/// <param name="custom">Any additional desired custom column role mappings</param>
/// <param name="opt">Whether to consider the column names specified "optional" or not. If <c>false</c> then any non-empty
/// values for the column names that does not appear in <paramref name="schema"/> will result in an exception being thrown,
/// but if <c>true</c> such values will be ignored</param>
public RoleMappedSchema(DataViewSchema schema, string label, string feature,
string group = null, string weight = null, string name = null,
IEnumerable<KeyValuePair<RoleMappedSchema.ColumnRole, string>> custom = null, bool opt = false)
: this(Contracts.CheckRef(schema, nameof(schema)), PredefinedRolesHelper(label, feature, group, weight, name, custom), opt)
{
Contracts.CheckValueOrNull(label);
Contracts.CheckValueOrNull(feature);
Contracts.CheckValueOrNull(group);
Contracts.CheckValueOrNull(weight);
Contracts.CheckValueOrNull(name);
Contracts.CheckValueOrNull(custom);
foreach (var info in roleAndList.Value)
yield return new KeyValuePair<ColumnRole, DataViewSchema.Column>(roleAndList.Key, info);
}
}
/// <summary>
/// Encapsulates an <see cref="IDataView"/> plus a corresponding <see cref="RoleMappedSchema"/>.
/// Note that the schema of <see cref="RoleMappedSchema.Schema"/> of <see cref="Schema"/> is
/// guaranteed to equal the the <see cref="IDataView.Schema"/> of <see cref="Data"/>.
/// An enumerable over all role-column associations within this object.
/// </summary>
[BestFriend]
internal sealed class RoleMappedData
public IEnumerable<KeyValuePair<ColumnRole, string>> GetColumnRoleNames()
{
/// <summary>
/// The data.
/// </summary>
public IDataView Data { get; }
/// <summary>
/// The role mapped schema. Note that <see cref="Schema"/>'s <see cref="RoleMappedSchema.Schema"/> is
/// guaranteed to be the same as <see cref="Data"/>'s <see cref="IDataView.Schema"/>.
/// </summary>
public RoleMappedSchema Schema { get; }
private RoleMappedData(IDataView data, RoleMappedSchema schema)
foreach (var roleAndList in _map)
{
Contracts.AssertValue(data);
Contracts.AssertValue(schema);
Contracts.Assert(schema.Schema == data.Schema);
Data = data;
Schema = schema;
}
/// <summary>
/// Constructor given a data view, and mapping pairs of roles to columns in the data view's schema.
/// This skips null or empty column-names. It will also skip column-names that are not
/// found in the schema if <paramref name="opt"/> is true.
/// </summary>
/// <param name="data">The data over which roles are defined</param>
/// <param name="opt">Whether to consider the column names specified "optional" or not. If <c>false</c> then any non-empty
/// values for the column names that does not appear in <paramref name="data"/>'s schema will result in an exception being thrown,
/// but if <c>true</c> such values will be ignored</param>
/// <param name="roles">The column role to column name mappings</param>
public RoleMappedData(IDataView data, bool opt = false, params KeyValuePair<RoleMappedSchema.ColumnRole, string>[] roles)
: this(Contracts.CheckRef(data, nameof(data)), new RoleMappedSchema(data.Schema, Contracts.CheckRef(roles, nameof(roles)), opt))
{
}
/// <summary>
/// Constructor given a data view, and mapping pairs of roles to columns in the data view's schema.
/// This skips null or empty column-names. It will also skip column-names that are not
/// found in the schema if <paramref name="opt"/> is true.
/// </summary>
/// <param name="data">The schema over which roles are defined</param>
/// <param name="roles">The column role to column name mappings</param>
/// <param name="opt">Whether to consider the column names specified "optional" or not. If <c>false</c> then any non-empty
/// values for the column names that does not appear in <paramref name="data"/>'s schema will result in an exception being thrown,
/// but if <c>true</c> such values will be ignored</param>
public RoleMappedData(IDataView data, IEnumerable<KeyValuePair<RoleMappedSchema.ColumnRole, string>> roles, bool opt = false)
: this(Contracts.CheckRef(data, nameof(data)), new RoleMappedSchema(data.Schema, Contracts.CheckRef(roles, nameof(roles)), opt))
{
}
/// <summary>
/// Convenience constructor for role-mappings over the commonly used roles. Note that if any column name specified
/// is <c>null</c> or whitespace, it is ignored.
/// </summary>
/// <param name="data">The data over which roles are defined</param>
/// <param name="label">The column name that will be mapped to the <see cref="RoleMappedSchema.ColumnRole.Label"/> role</param>
/// <param name="feature">The column name that will be mapped to the <see cref="RoleMappedSchema.ColumnRole.Feature"/> role</param>
/// <param name="group">The column name that will be mapped to the <see cref="RoleMappedSchema.ColumnRole.Group"/> role</param>
/// <param name="weight">The column name that will be mapped to the <see cref="RoleMappedSchema.ColumnRole.Weight"/> role</param>
/// <param name="name">The column name that will be mapped to the <see cref="RoleMappedSchema.ColumnRole.Name"/> role</param>
/// <param name="custom">Any additional desired custom column role mappings</param>
/// <param name="opt">Whether to consider the column names specified "optional" or not. If <c>false</c> then any non-empty
/// values for the column names that does not appear in <paramref name="data"/>'s schema will result in an exception being thrown,
/// but if <c>true</c> such values will be ignored</param>
public RoleMappedData(IDataView data, string label, string feature,
string group = null, string weight = null, string name = null,
IEnumerable<KeyValuePair<RoleMappedSchema.ColumnRole, string>> custom = null, bool opt = false)
: this(Contracts.CheckRef(data, nameof(data)),
new RoleMappedSchema(data.Schema, label, feature, group, weight, name, custom, opt))
{
Contracts.CheckValueOrNull(label);
Contracts.CheckValueOrNull(feature);
Contracts.CheckValueOrNull(group);
Contracts.CheckValueOrNull(weight);
Contracts.CheckValueOrNull(name);
Contracts.CheckValueOrNull(custom);
foreach (var info in roleAndList.Value)
yield return new KeyValuePair<ColumnRole, string>(roleAndList.Key, info.Name);
}
}
/// <summary>
/// An enumerable over all role-column associations for the given role. This is a helper function
/// for implementing the <see cref="ISchemaBoundMapper.GetInputColumnRoles"/> method.
/// </summary>
public IEnumerable<KeyValuePair<ColumnRole, string>> GetColumnRoleNames(ColumnRole role)
{
if (_map.TryGetValue(role.Value, out var list))
{
foreach (var info in list)
yield return new KeyValuePair<ColumnRole, string>(role, info.Name);
}
}
/// <summary>
/// Returns the <see cref="DataViewSchema.Column"/> corresponding to <paramref name="role"/> if there is
/// exactly one such mapping, and otherwise throws an exception.
/// </summary>
/// <param name="role">The role to look up</param>
/// <returns>The column corresponding to that role, assuming there was only one column
/// mapped to that</returns>
public DataViewSchema.Column GetUniqueColumn(ColumnRole role)
{
var infos = GetColumns(role);
if (Utils.Size(infos) != 1)
throw Contracts.Except("Expected exactly one column with role '{0}', but found {1}.", role.Value, Utils.Size(infos));
return infos[0];
}
private static Dictionary<string, IReadOnlyList<DataViewSchema.Column>> Copy(Dictionary<string, List<DataViewSchema.Column>> map)
{
var copy = new Dictionary<string, IReadOnlyList<DataViewSchema.Column>>(map.Count);
foreach (var kvp in map)
{
Contracts.Assert(Utils.Size(kvp.Value) > 0);
var cols = kvp.Value.ToArray();
copy.Add(kvp.Key, cols);
}
return copy;
}
/// <summary>
/// Constructor given a schema, and mapping pairs of roles to columns in the schema.
/// This skips null or empty column-names. It will also skip column-names that are not
/// found in the schema if <paramref name="opt"/> is true.
/// </summary>
/// <param name="schema">The schema over which roles are defined</param>
/// <param name="opt">Whether to consider the column names specified "optional" or not. If <c>false</c> then any non-empty
/// values for the column names that does not appear in <paramref name="schema"/> will result in an exception being thrown,
/// but if <c>true</c> such values will be ignored</param>
/// <param name="roles">The column role to column name mappings</param>
public RoleMappedSchema(DataViewSchema schema, bool opt = false, params KeyValuePair<ColumnRole, string>[] roles)
: this(Contracts.CheckRef(schema, nameof(schema)), Contracts.CheckRef(roles, nameof(roles)), opt)
{
}
/// <summary>
/// Constructor given a schema, and mapping pairs of roles to columns in the schema.
/// This skips null or empty column names. It will also skip column-names that are not
/// found in the schema if <paramref name="opt"/> is true.
/// </summary>
/// <param name="schema">The schema over which roles are defined</param>
/// <param name="roles">The column role to column name mappings</param>
/// <param name="opt">Whether to consider the column names specified "optional" or not. If <c>false</c> then any non-empty
/// values for the column names that does not appear in <paramref name="schema"/> will result in an exception being thrown,
/// but if <c>true</c> such values will be ignored</param>
public RoleMappedSchema(DataViewSchema schema, IEnumerable<KeyValuePair<ColumnRole, string>> roles, bool opt = false)
: this(Contracts.CheckRef(schema, nameof(schema)),
MapFromNames(schema, Contracts.CheckRef(roles, nameof(roles)), opt))
{
}
private static IEnumerable<KeyValuePair<ColumnRole, string>> PredefinedRolesHelper(
string label, string feature, string group, string weight, string name,
IEnumerable<KeyValuePair<ColumnRole, string>> custom = null)
{
if (!string.IsNullOrWhiteSpace(label))
yield return ColumnRole.Label.Bind(label);
if (!string.IsNullOrWhiteSpace(feature))
yield return ColumnRole.Feature.Bind(feature);
if (!string.IsNullOrWhiteSpace(group))
yield return ColumnRole.Group.Bind(group);
if (!string.IsNullOrWhiteSpace(weight))
yield return ColumnRole.Weight.Bind(weight);
if (!string.IsNullOrWhiteSpace(name))
yield return ColumnRole.Name.Bind(name);
if (custom != null)
{
foreach (var role in custom)
yield return role;
}
}
/// <summary>
/// Convenience constructor for role-mappings over the commonly used roles. Note that if any column name specified
/// is <c>null</c> or whitespace, it is ignored.
/// </summary>
/// <param name="schema">The schema over which roles are defined</param>
/// <param name="label">The column name that will be mapped to the <see cref="ColumnRole.Label"/> role</param>
/// <param name="feature">The column name that will be mapped to the <see cref="ColumnRole.Feature"/> role</param>
/// <param name="group">The column name that will be mapped to the <see cref="ColumnRole.Group"/> role</param>
/// <param name="weight">The column name that will be mapped to the <see cref="ColumnRole.Weight"/> role</param>
/// <param name="name">The column name that will be mapped to the <see cref="ColumnRole.Name"/> role</param>
/// <param name="custom">Any additional desired custom column role mappings</param>
/// <param name="opt">Whether to consider the column names specified "optional" or not. If <c>false</c> then any non-empty
/// values for the column names that does not appear in <paramref name="schema"/> will result in an exception being thrown,
/// but if <c>true</c> such values will be ignored</param>
public RoleMappedSchema(DataViewSchema schema, string label, string feature,
string group = null, string weight = null, string name = null,
IEnumerable<KeyValuePair<RoleMappedSchema.ColumnRole, string>> custom = null, bool opt = false)
: this(Contracts.CheckRef(schema, nameof(schema)), PredefinedRolesHelper(label, feature, group, weight, name, custom), opt)
{
Contracts.CheckValueOrNull(label);
Contracts.CheckValueOrNull(feature);
Contracts.CheckValueOrNull(group);
Contracts.CheckValueOrNull(weight);
Contracts.CheckValueOrNull(name);
Contracts.CheckValueOrNull(custom);
}
}
/// <summary>
/// Encapsulates an <see cref="IDataView"/> plus a corresponding <see cref="RoleMappedSchema"/>.
/// Note that the schema of <see cref="RoleMappedSchema.Schema"/> of <see cref="Schema"/> is
/// guaranteed to equal the the <see cref="IDataView.Schema"/> of <see cref="Data"/>.
/// </summary>
[BestFriend]
internal sealed class RoleMappedData
{
/// <summary>
/// The data.
/// </summary>
public IDataView Data { get; }
/// <summary>
/// The role mapped schema. Note that <see cref="Schema"/>'s <see cref="RoleMappedSchema.Schema"/> is
/// guaranteed to be the same as <see cref="Data"/>'s <see cref="IDataView.Schema"/>.
/// </summary>
public RoleMappedSchema Schema { get; }
private RoleMappedData(IDataView data, RoleMappedSchema schema)
{
Contracts.AssertValue(data);
Contracts.AssertValue(schema);
Contracts.Assert(schema.Schema == data.Schema);
Data = data;
Schema = schema;
}
/// <summary>
/// Constructor given a data view, and mapping pairs of roles to columns in the data view's schema.
/// This skips null or empty column-names. It will also skip column-names that are not
/// found in the schema if <paramref name="opt"/> is true.
/// </summary>
/// <param name="data">The data over which roles are defined</param>
/// <param name="opt">Whether to consider the column names specified "optional" or not. If <c>false</c> then any non-empty
/// values for the column names that does not appear in <paramref name="data"/>'s schema will result in an exception being thrown,
/// but if <c>true</c> such values will be ignored</param>
/// <param name="roles">The column role to column name mappings</param>
public RoleMappedData(IDataView data, bool opt = false, params KeyValuePair<RoleMappedSchema.ColumnRole, string>[] roles)
: this(Contracts.CheckRef(data, nameof(data)), new RoleMappedSchema(data.Schema, Contracts.CheckRef(roles, nameof(roles)), opt))
{
}
/// <summary>
/// Constructor given a data view, and mapping pairs of roles to columns in the data view's schema.
/// This skips null or empty column-names. It will also skip column-names that are not
/// found in the schema if <paramref name="opt"/> is true.
/// </summary>
/// <param name="data">The schema over which roles are defined</param>
/// <param name="roles">The column role to column name mappings</param>
/// <param name="opt">Whether to consider the column names specified "optional" or not. If <c>false</c> then any non-empty
/// values for the column names that does not appear in <paramref name="data"/>'s schema will result in an exception being thrown,
/// but if <c>true</c> such values will be ignored</param>
public RoleMappedData(IDataView data, IEnumerable<KeyValuePair<RoleMappedSchema.ColumnRole, string>> roles, bool opt = false)
: this(Contracts.CheckRef(data, nameof(data)), new RoleMappedSchema(data.Schema, Contracts.CheckRef(roles, nameof(roles)), opt))
{
}
/// <summary>
/// Convenience constructor for role-mappings over the commonly used roles. Note that if any column name specified
/// is <c>null</c> or whitespace, it is ignored.
/// </summary>
/// <param name="data">The data over which roles are defined</param>
/// <param name="label">The column name that will be mapped to the <see cref="RoleMappedSchema.ColumnRole.Label"/> role</param>
/// <param name="feature">The column name that will be mapped to the <see cref="RoleMappedSchema.ColumnRole.Feature"/> role</param>
/// <param name="group">The column name that will be mapped to the <see cref="RoleMappedSchema.ColumnRole.Group"/> role</param>
/// <param name="weight">The column name that will be mapped to the <see cref="RoleMappedSchema.ColumnRole.Weight"/> role</param>
/// <param name="name">The column name that will be mapped to the <see cref="RoleMappedSchema.ColumnRole.Name"/> role</param>
/// <param name="custom">Any additional desired custom column role mappings</param>
/// <param name="opt">Whether to consider the column names specified "optional" or not. If <c>false</c> then any non-empty
/// values for the column names that does not appear in <paramref name="data"/>'s schema will result in an exception being thrown,
/// but if <c>true</c> such values will be ignored</param>
public RoleMappedData(IDataView data, string label, string feature,
string group = null, string weight = null, string name = null,
IEnumerable<KeyValuePair<RoleMappedSchema.ColumnRole, string>> custom = null, bool opt = false)
: this(Contracts.CheckRef(data, nameof(data)),
new RoleMappedSchema(data.Schema, label, feature, group, weight, name, custom, opt))
{
Contracts.CheckValueOrNull(label);
Contracts.CheckValueOrNull(feature);
Contracts.CheckValueOrNull(group);
Contracts.CheckValueOrNull(weight);
Contracts.CheckValueOrNull(name);
Contracts.CheckValueOrNull(custom);
}
}

Просмотреть файл

@ -4,78 +4,77 @@
using Microsoft.ML.Runtime;
namespace Microsoft.ML.Data
namespace Microsoft.ML.Data;
// REVIEW: Since each cursor will create a channel, it would be great that the RootCursorBase takes
// ownership of the channel so the derived classes don't have to.
/// <summary>
/// Base class for creating a cursor with default tracking of <see cref="Position"/>. All calls to <see cref="MoveNext"/>
/// will be seen by subclasses of this cursor. For a cursor that has an input cursor and does not need notification on
/// <see cref="MoveNext"/>, use <see cref="SynchronizedCursorBase"/> instead.
/// </summary>
[BestFriend]
internal abstract class RootCursorBase : DataViewRowCursor
{
// REVIEW: Since each cursor will create a channel, it would be great that the RootCursorBase takes
// ownership of the channel so the derived classes don't have to.
protected readonly IChannel Ch;
private long _position;
private bool _disposed;
/// <summary>
/// Base class for creating a cursor with default tracking of <see cref="Position"/>. All calls to <see cref="MoveNext"/>
/// will be seen by subclasses of this cursor. For a cursor that has an input cursor and does not need notification on
/// <see cref="MoveNext"/>, use <see cref="SynchronizedCursorBase"/> instead.
/// Zero-based position of the cursor.
/// </summary>
[BestFriend]
internal abstract class RootCursorBase : DataViewRowCursor
public sealed override long Position => _position;
/// <summary>
/// Convenience property for checking whether the current state of the cursor is one where data can be fetched.
/// </summary>
protected bool IsGood => _position >= 0;
/// <summary>
/// Creates an instance of the <see cref="RootCursorBase"/> class
/// </summary>
/// <param name="provider">Channel provider</param>
protected RootCursorBase(IChannelProvider provider)
{
protected readonly IChannel Ch;
private long _position;
private bool _disposed;
Contracts.CheckValue(provider, nameof(provider));
Ch = provider.Start("Cursor");
/// <summary>
/// Zero-based position of the cursor.
/// </summary>
public sealed override long Position => _position;
_position = -1;
}
/// <summary>
/// Convenience property for checking whether the current state of the cursor is one where data can be fetched.
/// </summary>
protected bool IsGood => _position >= 0;
/// <summary>
/// Creates an instance of the <see cref="RootCursorBase"/> class
/// </summary>
/// <param name="provider">Channel provider</param>
protected RootCursorBase(IChannelProvider provider)
protected override void Dispose(bool disposing)
{
if (_disposed)
return;
if (disposing)
{
Contracts.CheckValue(provider, nameof(provider));
Ch = provider.Start("Cursor");
Ch.Dispose();
_position = -1;
}
_disposed = true;
base.Dispose(disposing);
protected override void Dispose(bool disposing)
{
if (_disposed)
return;
if (disposing)
{
Ch.Dispose();
_position = -1;
}
_disposed = true;
base.Dispose(disposing);
}
public sealed override bool MoveNext()
{
if (_disposed)
return false;
if (MoveNextCore())
{
_position++;
return true;
}
Dispose();
return false;
}
/// <summary>
/// Core implementation of <see cref="MoveNext"/>, called if no prior call to this method
/// has returned <see langword="false"/>.
/// </summary>
protected abstract bool MoveNextCore();
}
public sealed override bool MoveNext()
{
if (_disposed)
return false;
if (MoveNextCore())
{
_position++;
return true;
}
Dispose();
return false;
}
/// <summary>
/// Core implementation of <see cref="MoveNext"/>, called if no prior call to this method
/// has returned <see langword="false"/>.
/// </summary>
protected abstract bool MoveNextCore();
}

Просмотреть файл

@ -4,26 +4,25 @@
using System.Collections.Generic;
namespace Microsoft.ML.Data
{
[BestFriend]
internal static class SchemaExtensions
{
public static DataViewSchema MakeSchema(IEnumerable<DataViewSchema.DetachedColumn> columns)
{
var builder = new DataViewSchema.Builder();
builder.AddColumns(columns);
return builder.ToSchema();
}
namespace Microsoft.ML.Data;
/// <summary>
/// Legacy method to get the column index.
/// DO NOT USE: use <see cref="DataViewSchema.GetColumnOrNull"/> instead.
/// </summary>
public static bool TryGetColumnIndex(this DataViewSchema schema, string name, out int col)
{
col = schema.GetColumnOrNull(name)?.Index ?? -1;
return col >= 0;
}
[BestFriend]
internal static class SchemaExtensions
{
public static DataViewSchema MakeSchema(IEnumerable<DataViewSchema.DetachedColumn> columns)
{
var builder = new DataViewSchema.Builder();
builder.AddColumns(columns);
return builder.ToSchema();
}
/// <summary>
/// Legacy method to get the column index.
/// DO NOT USE: use <see cref="DataViewSchema.GetColumnOrNull"/> instead.
/// </summary>
public static bool TryGetColumnIndex(this DataViewSchema schema, string name, out int col)
{
col = schema.GetColumnOrNull(name)?.Index ?? -1;
return col >= 0;
}
}

Просмотреть файл

@ -6,266 +6,265 @@ using System;
using System.Collections.Generic;
using Microsoft.ML.EntryPoints;
namespace Microsoft.ML.Runtime
namespace Microsoft.ML.Runtime;
/// <summary>
/// Instances of this class are used to set up a bundle of named delegates. These
/// delegates are registered through <see cref="Register{TRet}"/> and its overloads.
/// Once all registrations are done, <see cref="Publish"/> is called and a message
/// of type <see cref="Bundle"/> is sent through the input channel
/// provider. The intended use case is that any information surfaced through these
/// delegates will be published in some fashion, with the target scenario being
/// that the library will publish some sort of restful API.
/// </summary>
[BestFriend]
internal sealed class ServerChannel : ServerChannel.IPendingBundleNotification, IDisposable
{
// See ServerChannel.md for a more elaborate discussion of high level usage and design.
private readonly IChannelProvider _chp;
private readonly string _identifier;
// This holds the running collection of named delegates, if any. The dictionary itself
// is lazily initialized only when a listener
private Dictionary<string, Delegate> _toPublish;
private Action<Bundle> _onPublish;
private Bundle _published;
private bool _disposed;
/// <summary>
/// Instances of this class are used to set up a bundle of named delegates. These
/// delegates are registered through <see cref="Register{TRet}"/> and its overloads.
/// Once all registrations are done, <see cref="Publish"/> is called and a message
/// of type <see cref="Bundle"/> is sent through the input channel
/// provider. The intended use case is that any information surfaced through these
/// delegates will be published in some fashion, with the target scenario being
/// that the library will publish some sort of restful API.
/// Returns either this object, or <c>null</c> if there are no listeners on this server
/// channel. This can be used in conjunction with the <c>?.</c> operator to have more
/// performant though more robust calls to <see cref="Register{TRet}"/> and
/// <see cref="Publish"/>.
/// </summary>
[BestFriend]
internal sealed class ServerChannel : ServerChannel.IPendingBundleNotification, IDisposable
private ServerChannel ThisIfActiveOrNull => _toPublish == null ? null : this;
private ServerChannel(IChannelProvider provider, string idenfier)
{
// See ServerChannel.md for a more elaborate discussion of high level usage and design.
private readonly IChannelProvider _chp;
private readonly string _identifier;
Contracts.AssertValue(provider);
_chp = provider;
_chp.AssertNonWhiteSpace(idenfier);
_identifier = idenfier;
}
// This holds the running collection of named delegates, if any. The dictionary itself
// is lazily initialized only when a listener
private Dictionary<string, Delegate> _toPublish;
private Action<Bundle> _onPublish;
private Bundle _published;
private bool _disposed;
/// <summary>
/// Returns either this object, or <c>null</c> if there are no listeners on this server
/// channel. This can be used in conjunction with the <c>?.</c> operator to have more
/// performant though more robust calls to <see cref="Register{TRet}"/> and
/// <see cref="Publish"/>.
/// </summary>
private ServerChannel ThisIfActiveOrNull => _toPublish == null ? null : this;
private ServerChannel(IChannelProvider provider, string idenfier)
/// <summary>
/// Starts a new server channel.
/// </summary>
/// <param name="provider">The channel provider, on which to send
/// the notification that a server is being constructed</param>
/// <param name="identifier">A semi-unique identifier for this
/// "bundle" that is being constructed</param>
/// <returns>The constructed server channel, or <c>null</c> if there
/// was no listeners for server channels registered on <paramref name="provider"/></returns>
public static ServerChannel Start(IChannelProvider provider, string identifier)
{
Contracts.CheckValue(provider, nameof(provider));
provider.CheckNonWhiteSpace(identifier, nameof(identifier));
using (var pipe = provider.StartPipe<IPendingBundleNotification>("Server"))
{
Contracts.AssertValue(provider);
_chp = provider;
_chp.AssertNonWhiteSpace(idenfier);
_identifier = idenfier;
}
/// <summary>
/// Starts a new server channel.
/// </summary>
/// <param name="provider">The channel provider, on which to send
/// the notification that a server is being constructed</param>
/// <param name="identifier">A semi-unique identifier for this
/// "bundle" that is being constructed</param>
/// <returns>The constructed server channel, or <c>null</c> if there
/// was no listeners for server channels registered on <paramref name="provider"/></returns>
public static ServerChannel Start(IChannelProvider provider, string identifier)
{
Contracts.CheckValue(provider, nameof(provider));
provider.CheckNonWhiteSpace(identifier, nameof(identifier));
using (var pipe = provider.StartPipe<IPendingBundleNotification>("Server"))
{
var sc = new ServerChannel(provider, identifier);
pipe.Send(sc);
return sc.ThisIfActiveOrNull;
}
}
public void Dispose()
{
if (!_disposed)
{
_disposed = true;
_published?.Done();
}
}
private void RegisterCore(string name, Delegate func)
{
_chp.CheckNonEmpty(name, nameof(name));
_chp.CheckValue(func, nameof(func));
_chp.Check(_published == null, "Cannot expose more interfaces once a server channel has been published");
_chp.AssertValue(_toPublish);
_toPublish.Add(name, func);
}
public void Register<TRet>(string name, Func<TRet> func)
{
if (_toPublish != null)
RegisterCore(name, func);
}
public void Register<T1, TRet>(string name, Func<T1, TRet> func)
{
if (_toPublish != null)
RegisterCore(name, func);
}
public void Register<T1, T2, TRet>(string name, Func<T1, T2, TRet> func)
{
if (_toPublish != null)
RegisterCore(name, func);
}
public void Register<T1, T2, T3, TRet>(string name, Func<T1, T2, T3, TRet> func)
{
if (_toPublish != null)
RegisterCore(name, func);
}
/// <summary>
/// Finalizes all registrations of delegates, and pipes the bundle of objects
/// in a <see cref="Bundle"/> up through the pipe to be consumed by any
/// listeners.
/// </summary>
public void Publish()
{
_chp.Assert((_toPublish == null) == (_onPublish == null));
if (_toPublish == null)
return;
_chp.Check(_published == null, "Cannot republish once a server channel has been published");
_published = new Bundle(this);
_onPublish(_published);
}
public void Acknowledge(Action<Bundle> toDo)
{
_chp.CheckValue(toDo, nameof(toDo));
_chp.Assert((_onPublish == null) == (_toPublish == null));
if (_toPublish == null)
_toPublish = new Dictionary<string, Delegate>();
_onPublish += toDo;
_chp.AssertValue(_onPublish);
}
/// <summary>
/// Entry point factory for creating <see cref="IServer"/> instances.
/// </summary>
[TlcModule.ComponentKind("Server")]
public interface IServerFactory : IComponentFactory<IChannel, IServer>
{
new IServer CreateComponent(IHostEnvironment env, IChannel ch);
}
/// <summary>
/// Classes that want to publish the bundles from server channels in some fashion should implement
/// this interface. The intended simple use case is that this will be some form of in-process web
/// server, and then when disposed, they should stop themselves.
///
/// Note that the primary communication with the server from the client code's perspective is not
/// through method calls on this interface, but rather communication through an
/// <see cref="IPipe{IPendingBundleNotification}"/> that the server will listen to throughout its
/// lifetime.
/// </summary>
public interface IServer : IDisposable
{
/// <summary>
/// This should return the base address where the server is. If this server is not actually
/// serving content at any URL, this property should be null.
/// </summary>
Uri BaseAddress { get; }
}
/// <summary>
/// Creates what might be considered a good "default" server factory, if possible,
/// or <c>null</c> if no good default was possible. A <c>null</c> value could be returned,
/// for example, if a user opted to remove all implementations of <see cref="IServer"/> and
/// the associated <see cref="IServerFactory"/> for security reasons.
/// </summary>
public static IServerFactory CreateDefaultServerFactoryOrNull(IHostEnvironment env)
{
Contracts.CheckValue(env, nameof(env));
// REVIEW: There should be a better way. There currently isn't,
// but there should be. This is pretty horrifying, but it is preferable to
// the alternative of having core components depend on an actual server
// implementation, since we want those to be removable because of security
// concerns in certain environments (since not everyone will be wild about
// web servers popping up everywhere).
ComponentCatalog.ComponentInfo component;
if (!env.ComponentCatalog.TryFindComponent(typeof(IServerFactory), "mini", out component))
return null;
IServerFactory factory = (IServerFactory)Activator.CreateInstance(component.ArgumentType);
var field = factory.GetType().GetField("Port");
if (field?.FieldType != typeof(int))
return null;
field.SetValue(factory, 12345);
return factory;
}
/// <summary>
/// When a <see cref="ServerChannel"/> is created, the creation method will send an implementation
/// is a notification sent through an <see cref="IPipe{IPendingBundleNotification}"/>, to indicate that
/// a <see cref="Bundle"/> may be pending soon. Listeners that want to receive the bundle to
/// expose it, for example, a web service, should register this interest by passing in an action to be called.
/// If no listener registers interest, the server channel that sent the notification will act
/// differently by, say, acting as a no-op w.r.t. client calls to it.
/// </summary>
public interface IPendingBundleNotification
{
/// <summary>
/// Any publisher of the named delegates will call this method, upon receiving an instance
/// of this object through the pipe. This method serves two purposes: firstly it detects
/// whether anyone is even interested in publishing anything at all, so that we can just
/// ignore any input delegates in the case where no one is listening (which, we must expect,
/// is the majority of scenarios). The second is that it provides an action to call, once
/// all publishing is complete, and <see cref="Publish"/> has been called by the client code.
/// </summary>
/// <param name="toDo">The callback to perform when all named delegates have been registered,
/// and <see cref="Publish"/> is called.</param>
void Acknowledge(Action<Bundle> toDo);
}
/// <summary>
/// The final bundle of published named delegates that a listener can serve.
/// </summary>
public sealed class Bundle
{
/// <summary>
/// This contains a name to delegate mappings. The delegates contained herein are gauranteed to be
/// some variety of <see cref="Func{TResult}"/>, <see cref="Func{T1, TResult}"/>,
/// <see cref="Func{T1, T2, TResult}"/>, etc.
/// </summary>
public readonly IReadOnlyDictionary<string, Delegate> NameToFuncs;
/// <summary>
/// This should be a more-or-less unique identifier for the type of API this bundle is producing.
/// Its intended use is that it will form part of the URL for the RESTful API, so to the extent that
/// it contains multiple tokens they must be slash delimited.
/// </summary>
public readonly string Identifier;
internal Action Done;
internal Bundle(ServerChannel sch)
{
Contracts.AssertValue(sch);
NameToFuncs = sch._toPublish;
Identifier = sch._identifier;
}
public void AddDoneAction(Action onDone)
{
Done += onDone;
}
var sc = new ServerChannel(provider, identifier);
pipe.Send(sc);
return sc.ThisIfActiveOrNull;
}
}
[BestFriend]
internal static class ServerChannelUtilities
public void Dispose()
{
if (!_disposed)
{
_disposed = true;
_published?.Done();
}
}
private void RegisterCore(string name, Delegate func)
{
_chp.CheckNonEmpty(name, nameof(name));
_chp.CheckValue(func, nameof(func));
_chp.Check(_published == null, "Cannot expose more interfaces once a server channel has been published");
_chp.AssertValue(_toPublish);
_toPublish.Add(name, func);
}
public void Register<TRet>(string name, Func<TRet> func)
{
if (_toPublish != null)
RegisterCore(name, func);
}
public void Register<T1, TRet>(string name, Func<T1, TRet> func)
{
if (_toPublish != null)
RegisterCore(name, func);
}
public void Register<T1, T2, TRet>(string name, Func<T1, T2, TRet> func)
{
if (_toPublish != null)
RegisterCore(name, func);
}
public void Register<T1, T2, T3, TRet>(string name, Func<T1, T2, T3, TRet> func)
{
if (_toPublish != null)
RegisterCore(name, func);
}
/// <summary>
/// Finalizes all registrations of delegates, and pipes the bundle of objects
/// in a <see cref="Bundle"/> up through the pipe to be consumed by any
/// listeners.
/// </summary>
public void Publish()
{
_chp.Assert((_toPublish == null) == (_onPublish == null));
if (_toPublish == null)
return;
_chp.Check(_published == null, "Cannot republish once a server channel has been published");
_published = new Bundle(this);
_onPublish(_published);
}
public void Acknowledge(Action<Bundle> toDo)
{
_chp.CheckValue(toDo, nameof(toDo));
_chp.Assert((_onPublish == null) == (_toPublish == null));
if (_toPublish == null)
_toPublish = new Dictionary<string, Delegate>();
_onPublish += toDo;
_chp.AssertValue(_onPublish);
}
/// <summary>
/// Entry point factory for creating <see cref="IServer"/> instances.
/// </summary>
[TlcModule.ComponentKind("Server")]
public interface IServerFactory : IComponentFactory<IChannel, IServer>
{
new IServer CreateComponent(IHostEnvironment env, IChannel ch);
}
/// <summary>
/// Classes that want to publish the bundles from server channels in some fashion should implement
/// this interface. The intended simple use case is that this will be some form of in-process web
/// server, and then when disposed, they should stop themselves.
///
/// Note that the primary communication with the server from the client code's perspective is not
/// through method calls on this interface, but rather communication through an
/// <see cref="IPipe{IPendingBundleNotification}"/> that the server will listen to throughout its
/// lifetime.
/// </summary>
public interface IServer : IDisposable
{
/// <summary>
/// Convenience method for <see cref="ServerChannel.Start"/> that looks more idiomatic to typical
/// channel creation methods on <see cref="IChannelProvider"/>.
/// This should return the base address where the server is. If this server is not actually
/// serving content at any URL, this property should be null.
/// </summary>
/// <param name="provider">The channel provider.</param>
/// <param name="identifier">This is an identifier of the "type" of bundle that is being published,
/// and should form a path with forward-slash '/' delimiters.</param>
/// <returns>The newly created server channel, or <c>null</c> if there was no listener for
/// server channels on <paramref name="provider"/>.</returns>
public static ServerChannel StartServerChannel(this IChannelProvider provider, string identifier)
Uri BaseAddress { get; }
}
/// <summary>
/// Creates what might be considered a good "default" server factory, if possible,
/// or <c>null</c> if no good default was possible. A <c>null</c> value could be returned,
/// for example, if a user opted to remove all implementations of <see cref="IServer"/> and
/// the associated <see cref="IServerFactory"/> for security reasons.
/// </summary>
public static IServerFactory CreateDefaultServerFactoryOrNull(IHostEnvironment env)
{
Contracts.CheckValue(env, nameof(env));
// REVIEW: There should be a better way. There currently isn't,
// but there should be. This is pretty horrifying, but it is preferable to
// the alternative of having core components depend on an actual server
// implementation, since we want those to be removable because of security
// concerns in certain environments (since not everyone will be wild about
// web servers popping up everywhere).
ComponentCatalog.ComponentInfo component;
if (!env.ComponentCatalog.TryFindComponent(typeof(IServerFactory), "mini", out component))
return null;
IServerFactory factory = (IServerFactory)Activator.CreateInstance(component.ArgumentType);
var field = factory.GetType().GetField("Port");
if (field?.FieldType != typeof(int))
return null;
field.SetValue(factory, 12345);
return factory;
}
/// <summary>
/// When a <see cref="ServerChannel"/> is created, the creation method will send an implementation
/// is a notification sent through an <see cref="IPipe{IPendingBundleNotification}"/>, to indicate that
/// a <see cref="Bundle"/> may be pending soon. Listeners that want to receive the bundle to
/// expose it, for example, a web service, should register this interest by passing in an action to be called.
/// If no listener registers interest, the server channel that sent the notification will act
/// differently by, say, acting as a no-op w.r.t. client calls to it.
/// </summary>
public interface IPendingBundleNotification
{
/// <summary>
/// Any publisher of the named delegates will call this method, upon receiving an instance
/// of this object through the pipe. This method serves two purposes: firstly it detects
/// whether anyone is even interested in publishing anything at all, so that we can just
/// ignore any input delegates in the case where no one is listening (which, we must expect,
/// is the majority of scenarios). The second is that it provides an action to call, once
/// all publishing is complete, and <see cref="Publish"/> has been called by the client code.
/// </summary>
/// <param name="toDo">The callback to perform when all named delegates have been registered,
/// and <see cref="Publish"/> is called.</param>
void Acknowledge(Action<Bundle> toDo);
}
/// <summary>
/// The final bundle of published named delegates that a listener can serve.
/// </summary>
public sealed class Bundle
{
/// <summary>
/// This contains a name to delegate mappings. The delegates contained herein are gauranteed to be
/// some variety of <see cref="Func{TResult}"/>, <see cref="Func{T1, TResult}"/>,
/// <see cref="Func{T1, T2, TResult}"/>, etc.
/// </summary>
public readonly IReadOnlyDictionary<string, Delegate> NameToFuncs;
/// <summary>
/// This should be a more-or-less unique identifier for the type of API this bundle is producing.
/// Its intended use is that it will form part of the URL for the RESTful API, so to the extent that
/// it contains multiple tokens they must be slash delimited.
/// </summary>
public readonly string Identifier;
internal Action Done;
internal Bundle(ServerChannel sch)
{
Contracts.CheckValue(provider, nameof(provider));
Contracts.CheckNonWhiteSpace(identifier, nameof(identifier));
return ServerChannel.Start(provider, identifier);
Contracts.AssertValue(sch);
NameToFuncs = sch._toPublish;
Identifier = sch._identifier;
}
public void AddDoneAction(Action onDone)
{
Done += onDone;
}
}
}
[BestFriend]
internal static class ServerChannelUtilities
{
/// <summary>
/// Convenience method for <see cref="ServerChannel.Start"/> that looks more idiomatic to typical
/// channel creation methods on <see cref="IChannelProvider"/>.
/// </summary>
/// <param name="provider">The channel provider.</param>
/// <param name="identifier">This is an identifier of the "type" of bundle that is being published,
/// and should form a path with forward-slash '/' delimiters.</param>
/// <returns>The newly created server channel, or <c>null</c> if there was no listener for
/// server channels on <paramref name="provider"/>.</returns>
public static ServerChannel StartServerChannel(this IChannelProvider provider, string identifier)
{
Contracts.CheckValue(provider, nameof(provider));
Contracts.CheckNonWhiteSpace(identifier, nameof(identifier));
return ServerChannel.Start(provider, identifier);
}
}

Просмотреть файл

@ -4,69 +4,68 @@
using Microsoft.ML.Runtime;
namespace Microsoft.ML.Data
namespace Microsoft.ML.Data;
/// <summary>
/// Base class for creating a cursor on top of another cursor that does not add or remove rows.
/// It forces one-to-one correspondence between items in the input cursor and this cursor.
/// It delegates all <see cref="DataViewRowCursor"/> functionality except Dispose() to the root cursor.
/// Dispose is virtual with the default implementation delegating to the input cursor.
/// </summary>
[BestFriend]
internal abstract class SynchronizedCursorBase : DataViewRowCursor
{
protected readonly IChannel Ch;
/// <summary>
/// Base class for creating a cursor on top of another cursor that does not add or remove rows.
/// It forces one-to-one correspondence between items in the input cursor and this cursor.
/// It delegates all <see cref="DataViewRowCursor"/> functionality except Dispose() to the root cursor.
/// Dispose is virtual with the default implementation delegating to the input cursor.
/// The synchronized cursor base, as it merely passes through requests for all "positional" calls (including
/// <see cref="MoveNext"/>, <see cref="Position"/>, <see cref="Batch"/>, and so forth), offers an opportunity
/// for optimization for "wrapping" cursors (which are themselves often <see cref="SynchronizedCursorBase"/>
/// implementors) to get this root cursor. But, this can only be done by exposing this root cursor, as we do here.
/// Internal code should be quite careful in using this as the potential for misuse is quite high.
/// </summary>
[BestFriend]
internal abstract class SynchronizedCursorBase : DataViewRowCursor
internal readonly DataViewRowCursor Root;
private bool _disposed;
protected DataViewRowCursor Input { get; }
public sealed override long Position => Root.Position;
public sealed override long Batch => Root.Batch;
/// <summary>
/// Convenience property for checking whether the cursor is in a good state where values
/// can be retrieved, that is, whenever <see cref="Position"/> is non-negative.
/// </summary>
protected bool IsGood => Position >= 0;
protected SynchronizedCursorBase(IChannelProvider provider, DataViewRowCursor input)
{
protected readonly IChannel Ch;
Contracts.AssertValue(provider);
Ch = provider.Start("Cursor");
/// <summary>
/// The synchronized cursor base, as it merely passes through requests for all "positional" calls (including
/// <see cref="MoveNext"/>, <see cref="Position"/>, <see cref="Batch"/>, and so forth), offers an opportunity
/// for optimization for "wrapping" cursors (which are themselves often <see cref="SynchronizedCursorBase"/>
/// implementors) to get this root cursor. But, this can only be done by exposing this root cursor, as we do here.
/// Internal code should be quite careful in using this as the potential for misuse is quite high.
/// </summary>
internal readonly DataViewRowCursor Root;
private bool _disposed;
protected DataViewRowCursor Input { get; }
public sealed override long Position => Root.Position;
public sealed override long Batch => Root.Batch;
/// <summary>
/// Convenience property for checking whether the cursor is in a good state where values
/// can be retrieved, that is, whenever <see cref="Position"/> is non-negative.
/// </summary>
protected bool IsGood => Position >= 0;
protected SynchronizedCursorBase(IChannelProvider provider, DataViewRowCursor input)
{
Contracts.AssertValue(provider);
Ch = provider.Start("Cursor");
Ch.AssertValue(input);
Input = input;
// If this thing happens to be itself an instance of this class (which, practically, it will
// be in the majority of situations), we can treat the input as likewise being a passthrough,
// thereby saving lots of "nested" calls on the stack when doing common operations like movement.
Root = Input is SynchronizedCursorBase syncInput ? syncInput.Root : input;
}
protected override void Dispose(bool disposing)
{
if (_disposed)
return;
if (disposing)
{
Input.Dispose();
Ch.Dispose();
}
base.Dispose(disposing);
_disposed = true;
}
public sealed override bool MoveNext() => Root.MoveNext();
public sealed override ValueGetter<DataViewRowId> GetIdGetter() => Input.GetIdGetter();
Ch.AssertValue(input);
Input = input;
// If this thing happens to be itself an instance of this class (which, practically, it will
// be in the majority of situations), we can treat the input as likewise being a passthrough,
// thereby saving lots of "nested" calls on the stack when doing common operations like movement.
Root = Input is SynchronizedCursorBase syncInput ? syncInput.Root : input;
}
protected override void Dispose(bool disposing)
{
if (_disposed)
return;
if (disposing)
{
Input.Dispose();
Ch.Dispose();
}
base.Dispose(disposing);
_disposed = true;
}
public sealed override bool MoveNext() => Root.MoveNext();
public sealed override ValueGetter<DataViewRowId> GetIdGetter() => Input.GetIdGetter();
}

Просмотреть файл

@ -5,61 +5,60 @@
using System;
using Microsoft.ML.Runtime;
namespace Microsoft.ML.Data
namespace Microsoft.ML.Data;
/// <summary>
/// Convenient base class for <see cref="DataViewRow"/> implementors that wrap a single <see cref="DataViewRow"/>
/// as their input. The <see cref="DataViewRow.Position"/>, <see cref="DataViewRow.Batch"/>, and <see cref="DataViewRow.GetIdGetter"/>
/// are taken from this <see cref="Input"/>.
/// </summary>
[BestFriend]
internal abstract class WrappingRow : DataViewRow
{
private bool _disposed;
/// <summary>
/// Convenient base class for <see cref="DataViewRow"/> implementors that wrap a single <see cref="DataViewRow"/>
/// as their input. The <see cref="DataViewRow.Position"/>, <see cref="DataViewRow.Batch"/>, and <see cref="DataViewRow.GetIdGetter"/>
/// are taken from this <see cref="Input"/>.
/// The wrapped input row.
/// </summary>
protected DataViewRow Input { get; }
public sealed override long Batch => Input.Batch;
public sealed override long Position => Input.Position;
public override ValueGetter<DataViewRowId> GetIdGetter() => Input.GetIdGetter();
[BestFriend]
internal abstract class WrappingRow : DataViewRow
private protected WrappingRow(DataViewRow input)
{
private bool _disposed;
Contracts.AssertValue(input);
Input = input;
}
/// <summary>
/// The wrapped input row.
/// </summary>
protected DataViewRow Input { get; }
/// <summary>
/// This override of the dispose method by default only calls <see cref="Input"/>'s
/// <see cref="IDisposable.Dispose"/> method, but subclasses can enable additional functionality
/// via the <see cref="DisposeCore(bool)"/> functionality.
/// </summary>
/// <param name="disposing"></param>
protected sealed override void Dispose(bool disposing)
{
if (_disposed)
return;
// Since the input was created first, and this instance may depend on it, we should
// dispose local resources first before potentially disposing the input row resources.
DisposeCore(disposing);
if (disposing)
Input.Dispose();
_disposed = true;
}
public sealed override long Batch => Input.Batch;
public sealed override long Position => Input.Position;
public override ValueGetter<DataViewRowId> GetIdGetter() => Input.GetIdGetter();
[BestFriend]
private protected WrappingRow(DataViewRow input)
{
Contracts.AssertValue(input);
Input = input;
}
/// <summary>
/// This override of the dispose method by default only calls <see cref="Input"/>'s
/// <see cref="IDisposable.Dispose"/> method, but subclasses can enable additional functionality
/// via the <see cref="DisposeCore(bool)"/> functionality.
/// </summary>
/// <param name="disposing"></param>
protected sealed override void Dispose(bool disposing)
{
if (_disposed)
return;
// Since the input was created first, and this instance may depend on it, we should
// dispose local resources first before potentially disposing the input row resources.
DisposeCore(disposing);
if (disposing)
Input.Dispose();
_disposed = true;
}
/// <summary>
/// Called from <see cref="Dispose(bool)"/> with <see langword="true"/> in the case where
/// that method has never been called before, and right after <see cref="Input"/> has been
/// disposed. The default implementation does nothing.
/// </summary>
/// <param name="disposing">Whether this was called through the dispose path, as opposed
/// to the finalizer path.</param>
protected virtual void DisposeCore(bool disposing)
{
}
/// <summary>
/// Called from <see cref="Dispose(bool)"/> with <see langword="true"/> in the case where
/// that method has never been called before, and right after <see cref="Input"/> has been
/// disposed. The default implementation does nothing.
/// </summary>
/// <param name="disposing">Whether this was called through the dispose path, as opposed
/// to the finalizer path.</param>
protected virtual void DisposeCore(bool disposing)
{
}
}