428 строки
21 KiB
C#
428 строки
21 KiB
C#
// Licensed to the .NET Foundation under one or more agreements.
|
|
// The .NET Foundation licenses this file to you under the MIT license.
|
|
// See the LICENSE file in the project root for more information.
|
|
|
|
using System;
|
|
using System.Collections.Generic;
|
|
using System.IO;
|
|
using System.Linq;
|
|
using System.Threading.Tasks;
|
|
using Microsoft.ML.Data;
|
|
using Microsoft.ML.TestFramework;
|
|
using Xunit;
|
|
using Xunit.Abstractions;
|
|
|
|
namespace Microsoft.ML.AutoML.Test
|
|
{
|
|
|
|
public class UserInputValidationTests : BaseTestClass
|
|
{
|
|
private static readonly IDataView _data = DatasetUtil.GetUciAdultDataView();
|
|
|
|
public UserInputValidationTests(ITestOutputHelper output) : base(output)
|
|
{
|
|
}
|
|
|
|
[Fact]
|
|
public void ValidateExperimentExecuteNullTrainData()
|
|
{
|
|
var ex = Assert.Throws<ArgumentNullException>(() => UserInputValidationUtil.ValidateExperimentExecuteArgs(null, new ColumnInformation(), null, TaskKind.Regression));
|
|
Assert.StartsWith("Training data cannot be null", ex.Message);
|
|
}
|
|
|
|
[Fact]
|
|
public void ValidateExperimentExecuteNullLabel()
|
|
{
|
|
var ex = Assert.Throws<ArgumentException>(() => UserInputValidationUtil.ValidateExperimentExecuteArgs(_data,
|
|
new ColumnInformation() { LabelColumnName = null }, null, TaskKind.Regression));
|
|
|
|
Assert.Equal("Provided label column cannot be null", ex.Message);
|
|
}
|
|
|
|
[Fact]
|
|
public void ValidateExperimentExecuteLabelNotInTrain()
|
|
{
|
|
foreach (var task in new[] { TaskKind.Recommendation, TaskKind.Regression, TaskKind.Ranking })
|
|
{
|
|
const string columnName = "ReallyLongNonExistingColumnName";
|
|
var ex = Assert.Throws<ArgumentException>(() => UserInputValidationUtil.ValidateExperimentExecuteArgs(_data,
|
|
new ColumnInformation() { LabelColumnName = columnName }, null, task));
|
|
|
|
Assert.Equal($"Provided label column '{columnName}' not found in training data.", ex.Message);
|
|
}
|
|
}
|
|
|
|
[Fact]
|
|
public void ValidateExperimentExecuteLabelNotInTrainMistyped()
|
|
{
|
|
foreach (var task in new[] { TaskKind.Recommendation, TaskKind.Regression, TaskKind.Ranking })
|
|
{
|
|
var originalColumnName = _data.Schema.First().Name;
|
|
var mistypedColumnName = originalColumnName + "a";
|
|
var ex = Assert.Throws<ArgumentException>(() => UserInputValidationUtil.ValidateExperimentExecuteArgs(_data,
|
|
new ColumnInformation() { LabelColumnName = mistypedColumnName }, null, task));
|
|
|
|
Assert.Equal($"Provided label column '{mistypedColumnName}' not found in training data. Did you mean '{originalColumnName}'.",
|
|
ex.Message);
|
|
}
|
|
}
|
|
|
|
[Fact]
|
|
public void ValidateExperimentExecuteNumericColNotInTrain()
|
|
{
|
|
var columnInfo = new ColumnInformation();
|
|
columnInfo.NumericColumnNames.Add("N");
|
|
|
|
foreach (var task in new[] { TaskKind.Recommendation, TaskKind.Regression, TaskKind.Ranking })
|
|
{
|
|
var ex = Assert.Throws<ArgumentException>(() => UserInputValidationUtil.ValidateExperimentExecuteArgs(_data, columnInfo, null, task));
|
|
Assert.Equal("Provided label column 'Label' was of type Boolean, but only type Single is allowed.", ex.Message);
|
|
}
|
|
}
|
|
|
|
[Fact]
|
|
public void ValidateExperimentExecuteNullNumericCol()
|
|
{
|
|
var columnInfo = new ColumnInformation();
|
|
columnInfo.NumericColumnNames.Add(null);
|
|
|
|
var ex = Assert.Throws<ArgumentException>(() => UserInputValidationUtil.ValidateExperimentExecuteArgs(_data, columnInfo, null, TaskKind.Regression));
|
|
Assert.Equal("Null column string was specified as numeric in column information", ex.Message);
|
|
}
|
|
|
|
[Fact]
|
|
public void ValidateExperimentExecuteDuplicateCol()
|
|
{
|
|
var columnInfo = new ColumnInformation();
|
|
columnInfo.NumericColumnNames.Add(DefaultColumnNames.Label);
|
|
|
|
var ex = Assert.Throws<ArgumentException>(() => UserInputValidationUtil.ValidateExperimentExecuteArgs(_data, columnInfo, null, TaskKind.Regression));
|
|
}
|
|
|
|
[Fact]
|
|
public void ValidateExperimentExecuteArgsTrainValidColCountMismatch()
|
|
{
|
|
var context = new MLContext(1);
|
|
|
|
var trainDataBuilder = new ArrayDataViewBuilder(context);
|
|
trainDataBuilder.AddColumn("0", NumberDataViewType.Single, new float[] { 1 });
|
|
trainDataBuilder.AddColumn("1", new string[] { "1" });
|
|
var trainData = trainDataBuilder.GetDataView();
|
|
|
|
var validDataBuilder = new ArrayDataViewBuilder(context);
|
|
validDataBuilder.AddColumn("0", NumberDataViewType.Single, new float[] { 1 });
|
|
var validData = validDataBuilder.GetDataView();
|
|
|
|
foreach (var task in new[] { TaskKind.Recommendation, TaskKind.Regression, TaskKind.Ranking })
|
|
{
|
|
var ex = Assert.Throws<ArgumentException>(() => UserInputValidationUtil.ValidateExperimentExecuteArgs(trainData,
|
|
new ColumnInformation() { LabelColumnName = "0" }, validData, task));
|
|
Assert.StartsWith("Training data and validation data schemas do not match. Train data has '2' columns,and validation data has '1' columns.", ex.Message);
|
|
}
|
|
}
|
|
|
|
[Fact]
|
|
public void ValidateExperimentExecuteArgsTrainValidColNamesMismatch()
|
|
{
|
|
var context = new MLContext(1);
|
|
|
|
var trainDataBuilder = new ArrayDataViewBuilder(context);
|
|
trainDataBuilder.AddColumn("0", NumberDataViewType.Single, new float[] { 1 });
|
|
trainDataBuilder.AddColumn("1", new string[] { "1" });
|
|
var trainData = trainDataBuilder.GetDataView();
|
|
|
|
var validDataBuilder = new ArrayDataViewBuilder(context);
|
|
validDataBuilder.AddColumn("0", NumberDataViewType.Single, new float[] { 1 });
|
|
validDataBuilder.AddColumn("2", new string[] { "2" });
|
|
var validData = validDataBuilder.GetDataView();
|
|
|
|
foreach (var task in new[] { TaskKind.Recommendation, TaskKind.Regression, TaskKind.Ranking })
|
|
{
|
|
var ex = Assert.Throws<ArgumentException>(() => UserInputValidationUtil.ValidateExperimentExecuteArgs(trainData,
|
|
new ColumnInformation() { LabelColumnName = "0" }, validData, task));
|
|
Assert.StartsWith("Training data and validation data schemas do not match. Column '1' exists in train data, but not in validation data.", ex.Message);
|
|
}
|
|
}
|
|
|
|
[Fact]
|
|
public void ValidateExperimentExecuteArgsTrainValidColTypeMismatch()
|
|
{
|
|
var context = new MLContext(1);
|
|
|
|
var trainDataBuilder = new ArrayDataViewBuilder(context);
|
|
trainDataBuilder.AddColumn("0", NumberDataViewType.Single, new float[] { 1 });
|
|
trainDataBuilder.AddColumn("1", new string[] { "1" });
|
|
var trainData = trainDataBuilder.GetDataView();
|
|
|
|
var validDataBuilder = new ArrayDataViewBuilder(context);
|
|
validDataBuilder.AddColumn("0", NumberDataViewType.Single, new float[] { 1 });
|
|
validDataBuilder.AddColumn("1", NumberDataViewType.Single, new float[] { 1 });
|
|
var validData = validDataBuilder.GetDataView();
|
|
|
|
foreach (var task in new[] { TaskKind.Recommendation, TaskKind.Regression, TaskKind.Ranking })
|
|
{
|
|
var ex = Assert.Throws<ArgumentException>(() => UserInputValidationUtil.ValidateExperimentExecuteArgs(trainData,
|
|
new ColumnInformation() { LabelColumnName = "0" }, validData, TaskKind.Regression));
|
|
Assert.StartsWith("Training data and validation data schemas do not match. Column '1' is of type String in train data, and type Single in validation data.", ex.Message);
|
|
}
|
|
}
|
|
|
|
[Fact]
|
|
public void ValidateInferColumnsArgsNullPath()
|
|
{
|
|
var ex = Assert.Throws<ArgumentNullException>(() => UserInputValidationUtil.ValidateInferColumnsArgs(null, "Label"));
|
|
Assert.StartsWith("Provided path cannot be null", ex.Message);
|
|
}
|
|
|
|
[Fact]
|
|
public void ValidateInferColumnsArgsPathNotExist()
|
|
{
|
|
var ex = Assert.Throws<ArgumentException>(() => UserInputValidationUtil.ValidateInferColumnsArgs("idontexist", "Label"));
|
|
Assert.StartsWith("File 'idontexist' does not exist", ex.Message);
|
|
}
|
|
|
|
[Fact]
|
|
public void ValidateInferColumnsArgsEmptyFile()
|
|
{
|
|
const string emptyFilePath = "empty";
|
|
File.Create(emptyFilePath).Dispose();
|
|
var ex = Assert.Throws<ArgumentException>(() => UserInputValidationUtil.ValidateInferColumnsArgs(emptyFilePath, "Label"));
|
|
Assert.StartsWith("File at path 'empty' cannot be empty", ex.Message);
|
|
}
|
|
|
|
[Fact]
|
|
public void ValidateInferColsPath()
|
|
{
|
|
UserInputValidationUtil.ValidateInferColumnsArgs(DatasetUtil.GetUciAdultDataset());
|
|
}
|
|
|
|
[Fact]
|
|
public void ValidateFeaturesColInvalidType()
|
|
{
|
|
var schemaBuilder = new DataViewSchema.Builder();
|
|
schemaBuilder.AddColumn(DefaultColumnNames.Features, NumberDataViewType.Double);
|
|
schemaBuilder.AddColumn(DefaultColumnNames.Label, NumberDataViewType.Single);
|
|
var schema = schemaBuilder.ToSchema();
|
|
var dataView = DataViewTestFixture.BuildDummyDataView(schema);
|
|
|
|
var ex = Assert.Throws<ArgumentException>(() => UserInputValidationUtil.ValidateExperimentExecuteArgs(dataView, new ColumnInformation(), null, TaskKind.Regression));
|
|
Assert.StartsWith("Features column must be of data type Single", ex.Message);
|
|
}
|
|
|
|
[Fact]
|
|
public void ValidateTextColumnNotText()
|
|
{
|
|
const string textPurposeColName = "TextColumn";
|
|
var schemaBuilder = new DataViewSchema.Builder();
|
|
schemaBuilder.AddColumn(DefaultColumnNames.Features, NumberDataViewType.Single);
|
|
schemaBuilder.AddColumn(DefaultColumnNames.Label, NumberDataViewType.Single);
|
|
schemaBuilder.AddColumn(textPurposeColName, NumberDataViewType.Single);
|
|
var schema = schemaBuilder.ToSchema();
|
|
var dataView = DataViewTestFixture.BuildDummyDataView(schema);
|
|
|
|
var columnInfo = new ColumnInformation();
|
|
columnInfo.TextColumnNames.Add(textPurposeColName);
|
|
|
|
foreach (var task in new[] { TaskKind.Recommendation, TaskKind.Regression, TaskKind.Ranking })
|
|
{
|
|
var ex = Assert.Throws<ArgumentException>(() => UserInputValidationUtil.ValidateExperimentExecuteArgs(dataView, columnInfo, null, task));
|
|
Assert.Equal("Provided text column 'TextColumn' was of type Single, but only type String is allowed.", ex.Message);
|
|
}
|
|
}
|
|
|
|
[Fact]
|
|
public void ValidateRegressionLabelTypes()
|
|
{
|
|
ValidateLabelTypeTestCore<float>(TaskKind.Regression, NumberDataViewType.Single, true);
|
|
ValidateLabelTypeTestCore<bool>(TaskKind.Regression, BooleanDataViewType.Instance, false);
|
|
ValidateLabelTypeTestCore<double>(TaskKind.Regression, NumberDataViewType.Double, false);
|
|
ValidateLabelTypeTestCore<string>(TaskKind.Regression, TextDataViewType.Instance, false);
|
|
}
|
|
|
|
[Fact]
|
|
public void ValidateRecommendationLabelTypes()
|
|
{
|
|
ValidateLabelTypeTestCore<float>(TaskKind.Recommendation, NumberDataViewType.Single, true);
|
|
ValidateLabelTypeTestCore<bool>(TaskKind.Recommendation, BooleanDataViewType.Instance, false);
|
|
ValidateLabelTypeTestCore<double>(TaskKind.Recommendation, NumberDataViewType.Double, false);
|
|
ValidateLabelTypeTestCore<string>(TaskKind.Recommendation, TextDataViewType.Instance, false);
|
|
}
|
|
|
|
[Fact]
|
|
public void ValidateBinaryClassificationLabelTypes()
|
|
{
|
|
ValidateLabelTypeTestCore<float>(TaskKind.BinaryClassification, NumberDataViewType.Single, false);
|
|
ValidateLabelTypeTestCore<bool>(TaskKind.BinaryClassification, BooleanDataViewType.Instance, true);
|
|
}
|
|
|
|
[Fact]
|
|
public void ValidateMulticlassLabelTypes()
|
|
{
|
|
ValidateLabelTypeTestCore<float>(TaskKind.MulticlassClassification, NumberDataViewType.Single, true);
|
|
ValidateLabelTypeTestCore<bool>(TaskKind.MulticlassClassification, BooleanDataViewType.Instance, true);
|
|
ValidateLabelTypeTestCore<double>(TaskKind.MulticlassClassification, NumberDataViewType.Double, true);
|
|
ValidateLabelTypeTestCore<string>(TaskKind.MulticlassClassification, TextDataViewType.Instance, true);
|
|
}
|
|
|
|
[Fact]
|
|
public void ValidateRankingLabelTypes()
|
|
{
|
|
ValidateLabelTypeTestCore<float>(TaskKind.Ranking, NumberDataViewType.Single, true);
|
|
ValidateLabelTypeTestCore<bool>(TaskKind.Ranking, BooleanDataViewType.Instance, false);
|
|
ValidateLabelTypeTestCore<double>(TaskKind.Ranking, NumberDataViewType.Double, false);
|
|
ValidateLabelTypeTestCore<string>(TaskKind.Ranking, TextDataViewType.Instance, false);
|
|
}
|
|
|
|
[Fact]
|
|
public void ValidateAllowedFeatureColumnTypes()
|
|
{
|
|
var dataViewBuilder = new ArrayDataViewBuilder(new MLContext(1));
|
|
dataViewBuilder.AddColumn("Boolean", BooleanDataViewType.Instance, false);
|
|
dataViewBuilder.AddColumn("Number", NumberDataViewType.Single, 0f);
|
|
dataViewBuilder.AddColumn("Text", "a");
|
|
dataViewBuilder.AddColumn(DefaultColumnNames.Label, NumberDataViewType.Single, 0f);
|
|
var dataView = dataViewBuilder.GetDataView();
|
|
|
|
foreach (var task in new[] { TaskKind.Recommendation, TaskKind.Regression, TaskKind.Ranking })
|
|
{
|
|
UserInputValidationUtil.ValidateExperimentExecuteArgs(dataView, new ColumnInformation(),
|
|
null, task);
|
|
}
|
|
}
|
|
|
|
[Fact]
|
|
public void ValidateProhibitedFeatureColumnType()
|
|
{
|
|
var schemaBuilder = new DataViewSchema.Builder();
|
|
schemaBuilder.AddColumn("UInt64", NumberDataViewType.UInt64);
|
|
schemaBuilder.AddColumn(DefaultColumnNames.Label, NumberDataViewType.Single);
|
|
var schema = schemaBuilder.ToSchema();
|
|
var dataView = DataViewTestFixture.BuildDummyDataView(schema);
|
|
|
|
var ex = Assert.Throws<ArgumentException>(() => UserInputValidationUtil.ValidateExperimentExecuteArgs(dataView, new ColumnInformation(),
|
|
null, TaskKind.Regression));
|
|
Assert.StartsWith("Only supported feature column types are Boolean, Single, and String. Please change the feature column UInt64 of type UInt64 to one of the supported types.", ex.Message);
|
|
}
|
|
|
|
[Fact]
|
|
public void ValidateEmptyTrainingDataThrows()
|
|
{
|
|
var schemaBuilder = new DataViewSchema.Builder();
|
|
schemaBuilder.AddColumn("Number", NumberDataViewType.Single);
|
|
schemaBuilder.AddColumn(DefaultColumnNames.Label, NumberDataViewType.Single);
|
|
var schema = schemaBuilder.ToSchema();
|
|
var dataView = DataViewTestFixture.BuildDummyDataView(schema, createDummyRow: false);
|
|
var ex = Assert.Throws<ArgumentException>(() => UserInputValidationUtil.ValidateExperimentExecuteArgs(dataView, new ColumnInformation(),
|
|
null, TaskKind.Regression));
|
|
Assert.StartsWith("Training data has 0 rows", ex.Message);
|
|
}
|
|
|
|
[Fact]
|
|
public void ValidateEmptyValidationDataThrows()
|
|
{
|
|
// Training data
|
|
var dataViewBuilder = new ArrayDataViewBuilder(new MLContext(1));
|
|
dataViewBuilder.AddColumn("Number", NumberDataViewType.Single, 0f);
|
|
dataViewBuilder.AddColumn(DefaultColumnNames.Label, NumberDataViewType.Single, 0f);
|
|
var trainingData = dataViewBuilder.GetDataView();
|
|
|
|
// Validation data
|
|
var schemaBuilder = new DataViewSchema.Builder();
|
|
schemaBuilder.AddColumn("Number", NumberDataViewType.Single);
|
|
schemaBuilder.AddColumn(DefaultColumnNames.Label, NumberDataViewType.Single);
|
|
var schema = schemaBuilder.ToSchema();
|
|
var validationData = DataViewTestFixture.BuildDummyDataView(schema, createDummyRow: false);
|
|
|
|
foreach (var task in new[] { TaskKind.Recommendation, TaskKind.Regression, TaskKind.Ranking })
|
|
{
|
|
var ex = Assert.Throws<ArgumentException>(() => UserInputValidationUtil.ValidateExperimentExecuteArgs(trainingData, new ColumnInformation(),
|
|
validationData, task));
|
|
Assert.StartsWith("Validation data has 0 rows", ex.Message);
|
|
}
|
|
}
|
|
|
|
|
|
[Fact]
|
|
public void TestValidationDataSchemaChecksIgnoreHiddenColumns()
|
|
{
|
|
var mlContext = new MLContext(1);
|
|
|
|
// Build training data where label column is a float.
|
|
var trainDataBuilder = new ArrayDataViewBuilder(mlContext);
|
|
trainDataBuilder.AddColumn("Number", NumberDataViewType.Single, 0f);
|
|
trainDataBuilder.AddColumn(DefaultColumnNames.Label, NumberDataViewType.Single, 0f);
|
|
var trainingData = trainDataBuilder.GetDataView();
|
|
|
|
// In the training data, transform the label column from a float to a Boolean. This has the effect of
|
|
// creating a hidden column named 'Label' of type float and an additional column named 'Label' of type Boolean.
|
|
var convertLabelToBoolEstimator = mlContext.Transforms.Conversion.MapValue(DefaultColumnNames.Label,
|
|
new List<KeyValuePair<float, bool>>() { new KeyValuePair<float, bool>(1, true) });
|
|
trainingData = convertLabelToBoolEstimator.Fit(trainingData).Transform(trainingData);
|
|
|
|
// Build validation data where label column is a Boolean.
|
|
var validationDataBuilder = new ArrayDataViewBuilder(mlContext);
|
|
validationDataBuilder.AddColumn("Number", NumberDataViewType.Single, 0f);
|
|
validationDataBuilder.AddColumn(DefaultColumnNames.Label, BooleanDataViewType.Instance, false);
|
|
var validationData = validationDataBuilder.GetDataView();
|
|
|
|
UserInputValidationUtil.ValidateExperimentExecuteArgs(trainingData, new ColumnInformation(), validationData, TaskKind.BinaryClassification);
|
|
}
|
|
|
|
|
|
private static void ValidateLabelTypeTestCore<TLabelRawType>(TaskKind task, PrimitiveDataViewType labelType, bool labelTypeShouldBeValid)
|
|
{
|
|
var dataViewBuilder = new ArrayDataViewBuilder(new MLContext(1));
|
|
dataViewBuilder.AddColumn(DefaultColumnNames.Features, NumberDataViewType.Single, 0f);
|
|
if (labelType == TextDataViewType.Instance)
|
|
{
|
|
dataViewBuilder.AddColumn(DefaultColumnNames.Label, string.Empty);
|
|
}
|
|
else
|
|
{
|
|
dataViewBuilder.AddColumn(DefaultColumnNames.Label, labelType, Activator.CreateInstance<TLabelRawType>());
|
|
}
|
|
var dataView = dataViewBuilder.GetDataView();
|
|
var validationExceptionThrown = false;
|
|
try
|
|
{
|
|
UserInputValidationUtil.ValidateExperimentExecuteArgs(dataView, new ColumnInformation(), null, task);
|
|
}
|
|
catch
|
|
{
|
|
validationExceptionThrown = true;
|
|
}
|
|
Assert.Equal(labelTypeShouldBeValid, !validationExceptionThrown);
|
|
}
|
|
[Fact]
|
|
public void ValidateTrainDataColumnTestMultipleMismatchLessThan5()
|
|
{
|
|
foreach (var task in new[] { TaskKind.Recommendation, TaskKind.Regression, TaskKind.Ranking })
|
|
{
|
|
var originalColumnName = _data.Schema.First().Name;
|
|
var mistypedColumnName = "a" + originalColumnName + "b";
|
|
var ex = Assert.Throws<ArgumentException>(() => UserInputValidationUtil.ValidateExperimentExecuteArgs(_data,
|
|
new ColumnInformation() { LabelColumnName = mistypedColumnName }, null, task));
|
|
|
|
Assert.Equal($"Provided label column '{mistypedColumnName}' not found in training data. Did you mean '{originalColumnName}'.",
|
|
ex.Message);
|
|
}
|
|
}
|
|
|
|
[Fact]
|
|
public void ValidateTrainDataColumnTestMultipleMismatchMoreThan5()
|
|
{
|
|
foreach (var task in new[] { TaskKind.Recommendation, TaskKind.Regression, TaskKind.Ranking })
|
|
{
|
|
var originalColumnName = _data.Schema.First().Name;
|
|
var mistypedColumnName = "a" + originalColumnName + "bcdvfnfmsm";
|
|
var ex = Assert.Throws<ArgumentException>(() => UserInputValidationUtil.ValidateExperimentExecuteArgs(_data,
|
|
new ColumnInformation() { LabelColumnName = mistypedColumnName }, null, task));
|
|
|
|
Assert.Equal($"Provided label column '{mistypedColumnName}' not found in training data.",
|
|
ex.Message);
|
|
}
|
|
}
|
|
|
|
}
|
|
}
|