machinelearning/test/Microsoft.ML.OnnxTransformTest/DnnImageFeaturizerTest.cs

199 строки
7.6 KiB
C#

// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections.Generic;
using System.IO;
using System.Runtime.InteropServices;
using Microsoft.ML;
using Microsoft.ML.Core.Data;
using Microsoft.ML.Data;
using Microsoft.ML.Model;
using Microsoft.ML.RunTests;
using Microsoft.ML.StaticPipe;
using Microsoft.ML.Transforms;
using Microsoft.ML.Transforms.StaticPipe;
using Xunit;
using Xunit.Abstractions;
namespace Microsoft.ML.Tests
{
public class DnnImageFeaturizerTests : TestDataPipeBase
{
private const int inputSize = 3 * 224 * 224;
private class TestData
{
[VectorType(inputSize)]
public float[] data_0;
}
private class TestDataSize
{
[VectorType(2)]
public float[] data_0;
}
private class TestDataXY
{
[VectorType(inputSize)]
public float[] A;
}
private class TestDataDifferntType
{
[VectorType(inputSize)]
public string[] data_0;
}
private float[] GetSampleArrayData()
{
var samplevector = new float[inputSize];
for (int i = 0; i < inputSize; i++)
samplevector[i] = (i / ((float) inputSize));
return samplevector;
}
public DnnImageFeaturizerTests(ITestOutputHelper helper) : base(helper)
{
}
// Onnx is only supported on x64 Windows
[ConditionalFact(typeof(Environment), nameof(Environment.Is64BitProcess))]
void TestDnnImageFeaturizer()
{
if (!RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
return;
var samplevector = GetSampleArrayData();
var dataView = DataViewConstructionUtils.CreateFromList(Env,
new TestData[] {
new TestData()
{
data_0 = samplevector
},
});
var xyData = new List<TestDataXY> { new TestDataXY() { A = new float[inputSize] } };
var stringData = new List<TestDataDifferntType> { new TestDataDifferntType() { data_0 = new string[inputSize] } };
var sizeData = new List<TestDataSize> { new TestDataSize() { data_0 = new float[2] } };
var pipe = new DnnImageFeaturizerEstimator(Env, m => m.ModelSelector.ResNet18(m.Environment, m.InputColumn, m.OutputColumn), "data_0", "output_1");
var invalidDataWrongNames = ComponentCreation.CreateDataView(Env, xyData);
var invalidDataWrongTypes = ComponentCreation.CreateDataView(Env, stringData);
var invalidDataWrongVectorSize = ComponentCreation.CreateDataView(Env, sizeData);
TestEstimatorCore(pipe, dataView, invalidInput: invalidDataWrongNames);
TestEstimatorCore(pipe, dataView, invalidInput: invalidDataWrongTypes);
pipe.GetOutputSchema(SchemaShape.Create(invalidDataWrongVectorSize.Schema));
try
{
pipe.Fit(invalidDataWrongVectorSize);
Assert.False(true);
}
catch (ArgumentOutOfRangeException) { }
catch (InvalidOperationException) { }
}
// Onnx is only supported on x64 Windows
[ConditionalFact(typeof(Environment), nameof(Environment.Is64BitProcess))]
public void OnnxStatic()
{
if (!RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
return;
var env = new MLContext(null, 1);
var imageHeight = 224;
var imageWidth = 224;
var dataFile = GetDataPath("images/images.tsv");
var imageFolder = Path.GetDirectoryName(dataFile);
var data = TextLoaderStatic.CreateReader(env, ctx => (
imagePath: ctx.LoadText(0),
name: ctx.LoadText(1)))
.Read(dataFile);
var pipe = data.MakeNewEstimator()
.Append(row => (
row.name,
data_0: row.imagePath.LoadAsImage(imageFolder).Resize(imageHeight, imageWidth).ExtractPixels(interleaveArgb: true)))
.Append(row => (row.name, output_1: row.data_0.DnnImageFeaturizer(m => m.ModelSelector.ResNet18(m.Environment, m.InputColumn, m.OutputColumn))));
TestEstimatorCore(pipe.AsDynamic, data.AsDynamic);
var result = pipe.Fit(data).Transform(data).AsDynamic;
result.Schema.TryGetColumnIndex("output_1", out int output);
using (var cursor = result.GetRowCursor(result.Schema["output_1"]))
{
var buffer = default(VBuffer<float>);
var getter = cursor.GetGetter<VBuffer<float>>(output);
var numRows = 0;
while (cursor.MoveNext())
{
getter(ref buffer);
Assert.Equal(512, buffer.Length);
numRows += 1;
}
Assert.Equal(4, numRows);
}
}
// Onnx is only supported on x64 Windows
[ConditionalFact(typeof(Environment), nameof(Environment.Is64BitProcess))]
public void TestOldSavingAndLoading()
{
if (!RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
return;
var samplevector = GetSampleArrayData();
var dataView = ComponentCreation.CreateDataView(Env,
new TestData[] {
new TestData()
{
data_0 = samplevector
}
});
var inputNames = "data_0";
var outputNames = "output_1";
var est = new DnnImageFeaturizerEstimator(Env, m => m.ModelSelector.ResNet18(m.Environment, m.InputColumn, m.OutputColumn), inputNames, outputNames);
var transformer = est.Fit(dataView);
var result = transformer.Transform(dataView);
var resultRoles = new RoleMappedData(result);
using (var ms = new MemoryStream())
{
TrainUtils.SaveModel(Env, Env.Start("saving"), ms, null, resultRoles);
ms.Position = 0;
var loadedView = ModelFileUtils.LoadTransforms(Env, dataView, ms);
loadedView.Schema.TryGetColumnIndex(outputNames, out int softMaxOut1);
using (var cursor = loadedView.GetRowCursor(loadedView.Schema[outputNames]))
{
VBuffer<float> softMaxValue = default;
var softMaxGetter = cursor.GetGetter<VBuffer<float>>(softMaxOut1);
float sum = 0f;
int i = 0;
while (cursor.MoveNext())
{
softMaxGetter(ref softMaxValue);
var values = softMaxValue.DenseValues();
foreach (var val in values)
{
sum += val;
if (i == 0)
Assert.InRange(val, 0.0, 0.00001);
if (i == 7)
Assert.InRange(val, 0.62935, 0.62940);
if (i == 500)
Assert.InRange(val, 0.15521, 0.155225);
i++;
}
}
Assert.InRange(sum, 83.50, 84.50);
}
}
}
}
}