using System; using System.Collections.Generic; using System.IO; namespace CNTK.CSTrainingExamples { /// /// This class shows how to build a recurrent neural network model from ground up and train the model. /// public class LSTMSequenceClassifier { /// /// Execution folder is: CNTK/x64/BuildFolder /// Data folder is: CNTK/Tests/EndToEndTests/Text/SequenceClassification/Data /// public static string DataFolder = TestCommon.TestDataDirPrefix + "Tests/EndToEndTests/Text/SequenceClassification/Data"; /// /// Build and train a RNN model. /// /// CPU or GPU device to train and run the model public static void Train(DeviceDescriptor device) { const int inputDim = 2000; const int cellDim = 25; const int hiddenDim = 25; const int embeddingDim = 50; const int numOutputClasses = 5; // build the model var featuresName = "features"; var features = Variable.InputVariable(new int[] { inputDim }, DataType.Float, featuresName, null, true /*isSparse*/); var labelsName = "labels"; var labels = Variable.InputVariable(new int[] { numOutputClasses }, DataType.Float, labelsName, new List() { Axis.DefaultBatchAxis() }, true); var classifierOutput = LSTMSequenceClassifierNet(features, numOutputClasses, embeddingDim, hiddenDim, cellDim, device, "classifierOutput"); Function trainingLoss = CNTKLib.CrossEntropyWithSoftmax(classifierOutput, labels, "lossFunction"); Function prediction = CNTKLib.ClassificationError(classifierOutput, labels, "classificationError"); // prepare training data IList streamConfigurations = new StreamConfiguration[] { new StreamConfiguration(featuresName, inputDim, true, "x"), new StreamConfiguration(labelsName, numOutputClasses, false, "y") }; var minibatchSource = MinibatchSource.TextFormatMinibatchSource( Path.Combine(DataFolder, "Train.ctf"), streamConfigurations, MinibatchSource.InfinitelyRepeat, true); var featureStreamInfo = minibatchSource.StreamInfo(featuresName); var labelStreamInfo = minibatchSource.StreamInfo(labelsName); // prepare for training TrainingParameterScheduleDouble learningRatePerSample = new TrainingParameterScheduleDouble( 0.0005, 1); TrainingParameterScheduleDouble momentumTimeConstant = CNTKLib.MomentumAsTimeConstantSchedule(256); IList parameterLearners = new List() { Learner.MomentumSGDLearner(classifierOutput.Parameters(), learningRatePerSample, momentumTimeConstant, /*unitGainMomentum = */true) }; var trainer = Trainer.CreateTrainer(classifierOutput, trainingLoss, prediction, parameterLearners); // train the model uint minibatchSize = 200; int outputFrequencyInMinibatches = 20; int miniBatchCount = 0; int numEpochs = 5; while (numEpochs > 0) { var minibatchData = minibatchSource.GetNextMinibatch(minibatchSize, device); var arguments = new Dictionary { { features, minibatchData[featureStreamInfo] }, { labels, minibatchData[labelStreamInfo] } }; trainer.TrainMinibatch(arguments, device); TestHelper.PrintTrainingProgress(trainer, miniBatchCount++, outputFrequencyInMinibatches); // Because minibatchSource is created with MinibatchSource.InfinitelyRepeat, // batching will not end. Each time minibatchSource completes an sweep (epoch), // the last minibatch data will be marked as end of a sweep. We use this flag // to count number of epochs. if (TestHelper.MiniBatchDataIsSweepEnd(minibatchData.Values)) { numEpochs--; } } } static Function Stabilize(Variable x, DeviceDescriptor device) { bool isFloatType = typeof(ElementType).Equals(typeof(float)); Constant f, fInv; if (isFloatType) { f = Constant.Scalar(4.0f, device); fInv = Constant.Scalar(f.DataType, 1.0 / 4.0f); } else { f = Constant.Scalar(4.0, device); fInv = Constant.Scalar(f.DataType, 1.0 / 4.0f); } var beta = CNTKLib.ElementTimes( fInv, CNTKLib.Log( Constant.Scalar(f.DataType, 1.0) + CNTKLib.Exp(CNTKLib.ElementTimes(f, new Parameter(new NDShape(), f.DataType, 0.99537863 /* 1/f*ln (e^f-1) */, device))))); return CNTKLib.ElementTimes(beta, x); } static Tuple LSTMPCellWithSelfStabilization( Variable input, Variable prevOutput, Variable prevCellState, DeviceDescriptor device) { int outputDim = prevOutput.Shape[0]; int cellDim = prevCellState.Shape[0]; bool isFloatType = typeof(ElementType).Equals(typeof(float)); DataType dataType = isFloatType ? DataType.Float : DataType.Double; Func createBiasParam; if (isFloatType) createBiasParam = (dim) => new Parameter(new int[] { dim }, 0.01f, device, ""); else createBiasParam = (dim) => new Parameter(new int[] { dim }, 0.01, device, ""); uint seed2 = 1; Func createProjectionParam = (oDim) => new Parameter(new int[] { oDim, NDShape.InferredDimension }, dataType, CNTKLib.GlorotUniformInitializer(1.0, 1, 0, seed2++), device); Func createDiagWeightParam = (dim) => new Parameter(new int[] { dim }, dataType, CNTKLib.GlorotUniformInitializer(1.0, 1, 0, seed2++), device); Function stabilizedPrevOutput = Stabilize(prevOutput, device); Function stabilizedPrevCellState = Stabilize(prevCellState, device); Func projectInput = () => createBiasParam(cellDim) + (createProjectionParam(cellDim) * input); // Input gate Function it = CNTKLib.Sigmoid( (Variable)(projectInput() + (createProjectionParam(cellDim) * stabilizedPrevOutput)) + CNTKLib.ElementTimes(createDiagWeightParam(cellDim), stabilizedPrevCellState)); Function bit = CNTKLib.ElementTimes( it, CNTKLib.Tanh(projectInput() + (createProjectionParam(cellDim) * stabilizedPrevOutput))); // Forget-me-not gate Function ft = CNTKLib.Sigmoid( (Variable)( projectInput() + (createProjectionParam(cellDim) * stabilizedPrevOutput)) + CNTKLib.ElementTimes(createDiagWeightParam(cellDim), stabilizedPrevCellState)); Function bft = CNTKLib.ElementTimes(ft, prevCellState); Function ct = (Variable)bft + bit; // Output gate Function ot = CNTKLib.Sigmoid( (Variable)(projectInput() + (createProjectionParam(cellDim) * stabilizedPrevOutput)) + CNTKLib.ElementTimes(createDiagWeightParam(cellDim), Stabilize(ct, device))); Function ht = CNTKLib.ElementTimes(ot, CNTKLib.Tanh(ct)); Function c = ct; Function h = (outputDim != cellDim) ? (createProjectionParam(outputDim) * Stabilize(ht, device)) : ht; return new Tuple(h, c); } static Tuple LSTMPComponentWithSelfStabilization(Variable input, NDShape outputShape, NDShape cellShape, Func recurrenceHookH, Func recurrenceHookC, DeviceDescriptor device) { var dh = Variable.PlaceholderVariable(outputShape, input.DynamicAxes); var dc = Variable.PlaceholderVariable(cellShape, input.DynamicAxes); var LSTMCell = LSTMPCellWithSelfStabilization(input, dh, dc, device); var actualDh = recurrenceHookH(LSTMCell.Item1); var actualDc = recurrenceHookC(LSTMCell.Item2); // Form the recurrence loop by replacing the dh and dc placeholders with the actualDh and actualDc (LSTMCell.Item1).ReplacePlaceholders(new Dictionary { { dh, actualDh }, { dc, actualDc } }); return new Tuple(LSTMCell.Item1, LSTMCell.Item2); } private static Function Embedding(Variable input, int embeddingDim, DeviceDescriptor device) { System.Diagnostics.Debug.Assert(input.Shape.Rank == 1); int inputDim = input.Shape[0]; var embeddingParameters = new Parameter(new int[] { embeddingDim, inputDim }, DataType.Float, CNTKLib.GlorotUniformInitializer(), device); return CNTKLib.Times(embeddingParameters, input); } /// /// Build a one direction recurrent neural network (RNN) with long-short-term-memory (LSTM) cells. /// http://colah.github.io/posts/2015-08-Understanding-LSTMs/ /// /// the input variable /// number of output classes /// dimension of the embedding layer /// LSTM output dimension /// cell dimension /// CPU or GPU device to run the model /// name of the model output /// the RNN model static Function LSTMSequenceClassifierNet(Variable input, int numOutputClasses, int embeddingDim, int LSTMDim, int cellDim, DeviceDescriptor device, string outputName) { Function embeddingFunction = Embedding(input, embeddingDim, device); Func pastValueRecurrenceHook = (x) => CNTKLib.PastValue(x); Function LSTMFunction = LSTMPComponentWithSelfStabilization( embeddingFunction, new int[] { LSTMDim }, new int[] { cellDim }, pastValueRecurrenceHook, pastValueRecurrenceHook, device).Item1; Function thoughtVectorFunction = CNTKLib.SequenceLast(LSTMFunction); return TestHelper.FullyConnectedLinearLayer(thoughtVectorFunction, numOutputClasses, device, outputName); } } }