226 строки
11 KiB
C#
226 строки
11 KiB
C#
using System;
|
|
using System.Collections.Generic;
|
|
using System.IO;
|
|
|
|
namespace CNTK.CSTrainingExamples
|
|
{
|
|
/// <summary>
|
|
/// This class shows how to build a recurrent neural network model from ground up and train the model.
|
|
/// </summary>
|
|
public class LSTMSequenceClassifier
|
|
{
|
|
/// <summary>
|
|
/// Execution folder is: CNTK/x64/BuildFolder
|
|
/// Data folder is: CNTK/Tests/EndToEndTests/Text/SequenceClassification/Data
|
|
/// </summary>
|
|
public static string DataFolder = TestCommon.TestDataDirPrefix + "Tests/EndToEndTests/Text/SequenceClassification/Data";
|
|
|
|
/// <summary>
|
|
/// Build and train a RNN model.
|
|
/// </summary>
|
|
/// <param name="device">CPU or GPU device to train and run the model</param>
|
|
public static void Train(DeviceDescriptor device)
|
|
{
|
|
const int inputDim = 2000;
|
|
const int cellDim = 25;
|
|
const int hiddenDim = 25;
|
|
const int embeddingDim = 50;
|
|
const int numOutputClasses = 5;
|
|
|
|
// build the model
|
|
var featuresName = "features";
|
|
var features = Variable.InputVariable(new int[] { inputDim }, DataType.Float, featuresName, null, true /*isSparse*/);
|
|
var labelsName = "labels";
|
|
var labels = Variable.InputVariable(new int[] { numOutputClasses }, DataType.Float, labelsName,
|
|
new List<Axis>() { Axis.DefaultBatchAxis() }, true);
|
|
|
|
var classifierOutput = LSTMSequenceClassifierNet(features, numOutputClasses, embeddingDim, hiddenDim, cellDim, device, "classifierOutput");
|
|
Function trainingLoss = CNTKLib.CrossEntropyWithSoftmax(classifierOutput, labels, "lossFunction");
|
|
Function prediction = CNTKLib.ClassificationError(classifierOutput, labels, "classificationError");
|
|
|
|
// prepare training data
|
|
IList<StreamConfiguration> streamConfigurations = new StreamConfiguration[]
|
|
{ new StreamConfiguration(featuresName, inputDim, true, "x"), new StreamConfiguration(labelsName, numOutputClasses, false, "y") };
|
|
var minibatchSource = MinibatchSource.TextFormatMinibatchSource(
|
|
Path.Combine(DataFolder, "Train.ctf"), streamConfigurations,
|
|
MinibatchSource.InfinitelyRepeat, true);
|
|
var featureStreamInfo = minibatchSource.StreamInfo(featuresName);
|
|
var labelStreamInfo = minibatchSource.StreamInfo(labelsName);
|
|
|
|
// prepare for training
|
|
TrainingParameterScheduleDouble learningRatePerSample = new TrainingParameterScheduleDouble(
|
|
0.0005, 1);
|
|
TrainingParameterScheduleDouble momentumTimeConstant = CNTKLib.MomentumAsTimeConstantSchedule(256);
|
|
IList<Learner> parameterLearners = new List<Learner>() {
|
|
Learner.MomentumSGDLearner(classifierOutput.Parameters(), learningRatePerSample, momentumTimeConstant, /*unitGainMomentum = */true) };
|
|
var trainer = Trainer.CreateTrainer(classifierOutput, trainingLoss, prediction, parameterLearners);
|
|
|
|
// train the model
|
|
uint minibatchSize = 200;
|
|
int outputFrequencyInMinibatches = 20;
|
|
int miniBatchCount = 0;
|
|
int numEpochs = 5;
|
|
while (numEpochs > 0)
|
|
{
|
|
var minibatchData = minibatchSource.GetNextMinibatch(minibatchSize, device);
|
|
|
|
var arguments = new Dictionary<Variable, MinibatchData>
|
|
{
|
|
{ features, minibatchData[featureStreamInfo] },
|
|
{ labels, minibatchData[labelStreamInfo] }
|
|
};
|
|
|
|
trainer.TrainMinibatch(arguments, device);
|
|
TestHelper.PrintTrainingProgress(trainer, miniBatchCount++, outputFrequencyInMinibatches);
|
|
|
|
// Because minibatchSource is created with MinibatchSource.InfinitelyRepeat,
|
|
// batching will not end. Each time minibatchSource completes an sweep (epoch),
|
|
// the last minibatch data will be marked as end of a sweep. We use this flag
|
|
// to count number of epochs.
|
|
if (TestHelper.MiniBatchDataIsSweepEnd(minibatchData.Values))
|
|
{
|
|
numEpochs--;
|
|
}
|
|
}
|
|
}
|
|
|
|
static Function Stabilize<ElementType>(Variable x, DeviceDescriptor device)
|
|
{
|
|
bool isFloatType = typeof(ElementType).Equals(typeof(float));
|
|
Constant f, fInv;
|
|
if (isFloatType)
|
|
{
|
|
f = Constant.Scalar(4.0f, device);
|
|
fInv = Constant.Scalar(f.DataType, 1.0 / 4.0f);
|
|
}
|
|
else
|
|
{
|
|
f = Constant.Scalar(4.0, device);
|
|
fInv = Constant.Scalar(f.DataType, 1.0 / 4.0f);
|
|
}
|
|
|
|
var beta = CNTKLib.ElementTimes(
|
|
fInv,
|
|
CNTKLib.Log(
|
|
Constant.Scalar(f.DataType, 1.0) +
|
|
CNTKLib.Exp(CNTKLib.ElementTimes(f, new Parameter(new NDShape(), f.DataType, 0.99537863 /* 1/f*ln (e^f-1) */, device)))));
|
|
return CNTKLib.ElementTimes(beta, x);
|
|
}
|
|
|
|
static Tuple<Function, Function> LSTMPCellWithSelfStabilization<ElementType>(
|
|
Variable input, Variable prevOutput, Variable prevCellState, DeviceDescriptor device)
|
|
{
|
|
int outputDim = prevOutput.Shape[0];
|
|
int cellDim = prevCellState.Shape[0];
|
|
|
|
bool isFloatType = typeof(ElementType).Equals(typeof(float));
|
|
DataType dataType = isFloatType ? DataType.Float : DataType.Double;
|
|
|
|
Func<int, Parameter> createBiasParam;
|
|
if (isFloatType)
|
|
createBiasParam = (dim) => new Parameter(new int[] { dim }, 0.01f, device, "");
|
|
else
|
|
createBiasParam = (dim) => new Parameter(new int[] { dim }, 0.01, device, "");
|
|
|
|
uint seed2 = 1;
|
|
Func<int, Parameter> createProjectionParam = (oDim) => new Parameter(new int[] { oDim, NDShape.InferredDimension },
|
|
dataType, CNTKLib.GlorotUniformInitializer(1.0, 1, 0, seed2++), device);
|
|
|
|
Func<int, Parameter> createDiagWeightParam = (dim) =>
|
|
new Parameter(new int[] { dim }, dataType, CNTKLib.GlorotUniformInitializer(1.0, 1, 0, seed2++), device);
|
|
|
|
Function stabilizedPrevOutput = Stabilize<ElementType>(prevOutput, device);
|
|
Function stabilizedPrevCellState = Stabilize<ElementType>(prevCellState, device);
|
|
|
|
Func<Variable> projectInput = () =>
|
|
createBiasParam(cellDim) + (createProjectionParam(cellDim) * input);
|
|
|
|
// Input gate
|
|
Function it =
|
|
CNTKLib.Sigmoid(
|
|
(Variable)(projectInput() + (createProjectionParam(cellDim) * stabilizedPrevOutput)) +
|
|
CNTKLib.ElementTimes(createDiagWeightParam(cellDim), stabilizedPrevCellState));
|
|
Function bit = CNTKLib.ElementTimes(
|
|
it,
|
|
CNTKLib.Tanh(projectInput() + (createProjectionParam(cellDim) * stabilizedPrevOutput)));
|
|
|
|
// Forget-me-not gate
|
|
Function ft = CNTKLib.Sigmoid(
|
|
(Variable)(
|
|
projectInput() + (createProjectionParam(cellDim) * stabilizedPrevOutput)) +
|
|
CNTKLib.ElementTimes(createDiagWeightParam(cellDim), stabilizedPrevCellState));
|
|
Function bft = CNTKLib.ElementTimes(ft, prevCellState);
|
|
|
|
Function ct = (Variable)bft + bit;
|
|
|
|
// Output gate
|
|
Function ot = CNTKLib.Sigmoid(
|
|
(Variable)(projectInput() + (createProjectionParam(cellDim) * stabilizedPrevOutput)) +
|
|
CNTKLib.ElementTimes(createDiagWeightParam(cellDim), Stabilize<ElementType>(ct, device)));
|
|
Function ht = CNTKLib.ElementTimes(ot, CNTKLib.Tanh(ct));
|
|
|
|
Function c = ct;
|
|
Function h = (outputDim != cellDim) ? (createProjectionParam(outputDim) * Stabilize<ElementType>(ht, device)) : ht;
|
|
|
|
return new Tuple<Function, Function>(h, c);
|
|
}
|
|
|
|
|
|
static Tuple<Function, Function> LSTMPComponentWithSelfStabilization<ElementType>(Variable input,
|
|
NDShape outputShape, NDShape cellShape,
|
|
Func<Variable, Function> recurrenceHookH,
|
|
Func<Variable, Function> recurrenceHookC,
|
|
DeviceDescriptor device)
|
|
{
|
|
var dh = Variable.PlaceholderVariable(outputShape, input.DynamicAxes);
|
|
var dc = Variable.PlaceholderVariable(cellShape, input.DynamicAxes);
|
|
|
|
var LSTMCell = LSTMPCellWithSelfStabilization<ElementType>(input, dh, dc, device);
|
|
var actualDh = recurrenceHookH(LSTMCell.Item1);
|
|
var actualDc = recurrenceHookC(LSTMCell.Item2);
|
|
|
|
// Form the recurrence loop by replacing the dh and dc placeholders with the actualDh and actualDc
|
|
(LSTMCell.Item1).ReplacePlaceholders(new Dictionary<Variable, Variable> { { dh, actualDh }, { dc, actualDc } });
|
|
|
|
return new Tuple<Function, Function>(LSTMCell.Item1, LSTMCell.Item2);
|
|
}
|
|
|
|
private static Function Embedding(Variable input, int embeddingDim, DeviceDescriptor device)
|
|
{
|
|
System.Diagnostics.Debug.Assert(input.Shape.Rank == 1);
|
|
int inputDim = input.Shape[0];
|
|
var embeddingParameters = new Parameter(new int[] { embeddingDim, inputDim }, DataType.Float, CNTKLib.GlorotUniformInitializer(), device);
|
|
return CNTKLib.Times(embeddingParameters, input);
|
|
}
|
|
|
|
/// <summary>
|
|
/// Build a one direction recurrent neural network (RNN) with long-short-term-memory (LSTM) cells.
|
|
/// http://colah.github.io/posts/2015-08-Understanding-LSTMs/
|
|
/// </summary>
|
|
/// <param name="input">the input variable</param>
|
|
/// <param name="numOutputClasses">number of output classes</param>
|
|
/// <param name="embeddingDim">dimension of the embedding layer</param>
|
|
/// <param name="LSTMDim">LSTM output dimension</param>
|
|
/// <param name="cellDim">cell dimension</param>
|
|
/// <param name="device">CPU or GPU device to run the model</param>
|
|
/// <param name="outputName">name of the model output</param>
|
|
/// <returns>the RNN model</returns>
|
|
static Function LSTMSequenceClassifierNet(Variable input, int numOutputClasses, int embeddingDim, int LSTMDim, int cellDim, DeviceDescriptor device,
|
|
string outputName)
|
|
{
|
|
Function embeddingFunction = Embedding(input, embeddingDim, device);
|
|
Func<Variable, Function> pastValueRecurrenceHook = (x) => CNTKLib.PastValue(x);
|
|
Function LSTMFunction = LSTMPComponentWithSelfStabilization<float>(
|
|
embeddingFunction,
|
|
new int[] { LSTMDim },
|
|
new int[] { cellDim },
|
|
pastValueRecurrenceHook,
|
|
pastValueRecurrenceHook,
|
|
device).Item1;
|
|
Function thoughtVectorFunction = CNTKLib.SequenceLast(LSTMFunction);
|
|
|
|
return TestHelper.FullyConnectedLinearLayer(thoughtVectorFunction, numOutputClasses, device, outputName);
|
|
}
|
|
}
|
|
}
|