272 строки
14 KiB
C#
272 строки
14 KiB
C#
using System;
|
|
using System.Collections.Generic;
|
|
using System.IO;
|
|
using System.Linq;
|
|
|
|
namespace CNTK.CSTrainingExamples
|
|
{
|
|
/// <summary>
|
|
/// This class shows how to do image classification using convolution network.
|
|
/// It follows closely this CNTK Python tutorial:
|
|
/// https://github.com/Microsoft/CNTK/blob/master/Tutorials/CNTK_201B_CIFAR-10_ImageHandsOn.ipynb
|
|
/// </summary>
|
|
public class CifarResNetClassifier
|
|
{
|
|
/// <summary>
|
|
/// execution folder is: CNTK/x64/BuildFolder
|
|
/// data folder is: CNTK/Examples/Image/DataSets
|
|
/// </summary>
|
|
public static string CifarDataFolder;
|
|
|
|
/// <summary>
|
|
/// number of epochs for training.
|
|
/// </summary>
|
|
public static uint MaxEpochs = 1;
|
|
|
|
private static readonly int[] imageDim = {32, 32, 3};
|
|
private static readonly int numClasses = 10;
|
|
|
|
/// <summary>
|
|
/// Train and evaluate an image classifier with CIFAR-10 data.
|
|
/// The classification model is saved after training.
|
|
/// For repeated runs, the caller may choose whether to retrain a model or
|
|
/// just validate an existing one.
|
|
/// </summary>
|
|
/// <param name="device">CPU or GPU device to run</param>
|
|
/// <param name="forceRetrain">whether to override an existing model.
|
|
/// if true, any existing model will be overridden and the new one evaluated.
|
|
/// if false and there is an existing model, the existing model is evaluated.</param>
|
|
public static void TrainAndEvaluate(DeviceDescriptor device, bool forceRetrain)
|
|
{
|
|
string modelFile = "Cifar10Rest.model";
|
|
|
|
// If a model already exists and not set to force retrain, validate the model and return.
|
|
if (File.Exists(modelFile) && !forceRetrain)
|
|
{
|
|
ValidateModel(device, modelFile);
|
|
return;
|
|
}
|
|
|
|
// prepare training data
|
|
var minibatchSource = CreateMinibatchSource(Path.Combine(CifarDataFolder, "train_map.txt"),
|
|
Path.Combine(CifarDataFolder, "CIFAR-10_mean.xml"), imageDim, numClasses, MaxEpochs);
|
|
var imageStreamInfo = minibatchSource.StreamInfo("features");
|
|
var labelStreamInfo = minibatchSource.StreamInfo("labels");
|
|
|
|
// build a model
|
|
var imageInput = CNTKLib.InputVariable(imageDim, imageStreamInfo.m_elementType, "Images");
|
|
var labelsVar = CNTKLib.InputVariable(new int[] { numClasses }, labelStreamInfo.m_elementType, "Labels");
|
|
var classifierOutput = ResNetClassifier(imageInput, numClasses, device, "classifierOutput");
|
|
|
|
// prepare for training
|
|
var trainingLoss = CNTKLib.CrossEntropyWithSoftmax(classifierOutput, labelsVar, "lossFunction");
|
|
var prediction = CNTKLib.ClassificationError(classifierOutput, labelsVar, 5, "predictionError");
|
|
|
|
var learningRatePerSample = new TrainingParameterScheduleDouble(0.0078125, 1);
|
|
var trainer = Trainer.CreateTrainer(classifierOutput, trainingLoss, prediction,
|
|
new List<Learner> { Learner.SGDLearner(classifierOutput.Parameters(), learningRatePerSample) });
|
|
|
|
uint minibatchSize = 64;
|
|
int outputFrequencyInMinibatches = 20, miniBatchCount = 0;
|
|
|
|
// Feed data to the trainer for number of epochs.
|
|
while (true)
|
|
{
|
|
var minibatchData = minibatchSource.GetNextMinibatch(minibatchSize, device);
|
|
|
|
// Stop training once max epochs is reached.
|
|
if (minibatchData.empty())
|
|
{
|
|
break;
|
|
}
|
|
|
|
trainer.TrainMinibatch(new Dictionary<Variable, MinibatchData>()
|
|
{ { imageInput, minibatchData[imageStreamInfo] }, { labelsVar, minibatchData[labelStreamInfo] } }, device);
|
|
TestHelper.PrintTrainingProgress(trainer, miniBatchCount++, outputFrequencyInMinibatches);
|
|
}
|
|
|
|
// save the model
|
|
var imageClassifier = Function.Combine(new List<Variable>() { trainingLoss, prediction, classifierOutput }, "ImageClassifier");
|
|
imageClassifier.Save(modelFile);
|
|
|
|
// validate the model
|
|
ValidateModel(device, modelFile);
|
|
}
|
|
|
|
private static void ValidateModel(DeviceDescriptor device, string modelFile)
|
|
{
|
|
MinibatchSource testMinibatchSource = CreateMinibatchSource(
|
|
Path.Combine(CifarDataFolder, "test_map.txt"),
|
|
Path.Combine(CifarDataFolder, "CIFAR-10_mean.xml"), imageDim, numClasses, 1);
|
|
TestHelper.ValidateModelWithMinibatchSource(modelFile, testMinibatchSource,
|
|
imageDim, numClasses, "features", "labels", "classifierOutput", device);
|
|
}
|
|
|
|
private static Function ConvBatchNormalizationReLULayer(Variable input, int outFeatureMapCount, int kernelWidth, int kernelHeight, int hStride, int vStride,
|
|
double wScale, double bValue, double scValue, int bnTimeConst, bool spatial, DeviceDescriptor device)
|
|
{
|
|
var convBNFunction = ConvBatchNormalizationLayer(input, outFeatureMapCount, kernelWidth, kernelHeight, hStride, vStride, wScale, bValue, scValue, bnTimeConst, spatial, device);
|
|
return CNTKLib.ReLU(convBNFunction);
|
|
}
|
|
|
|
private static Function ConvBatchNormalizationLayer(Variable input, int outFeatureMapCount, int kernelWidth, int kernelHeight, int hStride, int vStride,
|
|
double wScale, double bValue, double scValue, int bnTimeConst, bool spatial, DeviceDescriptor device)
|
|
{
|
|
int numInputChannels = input.Shape[input.Shape.Rank - 1];
|
|
|
|
var convParams = new Parameter(new int[] { kernelWidth, kernelHeight, numInputChannels, outFeatureMapCount },
|
|
DataType.Float, CNTKLib.GlorotUniformInitializer(wScale, -1, 2), device);
|
|
var convFunction = CNTKLib.Convolution(convParams, input, new int[] { hStride, vStride, numInputChannels });
|
|
|
|
var biasParams = new Parameter(new int[] { NDShape.InferredDimension }, (float)bValue, device, "");
|
|
var scaleParams = new Parameter(new int[] { NDShape.InferredDimension }, (float)scValue, device, "");
|
|
var runningMean = new Constant(new int[] { NDShape.InferredDimension }, 0.0f, device);
|
|
var runningInvStd = new Constant(new int[] { NDShape.InferredDimension }, 0.0f, device);
|
|
var runningCount = Constant.Scalar(0.0f, device);
|
|
return CNTKLib.BatchNormalization(convFunction, scaleParams, biasParams, runningMean, runningInvStd, runningCount,
|
|
spatial, (double)bnTimeConst, 0.0, 1e-5 /* epsilon */);
|
|
}
|
|
|
|
private static Function ResNetNode(Variable input, int outFeatureMapCount, int kernelWidth, int kernelHeight, double wScale, double bValue,
|
|
double scValue, int bnTimeConst, bool spatial, DeviceDescriptor device)
|
|
{
|
|
var c1 = ConvBatchNormalizationReLULayer(input, outFeatureMapCount, kernelWidth, kernelHeight, 1, 1, wScale, bValue, scValue, bnTimeConst, spatial, device);
|
|
var c2 = ConvBatchNormalizationLayer(c1, outFeatureMapCount, kernelWidth, kernelHeight, 1, 1, wScale, bValue, scValue, bnTimeConst, spatial, device);
|
|
var p = CNTKLib.Plus(c2, input);
|
|
return CNTKLib.ReLU(p);
|
|
}
|
|
|
|
private static Function ResNetNodeInc(Variable input, int outFeatureMapCount, int kernelWidth, int kernelHeight, double wScale, double bValue,
|
|
double scValue, int bnTimeConst, bool spatial, Variable wProj, DeviceDescriptor device)
|
|
{
|
|
var c1 = ConvBatchNormalizationReLULayer(input, outFeatureMapCount, kernelWidth, kernelHeight, 2, 2, wScale, bValue, scValue, bnTimeConst, spatial, device);
|
|
var c2 = ConvBatchNormalizationLayer(c1, outFeatureMapCount, kernelWidth, kernelHeight, 1, 1, wScale, bValue, scValue, bnTimeConst, spatial, device);
|
|
|
|
var cProj = ProjectLayer(wProj, input, 2, 2, bValue, scValue, bnTimeConst, device);
|
|
|
|
var p = CNTKLib.Plus(c2, cProj);
|
|
return CNTKLib.ReLU(p);
|
|
}
|
|
|
|
private static Function ProjectLayer(Variable wProj, Variable input, int hStride, int vStride, double bValue, double scValue, int bnTimeConst,
|
|
DeviceDescriptor device)
|
|
{
|
|
int outFeatureMapCount = wProj.Shape[0];
|
|
var b = new Parameter(new int[] { outFeatureMapCount }, (float)bValue, device, "");
|
|
var sc = new Parameter(new int[] { outFeatureMapCount }, (float)scValue, device, "");
|
|
var m = new Constant(new int[] { outFeatureMapCount }, 0.0f, device);
|
|
var v = new Constant(new int[] { outFeatureMapCount }, 0.0f, device);
|
|
|
|
var n = Constant.Scalar(0.0f, device);
|
|
|
|
int numInputChannels = input.Shape[input.Shape.Rank - 1];
|
|
|
|
var c = CNTKLib.Convolution(wProj, input, new int[] { hStride, vStride, numInputChannels }, new bool[] { true }, new bool[] { false });
|
|
return CNTKLib.BatchNormalization(c, sc, b, m, v, n, true /*spatial*/, (double)bnTimeConst, 0, 1e-5, false);
|
|
}
|
|
|
|
private static Constant GetProjectionMap(int outputDim, int inputDim, DeviceDescriptor device)
|
|
{
|
|
if (inputDim > outputDim)
|
|
throw new Exception("Can only project from lower to higher dimensionality");
|
|
|
|
float[] projectionMapValues = new float[inputDim * outputDim];
|
|
for (int i = 0; i < inputDim * outputDim; i++)
|
|
projectionMapValues[i] = 0;
|
|
for (int i = 0; i < inputDim; ++i)
|
|
projectionMapValues[(i * (int)inputDim) + i] = 1.0f;
|
|
|
|
var projectionMap = new NDArrayView(DataType.Float, new int[] { 1, 1, inputDim, outputDim }, device);
|
|
projectionMap.CopyFrom(new NDArrayView(new int[] { 1, 1, inputDim, outputDim }, projectionMapValues, (uint)projectionMapValues.Count(), device));
|
|
|
|
return new Constant(projectionMap);
|
|
}
|
|
|
|
/// <summary>
|
|
/// Build a Resnet for image classification.
|
|
/// https://arxiv.org/abs/1512.03385
|
|
/// </summary>
|
|
/// <param name="input">input variable for image data</param>
|
|
/// <param name="numOutputClasses">number of output classes</param>
|
|
/// <param name="device">CPU or GPU device to run</param>
|
|
/// <param name="outputName">name of the classifier</param>
|
|
/// <returns></returns>
|
|
private static Function ResNetClassifier(Variable input, int numOutputClasses, DeviceDescriptor device, string outputName)
|
|
{
|
|
double convWScale = 7.07;
|
|
double convBValue = 0;
|
|
|
|
double fc1WScale = 0.4;
|
|
double fc1BValue = 0;
|
|
|
|
double scValue = 1;
|
|
int bnTimeConst = 4096;
|
|
|
|
int kernelWidth = 3;
|
|
int kernelHeight = 3;
|
|
|
|
double conv1WScale = 0.26;
|
|
int cMap1 = 16;
|
|
var conv1 = ConvBatchNormalizationReLULayer(input, cMap1, kernelWidth, kernelHeight, 1, 1, conv1WScale, convBValue, scValue, bnTimeConst, true /*spatial*/, device);
|
|
|
|
var rn1_1 = ResNetNode(conv1, cMap1, kernelWidth, kernelHeight, convWScale, convBValue, scValue, bnTimeConst, false /*spatial*/, device);
|
|
var rn1_2 = ResNetNode(rn1_1, cMap1, kernelWidth, kernelHeight, convWScale, convBValue, scValue, bnTimeConst, true /*spatial*/, device);
|
|
var rn1_3 = ResNetNode(rn1_2, cMap1, kernelWidth, kernelHeight, convWScale, convBValue, scValue, bnTimeConst, false /*spatial*/, device);
|
|
|
|
int cMap2 = 32;
|
|
var rn2_1_wProj = GetProjectionMap(cMap2, cMap1, device);
|
|
var rn2_1 = ResNetNodeInc(rn1_3, cMap2, kernelWidth, kernelHeight, convWScale, convBValue, scValue, bnTimeConst, true /*spatial*/, rn2_1_wProj, device);
|
|
var rn2_2 = ResNetNode(rn2_1, cMap2, kernelWidth, kernelHeight, convWScale, convBValue, scValue, bnTimeConst, false /*spatial*/, device);
|
|
var rn2_3 = ResNetNode(rn2_2, cMap2, kernelWidth, kernelHeight, convWScale, convBValue, scValue, bnTimeConst, true /*spatial*/, device);
|
|
|
|
int cMap3 = 64;
|
|
var rn3_1_wProj = GetProjectionMap(cMap3, cMap2, device);
|
|
var rn3_1 = ResNetNodeInc(rn2_3, cMap3, kernelWidth, kernelHeight, convWScale, convBValue, scValue, bnTimeConst, true /*spatial*/, rn3_1_wProj, device);
|
|
var rn3_2 = ResNetNode(rn3_1, cMap3, kernelWidth, kernelHeight, convWScale, convBValue, scValue, bnTimeConst, false /*spatial*/, device);
|
|
var rn3_3 = ResNetNode(rn3_2, cMap3, kernelWidth, kernelHeight, convWScale, convBValue, scValue, bnTimeConst, false /*spatial*/, device);
|
|
|
|
// Global average pooling
|
|
int poolW = 8;
|
|
int poolH = 8;
|
|
int poolhStride = 1;
|
|
int poolvStride = 1;
|
|
var pool = CNTKLib.Pooling(rn3_3, PoolingType.Average,
|
|
new int[] { poolW, poolH, 1 }, new int[] { poolhStride, poolvStride, 1 });
|
|
|
|
// Output DNN layer
|
|
var outTimesParams = new Parameter(new int[] { numOutputClasses, 1, 1, cMap3 }, DataType.Float,
|
|
CNTKLib.GlorotUniformInitializer(fc1WScale, 1, 0), device);
|
|
var outBiasParams = new Parameter(new int[] { numOutputClasses }, (float)fc1BValue, device, "");
|
|
|
|
return CNTKLib.Plus(CNTKLib.Times(outTimesParams, pool), outBiasParams, outputName);
|
|
}
|
|
|
|
private static MinibatchSource CreateMinibatchSource(string mapFilePath, string meanFilePath,
|
|
int[] imageDims, int numClasses, uint maxSweeps)
|
|
{
|
|
List<CNTKDictionary> transforms = new List<CNTKDictionary>{
|
|
CNTKLib.ReaderCrop("RandomSide",
|
|
new Tuple<int, int>(0, 0),
|
|
new Tuple<float, float>(0.8f, 1.0f),
|
|
new Tuple<float, float>(0.0f, 0.0f),
|
|
new Tuple<float, float>(1.0f, 1.0f),
|
|
"uniRatio"),
|
|
CNTKLib.ReaderScale(imageDims[0], imageDims[1], imageDims[2]),
|
|
CNTKLib.ReaderMean(meanFilePath)
|
|
};
|
|
|
|
var deserializerConfiguration = CNTKLib.ImageDeserializer(mapFilePath,
|
|
"labels", (uint)numClasses,
|
|
"features",
|
|
transforms);
|
|
|
|
MinibatchSourceConfig config = new MinibatchSourceConfig(new List<CNTKDictionary> { deserializerConfiguration })
|
|
{
|
|
MaxSweeps = maxSweeps
|
|
};
|
|
|
|
return CNTKLib.CreateCompositeMinibatchSource(config);
|
|
}
|
|
}
|
|
}
|