using System;
using System.Collections.Generic;
using System.IO;
namespace CNTK.CSTrainingExamples
{
///
/// This class shows how to build a recurrent neural network model from ground up and train the model.
///
public class LSTMSequenceClassifier
{
///
/// Execution folder is: CNTK/x64/BuildFolder
/// Data folder is: CNTK/Tests/EndToEndTests/Text/SequenceClassification/Data
///
public static string DataFolder = TestCommon.TestDataDirPrefix + "Tests/EndToEndTests/Text/SequenceClassification/Data";
///
/// Build and train a RNN model.
///
/// CPU or GPU device to train and run the model
public static void Train(DeviceDescriptor device)
{
const int inputDim = 2000;
const int cellDim = 25;
const int hiddenDim = 25;
const int embeddingDim = 50;
const int numOutputClasses = 5;
// build the model
var featuresName = "features";
var features = Variable.InputVariable(new int[] { inputDim }, DataType.Float, featuresName, null, true /*isSparse*/);
var labelsName = "labels";
var labels = Variable.InputVariable(new int[] { numOutputClasses }, DataType.Float, labelsName,
new List() { Axis.DefaultBatchAxis() }, true);
var classifierOutput = LSTMSequenceClassifierNet(features, numOutputClasses, embeddingDim, hiddenDim, cellDim, device, "classifierOutput");
Function trainingLoss = CNTKLib.CrossEntropyWithSoftmax(classifierOutput, labels, "lossFunction");
Function prediction = CNTKLib.ClassificationError(classifierOutput, labels, "classificationError");
// prepare training data
IList streamConfigurations = new StreamConfiguration[]
{ new StreamConfiguration(featuresName, inputDim, true, "x"), new StreamConfiguration(labelsName, numOutputClasses, false, "y") };
var minibatchSource = MinibatchSource.TextFormatMinibatchSource(
Path.Combine(DataFolder, "Train.ctf"), streamConfigurations,
MinibatchSource.InfinitelyRepeat, true);
var featureStreamInfo = minibatchSource.StreamInfo(featuresName);
var labelStreamInfo = minibatchSource.StreamInfo(labelsName);
// prepare for training
TrainingParameterScheduleDouble learningRatePerSample = new TrainingParameterScheduleDouble(
0.0005, 1);
TrainingParameterScheduleDouble momentumTimeConstant = CNTKLib.MomentumAsTimeConstantSchedule(256);
IList parameterLearners = new List() {
Learner.MomentumSGDLearner(classifierOutput.Parameters(), learningRatePerSample, momentumTimeConstant, /*unitGainMomentum = */true) };
var trainer = Trainer.CreateTrainer(classifierOutput, trainingLoss, prediction, parameterLearners);
// train the model
uint minibatchSize = 200;
int outputFrequencyInMinibatches = 20;
int miniBatchCount = 0;
int numEpochs = 5;
while (numEpochs > 0)
{
var minibatchData = minibatchSource.GetNextMinibatch(minibatchSize, device);
var arguments = new Dictionary
{
{ features, minibatchData[featureStreamInfo] },
{ labels, minibatchData[labelStreamInfo] }
};
trainer.TrainMinibatch(arguments, device);
TestHelper.PrintTrainingProgress(trainer, miniBatchCount++, outputFrequencyInMinibatches);
// Because minibatchSource is created with MinibatchSource.InfinitelyRepeat,
// batching will not end. Each time minibatchSource completes an sweep (epoch),
// the last minibatch data will be marked as end of a sweep. We use this flag
// to count number of epochs.
if (TestHelper.MiniBatchDataIsSweepEnd(minibatchData.Values))
{
numEpochs--;
}
}
}
static Function Stabilize(Variable x, DeviceDescriptor device)
{
bool isFloatType = typeof(ElementType).Equals(typeof(float));
Constant f, fInv;
if (isFloatType)
{
f = Constant.Scalar(4.0f, device);
fInv = Constant.Scalar(f.DataType, 1.0 / 4.0f);
}
else
{
f = Constant.Scalar(4.0, device);
fInv = Constant.Scalar(f.DataType, 1.0 / 4.0f);
}
var beta = CNTKLib.ElementTimes(
fInv,
CNTKLib.Log(
Constant.Scalar(f.DataType, 1.0) +
CNTKLib.Exp(CNTKLib.ElementTimes(f, new Parameter(new NDShape(), f.DataType, 0.99537863 /* 1/f*ln (e^f-1) */, device)))));
return CNTKLib.ElementTimes(beta, x);
}
static Tuple LSTMPCellWithSelfStabilization(
Variable input, Variable prevOutput, Variable prevCellState, DeviceDescriptor device)
{
int outputDim = prevOutput.Shape[0];
int cellDim = prevCellState.Shape[0];
bool isFloatType = typeof(ElementType).Equals(typeof(float));
DataType dataType = isFloatType ? DataType.Float : DataType.Double;
Func createBiasParam;
if (isFloatType)
createBiasParam = (dim) => new Parameter(new int[] { dim }, 0.01f, device, "");
else
createBiasParam = (dim) => new Parameter(new int[] { dim }, 0.01, device, "");
uint seed2 = 1;
Func createProjectionParam = (oDim) => new Parameter(new int[] { oDim, NDShape.InferredDimension },
dataType, CNTKLib.GlorotUniformInitializer(1.0, 1, 0, seed2++), device);
Func createDiagWeightParam = (dim) =>
new Parameter(new int[] { dim }, dataType, CNTKLib.GlorotUniformInitializer(1.0, 1, 0, seed2++), device);
Function stabilizedPrevOutput = Stabilize(prevOutput, device);
Function stabilizedPrevCellState = Stabilize(prevCellState, device);
Func projectInput = () =>
createBiasParam(cellDim) + (createProjectionParam(cellDim) * input);
// Input gate
Function it =
CNTKLib.Sigmoid(
(Variable)(projectInput() + (createProjectionParam(cellDim) * stabilizedPrevOutput)) +
CNTKLib.ElementTimes(createDiagWeightParam(cellDim), stabilizedPrevCellState));
Function bit = CNTKLib.ElementTimes(
it,
CNTKLib.Tanh(projectInput() + (createProjectionParam(cellDim) * stabilizedPrevOutput)));
// Forget-me-not gate
Function ft = CNTKLib.Sigmoid(
(Variable)(
projectInput() + (createProjectionParam(cellDim) * stabilizedPrevOutput)) +
CNTKLib.ElementTimes(createDiagWeightParam(cellDim), stabilizedPrevCellState));
Function bft = CNTKLib.ElementTimes(ft, prevCellState);
Function ct = (Variable)bft + bit;
// Output gate
Function ot = CNTKLib.Sigmoid(
(Variable)(projectInput() + (createProjectionParam(cellDim) * stabilizedPrevOutput)) +
CNTKLib.ElementTimes(createDiagWeightParam(cellDim), Stabilize(ct, device)));
Function ht = CNTKLib.ElementTimes(ot, CNTKLib.Tanh(ct));
Function c = ct;
Function h = (outputDim != cellDim) ? (createProjectionParam(outputDim) * Stabilize(ht, device)) : ht;
return new Tuple(h, c);
}
static Tuple LSTMPComponentWithSelfStabilization(Variable input,
NDShape outputShape, NDShape cellShape,
Func recurrenceHookH,
Func recurrenceHookC,
DeviceDescriptor device)
{
var dh = Variable.PlaceholderVariable(outputShape, input.DynamicAxes);
var dc = Variable.PlaceholderVariable(cellShape, input.DynamicAxes);
var LSTMCell = LSTMPCellWithSelfStabilization(input, dh, dc, device);
var actualDh = recurrenceHookH(LSTMCell.Item1);
var actualDc = recurrenceHookC(LSTMCell.Item2);
// Form the recurrence loop by replacing the dh and dc placeholders with the actualDh and actualDc
(LSTMCell.Item1).ReplacePlaceholders(new Dictionary { { dh, actualDh }, { dc, actualDc } });
return new Tuple(LSTMCell.Item1, LSTMCell.Item2);
}
private static Function Embedding(Variable input, int embeddingDim, DeviceDescriptor device)
{
System.Diagnostics.Debug.Assert(input.Shape.Rank == 1);
int inputDim = input.Shape[0];
var embeddingParameters = new Parameter(new int[] { embeddingDim, inputDim }, DataType.Float, CNTKLib.GlorotUniformInitializer(), device);
return CNTKLib.Times(embeddingParameters, input);
}
///
/// Build a one direction recurrent neural network (RNN) with long-short-term-memory (LSTM) cells.
/// http://colah.github.io/posts/2015-08-Understanding-LSTMs/
///
/// the input variable
/// number of output classes
/// dimension of the embedding layer
/// LSTM output dimension
/// cell dimension
/// CPU or GPU device to run the model
/// name of the model output
/// the RNN model
static Function LSTMSequenceClassifierNet(Variable input, int numOutputClasses, int embeddingDim, int LSTMDim, int cellDim, DeviceDescriptor device,
string outputName)
{
Function embeddingFunction = Embedding(input, embeddingDim, device);
Func pastValueRecurrenceHook = (x) => CNTKLib.PastValue(x);
Function LSTMFunction = LSTMPComponentWithSelfStabilization(
embeddingFunction,
new int[] { LSTMDim },
new int[] { cellDim },
pastValueRecurrenceHook,
pastValueRecurrenceHook,
device).Item1;
Function thoughtVectorFunction = CNTKLib.SequenceLast(LSTMFunction);
return TestHelper.FullyConnectedLinearLayer(thoughtVectorFunction, numOutputClasses, device, outputName);
}
}
}