255 строки
12 KiB
C#
255 строки
12 KiB
C#
// Licensed to the .NET Foundation under one or more agreements.
|
|
// The .NET Foundation licenses this file to you under the MIT license.
|
|
// See the LICENSE file in the project root for more information.
|
|
|
|
using System;
|
|
using System.Collections.Generic;
|
|
using System.IO;
|
|
using System.Linq;
|
|
using ApprovalTests;
|
|
using ApprovalTests.Namers;
|
|
using ApprovalTests.Reporters;
|
|
using FluentAssertions;
|
|
using Microsoft.ML.Data;
|
|
using Microsoft.ML.TestFramework;
|
|
using Newtonsoft.Json;
|
|
using Xunit;
|
|
using Xunit.Abstractions;
|
|
|
|
namespace Microsoft.ML.AutoML.Test
|
|
{
|
|
|
|
public class ColumnInferenceTests : BaseTestClass
|
|
{
|
|
public ColumnInferenceTests(ITestOutputHelper output) : base(output)
|
|
{
|
|
}
|
|
|
|
[Fact]
|
|
public void UnGroupReturnsMoreColumnsThanGroup()
|
|
{
|
|
var dataPath = DatasetUtil.GetUciAdultDataset();
|
|
var context = new MLContext(1);
|
|
var columnInferenceWithoutGrouping = context.Auto().InferColumns(dataPath, DatasetUtil.UciAdultLabel, groupColumns: false);
|
|
foreach (var col in columnInferenceWithoutGrouping.TextLoaderOptions.Columns)
|
|
{
|
|
Assert.False(col.Source.Length > 1 || col.Source[0].Min != col.Source[0].Max);
|
|
}
|
|
|
|
var columnInferenceWithGrouping = context.Auto().InferColumns(dataPath, DatasetUtil.UciAdultLabel, groupColumns: true);
|
|
Assert.True(columnInferenceWithGrouping.TextLoaderOptions.Columns.Count() < columnInferenceWithoutGrouping.TextLoaderOptions.Columns.Count());
|
|
}
|
|
|
|
[Fact]
|
|
public void IncorrectLabelColumnThrows()
|
|
{
|
|
var dataPath = DatasetUtil.GetUciAdultDataset();
|
|
var context = new MLContext(1);
|
|
Assert.Throws<ArgumentException>(new System.Action(() => context.Auto().InferColumns(dataPath, "Junk", groupColumns: false)));
|
|
}
|
|
|
|
[Fact]
|
|
public void LabelIndexOutOfBoundsThrows()
|
|
{
|
|
Assert.Throws<ArgumentOutOfRangeException>(() => new MLContext(1).Auto().InferColumns(DatasetUtil.GetUciAdultDataset(), 100));
|
|
}
|
|
|
|
[Fact]
|
|
public void IdentifyLabelColumnThroughIndexWithHeader()
|
|
{
|
|
var result = new MLContext(1).Auto().InferColumns(DatasetUtil.GetUciAdultDataset(), 14, hasHeader: true);
|
|
Assert.True(result.TextLoaderOptions.HasHeader);
|
|
var labelCol = result.TextLoaderOptions.Columns.First(c => c.Source[0].Min == 14 && c.Source[0].Max == 14);
|
|
Assert.Equal("hours-per-week", labelCol.Name);
|
|
Assert.Equal("hours-per-week", result.ColumnInformation.LabelColumnName);
|
|
}
|
|
|
|
[Fact]
|
|
public void IdentifyLabelColumnThroughIndexWithoutHeader()
|
|
{
|
|
var result = new MLContext(1).Auto().InferColumns(DatasetUtil.GetIrisDataset(), DatasetUtil.IrisDatasetLabelColIndex);
|
|
Assert.False(result.TextLoaderOptions.HasHeader);
|
|
var labelCol = result.TextLoaderOptions.Columns.First(c => c.Source[0].Min == DatasetUtil.IrisDatasetLabelColIndex &&
|
|
c.Source[0].Max == DatasetUtil.IrisDatasetLabelColIndex);
|
|
Assert.Equal(DefaultColumnNames.Label, labelCol.Name);
|
|
Assert.Equal(DefaultColumnNames.Label, result.ColumnInformation.LabelColumnName);
|
|
}
|
|
|
|
[Fact]
|
|
public void DatasetWithEmptyColumn()
|
|
{
|
|
var result = new MLContext(1).Auto().InferColumns(Path.Combine("TestData", "DatasetWithEmptyColumn.txt"), DefaultColumnNames.Label, groupColumns: false);
|
|
var emptyColumn = result.TextLoaderOptions.Columns.First(c => c.Name == "Empty");
|
|
Assert.Equal(DataKind.Single, emptyColumn.DataKind);
|
|
}
|
|
|
|
[Fact]
|
|
public void DatasetWithBoolColumn()
|
|
{
|
|
var result = new MLContext(1).Auto().InferColumns(Path.Combine("TestData", "BinaryDatasetWithBoolColumn.txt"), DefaultColumnNames.Label);
|
|
Assert.Equal(2, result.TextLoaderOptions.Columns.Count());
|
|
|
|
var boolColumn = result.TextLoaderOptions.Columns.First(c => c.Name == "Bool");
|
|
var labelColumn = result.TextLoaderOptions.Columns.First(c => c.Name == DefaultColumnNames.Label);
|
|
// ensure non-label Boolean column is detected as R4
|
|
Assert.Equal(DataKind.Single, boolColumn.DataKind);
|
|
Assert.Equal(DataKind.Boolean, labelColumn.DataKind);
|
|
|
|
// ensure non-label Boolean column is detected as R4
|
|
Assert.Single(result.ColumnInformation.NumericColumnNames);
|
|
Assert.Equal("Bool", result.ColumnInformation.NumericColumnNames.First());
|
|
Assert.Equal(DefaultColumnNames.Label, result.ColumnInformation.LabelColumnName);
|
|
}
|
|
|
|
[Fact]
|
|
public void InferDatasetWithoutHeader()
|
|
{
|
|
var context = new MLContext(1);
|
|
var filePath = Path.Combine("TestData", "DatasetWithoutHeader.txt");
|
|
var columnInfo = new ColumnInformation()
|
|
{
|
|
LabelColumnName = "col0",
|
|
UserIdColumnName = "col1",
|
|
ItemIdColumnName = "col2",
|
|
};
|
|
columnInfo.IgnoredColumnNames.Add("col4");
|
|
var result = ColumnInferenceApi.InferColumns(context, filePath, columnInfo, ',', null, null, false, false, false);
|
|
Assert.Equal(6, result.TextLoaderOptions.Columns.Count());
|
|
|
|
var labelColumn = result.TextLoaderOptions.Columns.First(c => c.Name == "col0");
|
|
var userColumn = result.TextLoaderOptions.Columns.First(c => c.Name == "col1");
|
|
var itemColumn = result.TextLoaderOptions.Columns.First(c => c.Name == "col2");
|
|
var ignoreColumn = result.TextLoaderOptions.Columns.First(c => c.Name == "col4");
|
|
|
|
Assert.Equal(DataKind.String, labelColumn.DataKind);
|
|
Assert.Equal(DataKind.Single, userColumn.DataKind);
|
|
Assert.Equal(DataKind.Single, itemColumn.DataKind);
|
|
Assert.Equal(DataKind.Single, ignoreColumn.DataKind);
|
|
|
|
Assert.Single(result.ColumnInformation.CategoricalColumnNames);
|
|
Assert.Empty(result.ColumnInformation.TextColumnNames);
|
|
}
|
|
|
|
[Fact]
|
|
public void WhereNameColumnIsOnlyFeature()
|
|
{
|
|
var result = new MLContext(1).Auto().InferColumns(Path.Combine("TestData", "NameColumnIsOnlyFeatureDataset.txt"), DefaultColumnNames.Label);
|
|
Assert.Equal(2, result.TextLoaderOptions.Columns.Count());
|
|
|
|
var nameColumn = result.TextLoaderOptions.Columns.First(c => c.Name == "Username");
|
|
var labelColumn = result.TextLoaderOptions.Columns.First(c => c.Name == DefaultColumnNames.Label);
|
|
Assert.Equal(DataKind.String, nameColumn.DataKind);
|
|
Assert.Equal(DataKind.Boolean, labelColumn.DataKind);
|
|
|
|
Assert.Single(result.ColumnInformation.TextColumnNames);
|
|
Assert.Equal("Username", result.ColumnInformation.TextColumnNames.First());
|
|
Assert.Equal(DefaultColumnNames.Label, result.ColumnInformation.LabelColumnName);
|
|
}
|
|
|
|
[Fact]
|
|
public void DefaultColumnNamesInferredCorrectly()
|
|
{
|
|
var result = new MLContext(1).Auto()
|
|
.InferColumns(Path.Combine("TestData", "DatasetWithDefaultColumnNames.txt"),
|
|
new ColumnInformation()
|
|
{
|
|
LabelColumnName = DefaultColumnNames.Label,
|
|
ExampleWeightColumnName = DefaultColumnNames.Weight,
|
|
UserIdColumnName = DefaultColumnNames.User,
|
|
ItemIdColumnName = DefaultColumnNames.Item,
|
|
},
|
|
groupColumns: false);
|
|
|
|
Assert.Equal(DefaultColumnNames.Label, result.ColumnInformation.LabelColumnName);
|
|
Assert.Equal(DefaultColumnNames.Weight, result.ColumnInformation.ExampleWeightColumnName);
|
|
Assert.Equal(DefaultColumnNames.User, result.ColumnInformation.UserIdColumnName);
|
|
Assert.Equal(DefaultColumnNames.Item, result.ColumnInformation.ItemIdColumnName);
|
|
Assert.Equal(3, result.ColumnInformation.NumericColumnNames.Count());
|
|
}
|
|
|
|
[Fact]
|
|
public void DefaultColumnNamesNoGrouping()
|
|
{
|
|
var result = new MLContext(1).Auto().InferColumns(Path.Combine("TestData", "DatasetWithDefaultColumnNames.txt"),
|
|
new ColumnInformation()
|
|
{
|
|
LabelColumnName = DefaultColumnNames.Label,
|
|
ExampleWeightColumnName = DefaultColumnNames.Weight,
|
|
});
|
|
|
|
Assert.Equal(DefaultColumnNames.Label, result.ColumnInformation.LabelColumnName);
|
|
Assert.Equal(DefaultColumnNames.Weight, result.ColumnInformation.ExampleWeightColumnName);
|
|
Assert.Single(result.ColumnInformation.NumericColumnNames);
|
|
Assert.Equal(DefaultColumnNames.Features, result.ColumnInformation.NumericColumnNames.First());
|
|
}
|
|
|
|
[Fact]
|
|
public void InferColumnsColumnInfoParam()
|
|
{
|
|
var columnInfo = new ColumnInformation() { LabelColumnName = DatasetUtil.MlNetGeneratedRegressionLabel };
|
|
var result = new MLContext(1).Auto().InferColumns(DatasetUtil.GetMlNetGeneratedRegressionDataset(),
|
|
columnInfo);
|
|
var labelCol = result.TextLoaderOptions.Columns.First(c => c.Name == DatasetUtil.MlNetGeneratedRegressionLabel);
|
|
Assert.Equal(DataKind.Single, labelCol.DataKind);
|
|
Assert.Equal(DatasetUtil.MlNetGeneratedRegressionLabel, result.ColumnInformation.LabelColumnName);
|
|
Assert.Single(result.ColumnInformation.NumericColumnNames);
|
|
Assert.Equal(DefaultColumnNames.Features, result.ColumnInformation.NumericColumnNames.First());
|
|
Assert.Null(result.ColumnInformation.ExampleWeightColumnName);
|
|
}
|
|
|
|
[Fact]
|
|
public void TrySplitColumns_should_split_on_dataset_with_newline_between_double_quotes()
|
|
{
|
|
var context = new MLContext();
|
|
var dataset = Path.Combine("TestData", "DatasetWithNewlineBetweenQuotes.txt");
|
|
var sample = TextFileSample.CreateFromFullFile(dataset);
|
|
var result = TextFileContents.TrySplitColumns(context, sample, TextFileContents.DefaultSeparators);
|
|
|
|
result.ColumnCount.Should().Be(4);
|
|
result.Separator.Should().Be(',');
|
|
result.IsSuccess.Should().BeTrue();
|
|
}
|
|
|
|
[Fact]
|
|
public void InferColumnsFromMultilineInputFile()
|
|
{
|
|
// Check if we can infer the column information
|
|
// from and input file which has escaped newlines inside quotes
|
|
var dataPath = GetDataPath("multiline.csv");
|
|
MLContext mlContext = new MLContext();
|
|
var inputColumnInformation = new ColumnInformation();
|
|
inputColumnInformation.LabelColumnName = @"id";
|
|
var result = mlContext.Auto().InferColumns(dataPath, inputColumnInformation);
|
|
|
|
// File has 3 columns: "id", "description" and "animal"
|
|
Assert.NotNull(result.ColumnInformation.LabelColumnName);
|
|
Assert.Single(result.ColumnInformation.TextColumnNames);
|
|
Assert.Single(result.ColumnInformation.CategoricalColumnNames);
|
|
|
|
Assert.Equal("id", result.ColumnInformation.LabelColumnName);
|
|
Assert.Equal("description", result.ColumnInformation.TextColumnNames.First());
|
|
Assert.Equal("animal", result.ColumnInformation.CategoricalColumnNames.First());
|
|
}
|
|
|
|
[Fact]
|
|
[UseReporter(typeof(DiffReporter))]
|
|
[UseApprovalSubdirectory("ApprovalTests")]
|
|
public void Wiki_column_inference_result_should_be_serializable()
|
|
{
|
|
// DiffEngine can't check for Helix, so the environment variable checks for helix.
|
|
if (Environment.GetEnvironmentVariable("HELIX_CORRELATION_ID") != null)
|
|
{
|
|
Approvals.UseAssemblyLocationForApprovedFiles();
|
|
}
|
|
|
|
var wiki = Path.Combine("TestData", "wiki-column-inference.json");
|
|
using (var stream = new StreamReader(wiki))
|
|
{
|
|
var json = stream.ReadToEnd();
|
|
var columnInferenceResults = JsonConvert.DeserializeObject<ColumnInferenceResults>(json);
|
|
Approvals.Verify(JsonConvert.SerializeObject(columnInferenceResults, Formatting.Indented));
|
|
}
|
|
}
|
|
}
|
|
}
|