From 81f9ac56ee2278be8fa06a2635b53c77821a8dc1 Mon Sep 17 00:00:00 2001 From: kaisheny Date: Sat, 27 Jun 2015 08:45:00 -0700 Subject: [PATCH] local changes for alignment node for attension-based mechanism --- CNTK.sln | 1 + Common/Include/DataReader.h | 2 +- .../LMSequenceReader/SequenceReader.cpp | 78 ++- DataReader/LMSequenceReader/SequenceReader.h | 1 - MachineLearning/CNTK/ComputationNetwork.h | 45 +- MachineLearning/CNTK/ComputationNode.h | 2 +- MachineLearning/CNTK/IComputationNetBuilder.h | 4 +- MachineLearning/CNTK/InputAndParamNodes.h | 153 ++++++ MachineLearning/CNTK/LinearAlgebraNodes.h | 6 +- MachineLearning/CNTK/MultiNetworksSGD.h | 492 ++++++++++++++++-- MachineLearning/CNTK/NDLNetworkBuilder.h | 6 +- .../CNTK/NetworkDescriptionLanguage.cpp | 4 + MachineLearning/CNTK/RecurrentNodes.h | 270 +++++++++- MachineLearning/CNTK/SimpleEvaluator.h | 196 +++++++ MachineLearning/CNTK/SimpleNetworkBuilder.cpp | 119 +++++ MachineLearning/CNTK/SimpleNetworkBuilder.h | 17 +- MachineLearning/CNTK/TrainingCriterionNodes.h | 2 +- 17 files changed, 1303 insertions(+), 95 deletions(-) diff --git a/CNTK.sln b/CNTK.sln index f252eb01e..ac765de21 100644 --- a/CNTK.sln +++ b/CNTK.sln @@ -241,6 +241,7 @@ Global {014DA766-B37B-4581-BC26-963EA5507931} = {33EBFE78-A1A8-4961-8938-92A271941F94} {D667AF32-028A-4A5D-BE19-F46776F0F6B2} = {33EBFE78-A1A8-4961-8938-92A271941F94} {3ED0465D-23E7-4855-9694-F788717B6533} = {39E42C4B-A078-4CA4-9D92-B883D8129601} + {065AF55D-AF02-448B-BFCD-52619FDA4BD0} = {39E42C4B-A078-4CA4-9D92-B883D8129601} {98D2C32B-0C1F-4E19-A626-65F7BA4600CF} = {065AF55D-AF02-448B-BFCD-52619FDA4BD0} {EA67F51F-1FE8-462D-9F3E-01161685AD59} = {065AF55D-AF02-448B-BFCD-52619FDA4BD0} {DE1A06BA-EC5C-4E0D-BCA8-3EA555310C58} = {065AF55D-AF02-448B-BFCD-52619FDA4BD0} diff --git a/Common/Include/DataReader.h b/Common/Include/DataReader.h index 8ec191bf4..a92817be7 100644 --- a/Common/Include/DataReader.h +++ b/Common/Include/DataReader.h @@ -78,7 +78,7 @@ public: virtual bool GetData(const std::wstring&, size_t, void*, size_t&, size_t) { NOT_IMPLEMENTED; }; virtual bool DataEnd(EndDataType) { NOT_IMPLEMENTED; }; virtual void SetSentenceSegBatch(Matrix&, Matrix&) { NOT_IMPLEMENTED; }; - virtual void SetRandomSeed(int) { NOT_IMPLEMENTED; }; + virtual void SetRandomSeed(unsigned seed = 0) { m_seed = seed; }; virtual bool GetProposalObs(std::map*>&, const size_t, vector&) { return false; } virtual void InitProposals(std::map*>&) { } virtual bool CanReadFor(wstring /* nodeName */) { diff --git a/DataReader/LMSequenceReader/SequenceReader.cpp b/DataReader/LMSequenceReader/SequenceReader.cpp index 13ca51f3f..40c2cfa4f 100644 --- a/DataReader/LMSequenceReader/SequenceReader.cpp +++ b/DataReader/LMSequenceReader/SequenceReader.cpp @@ -543,8 +543,8 @@ void SequenceReader::Init(const ConfigParameters& readerConfig) std::wstring m_file = readerConfig("file"); if (m_traceLevel > 0) { - //fprintf(stderr, "reading sequence file %ls\n", m_file.c_str()); - std::wcerr << "reading sequence file" << m_file.c_str() << endl; + fprintf(stderr, "reading sequence file %ls\n", m_file.c_str()); + //std::wcerr << "reading sequence file" << m_file.c_str() << endl; } const LabelInfo& labelIn = m_labelInfo[labelInfoIn]; @@ -1503,8 +1503,8 @@ void BatchSequenceReader::Init(const ConfigParameters& readerConfig) std::wstring m_file = readerConfig("file"); if (m_traceLevel > 0) { - //fwprintf(stderr, L"reading sequence file %s\n", m_file.c_str()); - std::wcerr << "reading sequence file " << m_file.c_str() << endl; + fwprintf(stderr, L"reading sequence file %s\n", m_file.c_str()); + //std::wcerr << "reading sequence file " << m_file.c_str() << endl; } const LabelInfo& labelIn = m_labelInfo[labelInfoIn]; @@ -1986,8 +1986,8 @@ bool BatchSequenceReader::DataEnd(EndDataType endDataType) /// notice that indices are defined as follows [begining ending_indx) of the class /// i.e., the ending_index is 1 plus of the true ending index template -void BatchSequenceReader::GetLabelOutput(std::map*>& matrices, +void BatchSequenceReader::GetLabelOutput(std::map < std::wstring, + Matrix* > & matrices, size_t m_mbStartSample, size_t actualmbsize) { size_t j = 0; @@ -2007,51 +2007,47 @@ void BatchSequenceReader::GetLabelOutput(std::mapTransferFromDeviceToDevice(curDevId, CPUDEVICE, true, false, false); if (labels->GetCurrentMatrixLocation() == CPU) - for (size_t jSample = m_mbStartSample; j < actualmbsize; ++j, ++jSample) - { - // pick the right sample with randomization if desired - size_t jRand = jSample; - int wrd = m_labelIdData[jRand]; - labels->SetValue(0, j, (ElemType)wrd); - SetSentenceEnd(wrd, j, actualmbsize); - - if (readerMode == ReaderMode::NCE) + for (size_t jSample = m_mbStartSample; j < actualmbsize; ++j, ++jSample) { - labels->SetValue(1, j, (ElemType)m.logprob(wrd)); - for (size_t noiseid = 0; noiseid < this->noise_sample_size; noiseid++) + // pick the right sample with randomization if desired + size_t jRand = jSample; + int wrd = m_labelIdData[jRand]; + labels->SetValue(0, j, (ElemType)wrd); + SetSentenceEnd(wrd, j, actualmbsize); + + if (readerMode == ReaderMode::NCE) { - int wid = m.sample(); - labels->SetValue(2 * (noiseid + 1), j, (ElemType)wid); - labels->SetValue(2 * (noiseid + 1) + 1, j, -(ElemType)m.logprob(wid)); - } - } - else if (readerMode == ReaderMode::Class) - { - int clsidx = idx4class[wrd]; - if (class_size > 0){ - - labels->SetValue(1, j, (ElemType)clsidx); - - /// save the [begining ending_indx) of the class - size_t lft = (size_t) (*m_classInfoLocal)(0, clsidx); - size_t rgt = (size_t) (*m_classInfoLocal)(1, clsidx); - if (wrd < lft || lft > rgt || wrd >= rgt) + labels->SetValue(1, j, (ElemType)m.logprob(wrd)); + for (size_t noiseid = 0; noiseid < this->noise_sample_size; noiseid++) { - LogicError("LMSequenceReader::GetLabelOutput word %d should be at least equal to or larger than its class's left index %d; right index %d of its class should be larger or equal to left index %d of its class; word index %d should be smaller than its class's right index %d.\n", wrd, lft, rgt, lft, wrd, rgt); + int wid = m.sample(); + labels->SetValue(2 * (noiseid + 1), j, (ElemType)wid); + labels->SetValue(2 * (noiseid + 1) + 1, j, -(ElemType)m.logprob(wid)); + } + } + else if (readerMode == ReaderMode::Class) + { + int clsidx = idx4class[wrd]; + if (class_size > 0){ + + labels->SetValue(1, j, (ElemType)clsidx); + + /// save the [begining ending_indx) of the class + size_t lft = (size_t)(*m_classInfoLocal)(0, clsidx); + size_t rgt = (size_t)(*m_classInfoLocal)(1, clsidx); + if (wrd < lft || lft > rgt || wrd >= rgt) + { + LogicError("LMSequenceReader::GetLabelOutput word %d should be at least equal to or larger than its class's left index %d; right index %d of its class should be larger or equal to left index %d of its class; word index %d should be smaller than its class's right index %d.\n", wrd, lft, rgt, lft, wrd, rgt); + } + labels->SetValue(2, j, (*m_classInfoLocal)(0, clsidx)); /// begining index of the class + labels->SetValue(3, j, (*m_classInfoLocal)(1, clsidx)); /// end index of the class } - labels->SetValue(2, j, (*m_classInfoLocal)(0, clsidx)); /// begining index of the class - labels->SetValue(3, j, (*m_classInfoLocal)(1, clsidx)); /// end index of the class } } - } else // GPU { RuntimeError("GetLabelOutput::should use CPU for labels "); } - if (curDevId != CPUDEVICE) - { - labels->TransferFromDeviceToDevice(CPUDEVICE, curDevId, true, false, false); - } } template diff --git a/DataReader/LMSequenceReader/SequenceReader.h b/DataReader/LMSequenceReader/SequenceReader.h index 798ebd55d..8578d3b7c 100644 --- a/DataReader/LMSequenceReader/SequenceReader.h +++ b/DataReader/LMSequenceReader/SequenceReader.h @@ -252,7 +252,6 @@ public: virtual bool GetData(const std::wstring& sectionName, size_t numRecords, void* data, size_t& dataBufferSize, size_t recordStart=0); virtual bool DataEnd(EndDataType endDataType); - void SetRandomSeed(int) { NOT_IMPLEMENTED; } }; template diff --git a/MachineLearning/CNTK/ComputationNetwork.h b/MachineLearning/CNTK/ComputationNetwork.h index 06137da63..cb6b5d204 100644 --- a/MachineLearning/CNTK/ComputationNetwork.h +++ b/MachineLearning/CNTK/ComputationNetwork.h @@ -524,7 +524,7 @@ public: } virtual void LoadFromFile(const std::wstring& fileName, const FileOptions fileFormat = FileOptions::fileOptionsBinary, - const bool bAllowNoCriterionNode = false) + const bool bAllowNoCriterionNode = false, ComputationNetwork* anotherNetwork=nullptr) { ClearNet(); @@ -574,7 +574,7 @@ public: std::vector childrenNodes; childrenNodes.resize(numChildren); for (int j = 0; j < numChildren; j++) - childrenNodes[j] = GetNodeFromName(childrenNames[j]); + childrenNodes[j] = GetNodeFromName(childrenNames[j], anotherNetwork); if (nodePtr->OperationName() == RowStackNode::TypeName()) //allow for variable input nodes nodePtr->AttachInputs(childrenNodes); @@ -1074,6 +1074,10 @@ public: newNode = new TimeReverseNode(fstream, modelVersion, m_deviceId, nodeName); else if (nodeType == ParallelNode::TypeName()) newNode = new ParallelNode(fstream, modelVersion, m_deviceId, nodeName); + else if (nodeType == AlignmentNode::TypeName()) + newNode = new AlignmentNode(fstream, modelVersion, m_deviceId, nodeName); + else if (nodeType == PairNetworkNode::TypeName()) + newNode = new PairNetworkNode(fstream, modelVersion, m_deviceId, nodeName); else { fprintf(stderr, "Error creating new ComputationNode of type %ls, with name %ls\n", nodeType.c_str(), nodeName.c_str()); @@ -1106,6 +1110,14 @@ public: return newNode; } + ComputationNodePtr PairNetwork(const ComputationNodePtr & a, const std::wstring nodeName = L"") + { + ComputationNodePtr newNode(new PairNetworkNode(m_deviceId, nodeName)); + newNode->AttachInputs(a); + AddNodeToNet(newNode); + return newNode; + } + ComputationNodePtr CreateSparseInputNode(const std::wstring inputName, const size_t rows, const size_t cols) { ComputationNodePtr newNode(new SparseInputValue(rows, cols, m_deviceId, inputName)); @@ -1128,7 +1140,14 @@ public: return newNode; } - ComputationNodePtr CreateConvolutionNode(const std::wstring nodeName, + ComputationNodePtr CreatePairNetworkNode(const std::wstring inputName, const size_t rows, const size_t cols) + { + ComputationNodePtr newNode(new PairNetworkNode(rows, cols, m_deviceId, inputName)); + AddNodeToNet(newNode); + return newNode; + } + + ComputationNodePtr CreateConvolutionNode(const std::wstring nodeName, const size_t kernelWidth, const size_t kernelHeight, const size_t outputChannels, const size_t horizontalSubsample, const size_t verticalSubsample, const bool zeroPadding = false, const size_t maxTempMemSizeInSamples = 0) @@ -1247,6 +1266,10 @@ public: newNode = new ParallelNode(m_deviceId, nodeName); else if (nodeType == RowStackNode::TypeName()) newNode = new RowStackNode(m_deviceId, nodeName); + else if (nodeType == AlignmentNode::TypeName()) + newNode = new AlignmentNode(m_deviceId, nodeName); + else if (nodeType == PairNetworkNode::TypeName()) + newNode = new PairNetworkNode(m_deviceId, nodeName); else { fprintf(stderr, "Error creating new ComputationNode of type %ls, with name %ls\n", nodeType.c_str(), nodeName.c_str()); @@ -1653,19 +1676,29 @@ public: return newNode; } + ComputationNodePtr Alignment(const ComputationNodePtr a, const ComputationNodePtr b, const ComputationNodePtr c, const std::wstring nodeName = L"") + { + ComputationNodePtr newNode(new AlignmentNode(m_deviceId, nodeName)); + newNode->AttachInputs(a, b, c); + AddNodeToNet(newNode); + return newNode; + } + bool NodeNameExist(const std::wstring& name) const { auto iter = m_nameToNodeMap.find(name); return (iter != m_nameToNodeMap.end()); } - ComputationNodePtr GetNodeFromName(const std::wstring& name) const + ComputationNodePtr GetNodeFromName(const std::wstring& name, ComputationNetwork* anotherNetwork = nullptr) const { auto iter = m_nameToNodeMap.find(name); if (iter != m_nameToNodeMap.end()) //found return iter->second; - else //should never try to get a node from nonexisting name - throw std::runtime_error("GetNodeFromName: Node name does not exist."); + if (anotherNetwork != nullptr) + return anotherNetwork->GetNodeFromName(name); + + RuntimeError("GetNodeFromName: Node name %s does not exist.", name.c_str()); } // GetNodesFromName - Get all the nodes from a name that may match a wildcard '*' pattern diff --git a/MachineLearning/CNTK/ComputationNode.h b/MachineLearning/CNTK/ComputationNode.h index 9b7462ab8..8f5a723be 100644 --- a/MachineLearning/CNTK/ComputationNode.h +++ b/MachineLearning/CNTK/ComputationNode.h @@ -867,7 +867,7 @@ namespace Microsoft { namespace MSR { namespace CNTK { return false; } - void EnumerateNodesForEval(std::unordered_set& visited, std::list& result, + virtual void EnumerateNodesForEval(std::unordered_set& visited, std::list& result, std::vector& sourceRecurrentNodePtr, const bool bFromDelayNode) { if (visited.find(this) == visited.end()) //not visited diff --git a/MachineLearning/CNTK/IComputationNetBuilder.h b/MachineLearning/CNTK/IComputationNetBuilder.h index b92549124..da0aaf685 100644 --- a/MachineLearning/CNTK/IComputationNetBuilder.h +++ b/MachineLearning/CNTK/IComputationNetBuilder.h @@ -15,8 +15,8 @@ namespace Microsoft { namespace MSR { namespace CNTK { { public: virtual ComputationNetwork& LoadNetworkFromFile(const std::wstring& modelFileName, bool forceLoad = true, - bool bAllowNoCriterion = false) = 0; - virtual ComputationNetwork& BuildNetworkFromDescription() = 0; + bool bAllowNoCriterion = false, ComputationNetwork* = nullptr) = 0; + virtual ComputationNetwork& BuildNetworkFromDescription(ComputationNetwork* = nullptr) = 0; virtual ~IComputationNetBuilder() {}; }; }}} \ No newline at end of file diff --git a/MachineLearning/CNTK/InputAndParamNodes.h b/MachineLearning/CNTK/InputAndParamNodes.h index 5b9c8ccdc..19104c285 100644 --- a/MachineLearning/CNTK/InputAndParamNodes.h +++ b/MachineLearning/CNTK/InputAndParamNodes.h @@ -586,4 +586,157 @@ namespace Microsoft { namespace MSR { namespace CNTK { template class LookupTableNode; template class LookupTableNode; + /** + pair this node to a node in another network + */ + template + class PairNetworkNode : public ComputationNode + { + UsingComputationNodeMembers; + public: + PairNetworkNode(const DEVICEID_TYPE deviceId = AUTOPLACEMATRIX, const std::wstring name = L"") + : ComputationNode(deviceId) + { + m_nodeName = (name == L"" ? CreateUniqNodeName() : name); + m_deviceId = deviceId; + MoveMatricesToDevice(deviceId); + m_reqMultiSeqHandling = true; + m_functionValues.Resize(1, 1); + InitRecurrentNode(); + } + + PairNetworkNode(File& fstream, const size_t modelVersion, const DEVICEID_TYPE deviceId = AUTOPLACEMATRIX, const std::wstring name = L"") + : ComputationNode(deviceId) + { + m_nodeName = (name == L"" ? CreateUniqNodeName() : name); + + m_functionValues.Resize(1, 1); + m_reqMultiSeqHandling = true; + + LoadFromFile(fstream, modelVersion, deviceId); + } + + PairNetworkNode(const DEVICEID_TYPE deviceId, size_t row_size, size_t col_size, const std::wstring name = L"") : ComputationNode(deviceId) + { + m_nodeName = (name == L"" ? CreateUniqNodeName() : name); + m_deviceId = deviceId; + MoveMatricesToDevice(deviceId); + m_reqMultiSeqHandling = true; + + m_functionValues.Resize(row_size, col_size); + + m_gradientValues.Resize(row_size, col_size); + m_gradientValues.SetValue(0.0f); + + InitRecurrentNode(); + } + + virtual const std::wstring OperationName() const { return TypeName(); } + + /// to-do: need to change to the new way of resetting state + virtual void ComputeInputPartial(const size_t inputIndex) + { + if (inputIndex > 0) + throw std::invalid_argument("PairNetwork operation only takes one input."); + + Matrix::ScaleAndAdd(1.0, GradientValues(), Inputs(inputIndex)->GradientValues()); + } + + virtual void ComputeInputPartial(const size_t inputIndex, const size_t timeIdxInSeq) + { + if (inputIndex > 0) + throw std::invalid_argument("Delay operation only takes one input."); + assert(m_functionValues.GetNumRows() == GradientValues().GetNumRows()); // original used m_functionValues.GetNumRows() for loop dimension + assert(m_sentenceSeg != nullptr); + assert(m_existsSentenceBeginOrNoLabels != nullptr); + + Matrix mTmp = Inputs(inputIndex)->GradientValues().ColumnSlice(timeIdxInSeq * m_samplesInRecurrentStep, m_samplesInRecurrentStep); + Matrix::ScaleAndAdd(1.0, GradientValues().ColumnSlice(timeIdxInSeq * m_samplesInRecurrentStep, m_samplesInRecurrentStep), + mTmp); + } + + virtual void EvaluateThisNode() + { + m_functionValues.SetValue(Inputs(0)->FunctionValues()); + } + + virtual void EvaluateThisNode(const size_t timeIdxInSeq) + { + Matrix mTmp = FunctionValues().ColumnSlice(timeIdxInSeq * m_samplesInRecurrentStep, m_samplesInRecurrentStep); + mTmp.SetValue(Inputs(0)->FunctionValues().ColumnSlice(timeIdxInSeq * m_samplesInRecurrentStep, m_samplesInRecurrentStep)); + } + + virtual void Validate() + { + PrintSelfBeforeValidation(true); + + if (m_children.size() != 1) + throw std::logic_error("PairNetwork operation should have one input."); + + if (!(Inputs(0) == nullptr)) + { + size_t rows0 = Inputs(0)->FunctionValues().GetNumRows(), cols0 = Inputs(0)->FunctionValues().GetNumCols(); + + if (rows0 > 0 && cols0 > 0) FunctionValues().Resize(rows0, cols0); + } + CopyImageSizeFromInputs(); + } + + virtual void AttachInputs(const ComputationNodePtr inputNode) + { + m_children.resize(1); + m_children[0] = inputNode; + } + + void EnumerateNodesForEval(std::unordered_set& visited, std::list& result, + std::vector& sourceRecurrentNodePtr, const bool bFromDelayNode) + { + if (visited.find(this) == visited.end()) //not visited + { + visited.insert(this); // have visited tagged here to avoid infinite loop over children, children's children, etc + + //children first for function evaluation + if (!IsLeaf()) + { + if (ChildrenNeedGradient()) //only nodes that require gradient calculation is included in gradient calculation + m_needGradient = true; + else + m_needGradient = false; + } + + result.push_back(ComputationNodePtr(this)); //we put this in the list even if it's leaf since we need to use it to determine learnable params + this->m_visitedOrder = result.size(); + } + else + { + if (!IsLeaf() && bFromDelayNode) + sourceRecurrentNodePtr.push_back(this); + } + } + + static const std::wstring TypeName() { return L"PairNetwork"; } + + // copy constructor + PairNetworkNode(const PairNetworkNode* node, const std::wstring& newName, const CopyNodeFlags flags) + : ComputationNode(node->m_deviceId) + { + node->CopyTo(this, newName, flags); + } + + virtual ComputationNodePtr Duplicate(const std::wstring& newName, const CopyNodeFlags flags) const + { + const std::wstring& name = (newName == L"") ? NodeName() : newName; + + ComputationNodePtr node = new PairNetworkNode(this, name, flags); + return node; + } + + protected: + virtual bool UseCustomizedMultiSeqHandling() { return true; } + + }; + + template class PairNetworkNode; + template class PairNetworkNode; + }}} diff --git a/MachineLearning/CNTK/LinearAlgebraNodes.h b/MachineLearning/CNTK/LinearAlgebraNodes.h index 15971e375..9ac3a7262 100644 --- a/MachineLearning/CNTK/LinearAlgebraNodes.h +++ b/MachineLearning/CNTK/LinearAlgebraNodes.h @@ -772,7 +772,7 @@ namespace Microsoft { namespace MSR { namespace CNTK { } } - static void WINAPI ComputeInputPartialLeft(Matrix& inputFunctionValues, Matrix& inputGradientValues, const Matrix& gradientValues) + static void WINAPI ComputeInputPartialLeft(const Matrix& inputFunctionValues, Matrix& inputGradientValues, const Matrix& gradientValues) { #if DUMPOUTPUT gradientValues.Print("Gradient-in"); @@ -2578,8 +2578,8 @@ namespace Microsoft { namespace MSR { namespace CNTK { inputFunctionValues.Print("child Function values"); #endif - if (ones.GetNumRows() != inputGradientValues.GetNumRows() || ones.GetNumCols() != inputGradientValues.GetNumCols()) - ones = Matrix::Ones(inputGradientValues.GetNumRows(), inputGradientValues.GetNumCols(), inputGradientValues.GetDeviceId()); + if (ones.GetNumRows() != inputGradientValues.GetNumRows() || ones.GetNumCols() != inputGradientValues.GetNumRows()) + ones = Matrix::Ones(inputGradientValues.GetNumRows(), inputGradientValues.GetNumRows(), inputGradientValues.GetDeviceId()); Matrix::MultiplyAndAdd(ones, false, gradientValues, true, inputGradientValues); #if DUMPOUTPUT inputGradientValues.Print("child Gradient-out"); diff --git a/MachineLearning/CNTK/MultiNetworksSGD.h b/MachineLearning/CNTK/MultiNetworksSGD.h index 27b0df26f..5be8e0406 100644 --- a/MachineLearning/CNTK/MultiNetworksSGD.h +++ b/MachineLearning/CNTK/MultiNetworksSGD.h @@ -158,7 +158,7 @@ namespace Microsoft { fprintf(stderr, "Starting from checkpoint. Load Decoder Network From File %ws.\n", modelFileName.c_str()); ComputationNetwork& decoderNet = - startEpoch<0 ? decoderNetBuilder->BuildNetworkFromDescription() : decoderNetBuilder->LoadNetworkFromFile(modelFileName); + startEpoch<0 ? decoderNetBuilder->BuildNetworkFromDescription(&encoderNet) : decoderNetBuilder->LoadNetworkFromFile(modelFileName, true, false, &encoderNet); startEpoch = max(startEpoch, 0); @@ -373,6 +373,264 @@ namespace Microsoft { fprintf(stderr, "Finished Epoch[%lu]: Evaluation Node [%ws] Per Sample = %.8g\n", i + 1, evalNodeNames[j].c_str(), epochEvalErrors[j]); } + if (decoderValidationSetDataReader != decoderTrainSetDataReader && decoderValidationSetDataReader != nullptr && + encoderValidationSetDataReader != encoderTrainSetDataReader && encoderValidationSetDataReader != nullptr) + { + SimpleEvaluator evalforvalidation(decoderNet); + vector cvEncoderSetTrainAndEvalNodes; + cvEncoderSetTrainAndEvalNodes.push_back(encoderEvaluationNodes[0]->NodeName()); + + vector cvDecoderSetTrainAndEvalNodes; + cvDecoderSetTrainAndEvalNodes.push_back(decoderCriterionNodes[0]->NodeName()); + cvDecoderSetTrainAndEvalNodes.push_back(decoderEvaluationNodes[0]->NodeName()); + + vector vScore = evalforvalidation.EvaluateEncoderDecoderWithHiddenStates( + encoderNet, decoderNet, + *encoderValidationSetDataReader, + *decoderValidationSetDataReader, cvEncoderSetTrainAndEvalNodes, + cvDecoderSetTrainAndEvalNodes, m_mbSize[i]); + fprintf(stderr, "Finished Epoch[%lu]: [Validation Set] Train Loss Per Sample = %.8g EvalErr Per Sample = %.8g\n", + i + 1, vScore[0], vScore[1]); + + epochCriterion[0] = vScore[0]; //the first one is the decoder training criterion. + } + + bool loadedPrevModel = false; + size_t epochsSinceLastLearnRateAdjust = i % m_learnRateAdjustInterval + 1; + if (avgCriterion == std::numeric_limits::infinity()) + avgCriterion = epochCriterion[0]; + else + avgCriterion = ((epochsSinceLastLearnRateAdjust - 1 - epochsNotCountedInAvgCriterion)* avgCriterion + epochCriterion[0]) / (epochsSinceLastLearnRateAdjust - epochsNotCountedInAvgCriterion); + + if (m_autoLearnRateSearchType == LearningRateSearchAlgorithm::AdjustAfterEpoch && m_learningRatesPerSample.size() <= i && epochsSinceLastLearnRateAdjust == m_learnRateAdjustInterval) + { + if (prevCriterion - avgCriterion < 0 && prevCriterion != std::numeric_limits::infinity()) + { + if (m_loadBestModel) + { + encoderNet.LoadPersistableParametersFromFile(GetEncoderModelNameForEpoch(i - 1), + false); + decoderNet.LoadPersistableParametersFromFile(GetDecoderModelNameForEpoch(i - 1), + m_validateAfterModelReloading); + encoderNet.ResetEvalTimeStamp(); + decoderNet.ResetEvalTimeStamp(); + LoadCheckPointInfo(i - 1, totalSamplesSeen, learnRatePerSample, smoothedGradients, prevCriterion); + fprintf(stderr, "Loaded the previous model which has better training criterion.\n"); + loadedPrevModel = true; + } + } + + if (m_continueReduce) + { + if (prevCriterion - avgCriterion <= m_reduceLearnRateIfImproveLessThan * prevCriterion && prevCriterion != std::numeric_limits::infinity()) + { + if (learnRateReduced == false) + { + learnRateReduced = true; + } + else + { + decoderNet.SaveToFile(GetDecoderModelNameForEpoch(i, true)); + encoderNet.SaveToFile(GetEncoderModelNameForEpoch(i, true)); + fprintf(stderr, "Finished training and saved final model\n\n"); + break; + } + } + if (learnRateReduced) + { + learnRatePerSample *= m_learnRateDecreaseFactor; + fprintf(stderr, "learnRatePerSample reduced to %.8g\n", learnRatePerSample); + } + } + else + { + if (prevCriterion - avgCriterion <= m_reduceLearnRateIfImproveLessThan * prevCriterion && prevCriterion != std::numeric_limits::infinity()) + { + + learnRatePerSample *= m_learnRateDecreaseFactor; + fprintf(stderr, "learnRatePerSample reduced to %.8g\n", learnRatePerSample); + } + else if (prevCriterion - avgCriterion > m_increaseLearnRateIfImproveMoreThan*prevCriterion && prevCriterion != std::numeric_limits::infinity()) + { + learnRatePerSample *= m_learnRateIncreaseFactor; + fprintf(stderr, "learnRatePerSample increased to %.8g\n", learnRatePerSample); + } + } + } + + if (!loadedPrevModel && epochsSinceLastLearnRateAdjust == m_learnRateAdjustInterval) //not loading previous values then set them + { + prevCriterion = avgCriterion; + epochsNotCountedInAvgCriterion = 0; + } + + //persist model and check-point info + decoderNet.SaveToFile(GetDecoderModelNameForEpoch(i)); + encoderNet.SaveToFile(GetEncoderModelNameForEpoch(i)); + SaveCheckPointInfo(i, totalSamplesSeen, learnRatePerSample, smoothedGradients, prevCriterion); + if (!m_keepCheckPointFiles) + _wunlink(GetCheckPointFileNameForEpoch(i - 1).c_str()); //delete previous checkpiont file to save space + + if (learnRatePerSample < 1e-12) + fprintf(stderr, "learnRate per sample is reduced to %.8g which is below 1e-12. stop training.\n", learnRatePerSample); + } + } + + void sfbTrainEncoderDecoderModel(int startEpoch, ComputationNetwork& encoderNet, + ComputationNetwork& decoderNet, + IDataReader* encoderTrainSetDataReader, + IDataReader* decoderTrainSetDataReader, + IDataReader* encoderValidationSetDataReader, + IDataReader* decoderValidationSetDataReader) + { + std::vector & encoderFeatureNodes = encoderNet.FeatureNodes(); + std::vector & encoderEvaluationNodes = encoderNet.OutputNodes(); + + std::vector & decoderFeatureNodes = decoderNet.FeatureNodes(); + std::vector & decoderLabelNodes = decoderNet.LabelNodes(); + std::vector decoderCriterionNodes = GetTrainCriterionNodes(decoderNet); + std::vector decoderEvaluationNodes = GetEvalCriterionNodes(decoderNet); + + std::map*> encoderInputMatrices, decoderInputMatrices; + for (size_t i = 0; iNodeName()] = + &encoderFeatureNodes[i]->FunctionValues(); + } + for (size_t i = 0; iNodeName()] = + &decoderFeatureNodes[i]->FunctionValues(); + } + for (size_t i = 0; iNodeName()] = &decoderLabelNodes[i]->FunctionValues(); + } + + //initializing weights and gradient holder + std::list& encoderLearnableNodes = encoderNet.LearnableNodes(encoderEvaluationNodes[0]); //only one criterion so far TODO: support multiple ones? + std::list& decoderLearnableNodes = decoderNet.LearnableNodes(decoderCriterionNodes[0]); + std::list learnableNodes; + for (auto nodeIter = encoderLearnableNodes.begin(); nodeIter != encoderLearnableNodes.end(); nodeIter++) + { + ComputationNodePtr node = (*nodeIter); + learnableNodes.push_back(node); + } + for (auto nodeIter = decoderLearnableNodes.begin(); nodeIter != decoderLearnableNodes.end(); nodeIter++) + { + ComputationNodePtr node = (*nodeIter); + learnableNodes.push_back(node); + } + + std::list> smoothedGradients; + for (auto nodeIter = learnableNodes.begin(); nodeIter != learnableNodes.end(); nodeIter++) + { + ComputationNodePtr node = (*nodeIter); + smoothedGradients.push_back(Matrix(node->FunctionValues().GetNumRows(), node->FunctionValues().GetNumCols(), node->FunctionValues().GetDeviceId())); + } + + vector epochCriterion; + ElemType avgCriterion, prevCriterion; + for (size_t i = 0; i < 2; i++) + epochCriterion.push_back(std::numeric_limits::infinity()); + avgCriterion = prevCriterion = std::numeric_limits::infinity(); + + size_t epochsNotCountedInAvgCriterion = startEpoch % m_learnRateAdjustInterval; + + std::vector epochEvalErrors(decoderEvaluationNodes.size(), std::numeric_limits::infinity()); + + std::vector evalNodeNames; + for (size_t i = 0; iNodeName()); + + size_t totalSamplesSeen = 0; + ElemType learnRatePerSample = 0.5f / m_mbSize[startEpoch]; + + int m_numPrevLearnRates = 5; //used to control the upper learnining rate in LR search to reduce computation + vector prevLearnRates; + prevLearnRates.resize(m_numPrevLearnRates); + for (int i = 0; i::infinity(); + + //precompute mean and invStdDev nodes and save initial model + if (/// to-do doesn't support pre-compute such as MVN here + /// PreCompute(net, encoderTrainSetDataReader, encoderFeatureNodes, encoderlabelNodes, encoderInputMatrices) || + startEpoch == 0) + { + encoderNet.SaveToFile(GetEncoderModelNameForEpoch(int(startEpoch) - 1)); + decoderNet.SaveToFile(GetDecoderModelNameForEpoch(int(startEpoch) - 1)); + } + + bool learnRateInitialized = false; + if (startEpoch > 0) + { + learnRateInitialized = LoadCheckPointInfo(startEpoch - 1, totalSamplesSeen, learnRatePerSample, smoothedGradients, prevCriterion); + setMomentum(m_momentumInputPerMB[m_momentumInputPerMB.size() - 1]); + } + + if (m_autoLearnRateSearchType == LearningRateSearchAlgorithm::AdjustAfterEpoch && !learnRateInitialized && m_learningRatesPerSample.size() <= startEpoch) + throw std::invalid_argument("When using \"AdjustAfterEpoch\", there must either exist a checkpoint file, or an explicit learning rate must be specified in config for the starting epoch."); + + ULONG dropOutSeed = 1; + ElemType prevDropoutRate = 0; + + bool learnRateReduced = false; + + for (int i = int(startEpoch); i 0 && m_learningRatesPerSample.size() > i)) + { + learnRatePerSample = m_learningRatesPerSample[i]; + setMomentum(m_momentumInputPerMB[i]); + } + else if (m_autoLearnRateSearchType == LearningRateSearchAlgorithm::SearchBeforeEpoch) + { + NOT_IMPLEMENTED; + } + + learnRateInitialized = true; + + if (learnRatePerSample < m_minLearnRate) + { + fprintf(stderr, "Learn Rate Per Sample for Epoch[%lu] = %.8g is less than minLearnRate %.8g. Training stops.\n", i + 1, learnRatePerSample, m_minLearnRate); + break; + } + + TrainOneEpochEncoderDecoderWithHiddenStates(encoderNet, decoderNet, i, m_epochSize, encoderTrainSetDataReader, + decoderTrainSetDataReader, learnRatePerSample, + encoderFeatureNodes, encoderEvaluationNodes, encoderInputMatrices, + decoderFeatureNodes, decoderLabelNodes, decoderCriterionNodes, decoderEvaluationNodes, + decoderInputMatrices, learnableNodes, smoothedGradients, + epochCriterion, epochEvalErrors, totalSamplesSeen); + + + auto t_end_epoch = clock(); + ElemType epochTime = ElemType(1.0)*(t_end_epoch - t_start_epoch) / (CLOCKS_PER_SEC); + + // fprintf(stderr, "Finished Epoch[%lu]: [Training Set] Train Loss Per Sample = %.8g ", i + 1, epochCriterion); + fprintf(stderr, "Finished Epoch[%lu]: [Training Set] Decoder Train Loss Per Sample = %.8g ", i + 1, epochCriterion[0]); + if (epochEvalErrors.size() == 1) + { + fprintf(stderr, "EvalErr Per Sample = %.8g Ave Learn Rate Per Sample = %.10g Epoch Time=%.8g\n", epochEvalErrors[0], learnRatePerSample, epochTime); + } + else + { + fprintf(stderr, "EvalErr Per Sample "); + for (size_t j = 0; jNodeName().c_str(), epochCriterion[i + 1]); + for (size_t j = 0; j localEpochCriterion(1, 2, decoderNet.GetDeviceID()); //assume only one training criterion node for each epoch + Matrix localEpochEvalErrors(1, numEvalNodes, decoderNet.GetDeviceID()); + + localEpochCriterion.SetValue(0); + localEpochEvalErrors.SetValue(0); + + encoderTrainSetDataReader->StartMinibatchLoop(m_mbSize[epochNumber], epochNumber, m_epochSize); + decoderTrainSetDataReader->StartMinibatchLoop(m_mbSize[epochNumber], epochNumber, m_epochSize); + + startReadMBTime = clock(); + Matrix mEncoderOutput(encoderEvaluationNodes[0]->FunctionValues().GetDeviceId()); + Matrix mDecoderInput(decoderEvaluationNodes[0]->FunctionValues().GetDeviceId()); + + unsigned uSeedForDataReader = epochNumber; + + bool bContinueDecoding = true; + while (bContinueDecoding) + { + try{ + encoderTrainSetDataReader->SetRandomSeed(uSeedForDataReader); + encoderTrainSetDataReader->GetMinibatch(encoderInputMatrices); + + /// now gradients on decoder network + decoderTrainSetDataReader->SetRandomSeed(uSeedForDataReader); + if (decoderTrainSetDataReader->GetMinibatch(decoderInputMatrices) == false) + break; + } + catch (...) + { + RuntimeError("Errors in reading features "); + } + + size_t actualMBSize = decoderNet.GetActualMBSize(); + if (actualMBSize == 0) + LogicError("decoderTrainSetDataReader read data but decoderNet reports no data read"); + + UpdateEvalTimeStamps(encoderFeatureNodes); + UpdateEvalTimeStamps(decoderFeatureNodes); + UpdateEvalTimeStamps(decoderLabelNodes); + + endReadMBTime = clock(); + startComputeMBTime = clock(); + + /// not the sentence begining, because the initial hidden layer activity is from the encoder network + // decoderTrainSetDataReader->SetSentenceBegin(false); + // decoderTrainSetDataReader->SetSentenceSegBatch(decoderNet.m_sentenceSeg); + // decoderTrainSetDataReader->SetSentenceSegBatch(decoderNet.m_sentenceBegin); + + if (m_doGradientCheck) + { + if (EncoderDecoderGradientCheck(encoderNet, + decoderNet, encoderTrainSetDataReader, + decoderTrainSetDataReader, encoderEvaluationNodes, + decoderFeatureNodes, decoderCriterionNodes, decoderEvaluationNodes, historyMat, localEpochCriterion, localEpochEvalErrors) == false) + { + throw runtime_error("SGD::TrainOneEpochEncoderDecoderWithHiddenStates gradient check not passed!"); + } + localEpochCriterion.SetValue(0); + localEpochEvalErrors.SetValue(0); + } + + EncoderDecoderWithHiddenStatesForwardPass(encoderNet, + decoderNet, encoderTrainSetDataReader, + decoderTrainSetDataReader, encoderEvaluationNodes, + decoderFeatureNodes, decoderCriterionNodes, decoderEvaluationNodes, localEpochCriterion, localEpochEvalErrors); + + EncoderDecoderWithHiddenStatesErrorProp(encoderNet, + decoderNet, encoderEvaluationNodes, + decoderCriterionNodes, + historyMat, m_lst_pair_encoder_decoder_nodes); + + //update model parameters + if (learnRatePerSample > m_minLearnRate * 0.01) + { + auto smoothedGradientIter = smoothedGradients.begin(); + for (auto nodeIter = learnableNodes.begin(); nodeIter != learnableNodes.end(); nodeIter++, smoothedGradientIter++) + { + ComputationNodePtr node = (*nodeIter); + Matrix& smoothedGradient = (*smoothedGradientIter); + + UpdateWeights(node, smoothedGradient, learnRatePerSample, actualMBSize, m_mbSize[epochNumber], m_L2RegWeight, m_L1RegWeight, m_needAveMultiplier); + } + } + + + endComputeMBTime = clock(); + numMBsRun++; + if (m_traceLevel > 0) + { + ElemType MBReadTime = (ElemType)(endReadMBTime - startReadMBTime) / (CLOCKS_PER_SEC); + ElemType MBComputeTime = (ElemType)(endComputeMBTime - startComputeMBTime) / CLOCKS_PER_SEC; + + readTimeInMBs += MBReadTime; + ComputeTimeInMBs += MBComputeTime; + numSamplesLastMBs += int(actualMBSize); + + if (numMBsRun % m_numMBsToShowResult == 0) + { + + epochCriterion[0] = localEpochCriterion.Get00Element(); + for (size_t i = 0; i< numEvalNodes; i++) + epochEvalErrors[i] = (const ElemType)localEpochEvalErrors(0, i); + + ElemType llk = (epochCriterion[0] - epochCriterionLastMBs[0]) / numSamplesLastMBs; + ElemType ppl = exp(llk); + fprintf(stderr, "Epoch[%d]-Minibatch[%d-%d]: Samples Seen = %d Decoder Train Loss Per Sample = %.8g PPL = %.4e ", epochNumber + 1, numMBsRun - m_numMBsToShowResult + 1, numMBsRun, numSamplesLastMBs, + llk, ppl); + for (size_t i = 0; i= epochSize) + break; + + /// call DataEnd function + /// DataEnd does reader specific process if sentence ending is reached + // encoderTrainSetDataReader->SetSentenceEnd(true); + // decoderTrainSetDataReader->SetSentenceEnd(true); + encoderTrainSetDataReader->DataEnd(endDataSentence); + decoderTrainSetDataReader->DataEnd(endDataSentence); + + uSeedForDataReader++; + } + + localEpochCriterion /= float(totalEpochSamples); + localEpochEvalErrors /= float(totalEpochSamples); + + epochCriterion[0] = localEpochCriterion.Get00Element(); + for (size_t i = 0; i < numEvalNodes; i++) + { + epochEvalErrors[i] = (const ElemType)localEpochEvalErrors(0, i); + } + fprintf(stderr, "total samples in epoch[%d] = %d\n", epochNumber, totalEpochSamples); + } + + /// use hidden states between encoder and decoder to communicate between two networks + void sfbTrainOneEpochEncoderDecoderWithHiddenStates( + ComputationNetwork& encoderNet, /// encoder network + ComputationNetwork& decoderNet, + const int epochNumber, const size_t epochSize, + IDataReader* encoderTrainSetDataReader, + IDataReader* decoderTrainSetDataReader, + const ElemType learnRatePerSample, + const std::vector& encoderFeatureNodes, + const std::vector& encoderEvaluationNodes, + std::map*>& encoderInputMatrices, + const std::vector& decoderFeatureNodes, + const std::vector& decoderLabelNodes, + const std::vector& decoderCriterionNodes, + const std::vector& decoderEvaluationNodes, + std::map*>& decoderInputMatrices, + const std::list& learnableNodes, + std::list>& smoothedGradients, + vector& epochCriterion, std::vector& epochEvalErrors, size_t& totalSamplesSeen) + { + assert(encoderEvaluationNodes.size() == 1); + + Matrix historyMat(encoderNet.GetDeviceID()); + + ElemType readTimeInMBs = 0, ComputeTimeInMBs = 0; + vector epochCriterionLastMBs; + for (size_t i = 0; i < epochCriterion.size(); i++) + epochCriterionLastMBs.push_back(0); + + int numSamplesLastMBs = 0; + std::vector epochEvalErrorsLastMBs(epochEvalErrors.size(), 0); + + clock_t startReadMBTime = 0, startComputeMBTime = 0; + clock_t endReadMBTime = 0, endComputeMBTime = 0; + + //initialize statistics + size_t totalEpochSamples = 0; + + int numMBsRun = 0; + /// get the pair of encode and decoder nodes if (m_lst_pair_encoder_decoder_nodes.size() == 0 && m_lst_pair_encoder_decode_node_names.size() > 0) { @@ -600,14 +1050,14 @@ namespace Microsoft { } EncoderDecoderWithHiddenStatesForwardPass(encoderNet, - decoderNet, encoderTrainSetDataReader, - decoderTrainSetDataReader, encoderEvaluationNodes, - decoderFeatureNodes, decoderCriterionNodes, decoderEvaluationNodes, historyMat, localEpochCriterion, localEpochEvalErrors); + decoderNet, encoderTrainSetDataReader, + decoderTrainSetDataReader, encoderEvaluationNodes, + decoderFeatureNodes, decoderCriterionNodes, decoderEvaluationNodes, localEpochCriterion, localEpochEvalErrors); EncoderDecoderWithHiddenStatesErrorProp(encoderNet, - decoderNet, encoderEvaluationNodes, - decoderCriterionNodes, - historyMat, m_lst_pair_encoder_decoder_nodes); + decoderNet, encoderEvaluationNodes, + decoderCriterionNodes, + historyMat, m_lst_pair_encoder_decoder_nodes); //update model parameters if (learnRatePerSample > m_minLearnRate * 0.01) @@ -741,7 +1191,7 @@ namespace Microsoft { EncoderDecoderWithHiddenStatesForwardPass(encoderNet, decoderNet, encoderTrainSetDataReader, decoderTrainSetDataReader, encoderEvaluationNodes, - decoderFeatureNodes, decoderCriterionNodes, decoderEvaluationNodes, historyMat, localEpochCriterion, localEpochEvalErrors); + decoderFeatureNodes, decoderCriterionNodes, decoderEvaluationNodes, localEpochCriterion, localEpochEvalErrors); ElemType score1 = localEpochCriterion.Get00Element(); @@ -759,7 +1209,7 @@ namespace Microsoft { EncoderDecoderWithHiddenStatesForwardPass(encoderNet, decoderNet, encoderTrainSetDataReader, decoderTrainSetDataReader, encoderEvaluationNodes, - decoderFeatureNodes, decoderCriterionNodes, decoderEvaluationNodes, historyMat, localEpochCriterion, localEpochEvalErrors); + decoderFeatureNodes, decoderCriterionNodes, decoderEvaluationNodes, localEpochCriterion, localEpochEvalErrors); ElemType score2 = localEpochCriterion.Get00Element(); @@ -776,7 +1226,7 @@ namespace Microsoft { EncoderDecoderWithHiddenStatesForwardPass(encoderNet, decoderNet, encoderTrainSetDataReader, decoderTrainSetDataReader, encoderEvaluationNodes, - decoderFeatureNodes, decoderCriterionNodes, decoderEvaluationNodes, historyMat, localEpochCriterion, localEpochEvalErrors); + decoderFeatureNodes, decoderCriterionNodes, decoderEvaluationNodes, localEpochCriterion, localEpochEvalErrors); EncoderDecoderWithHiddenStatesErrorProp(encoderNet, decoderNet, encoderEvaluationNodes, @@ -836,7 +1286,7 @@ namespace Microsoft { EncoderDecoderWithHiddenStatesForwardPass(encoderNet, decoderNet, encoderTrainSetDataReader, decoderTrainSetDataReader, encoderEvaluationNodes, - decoderFeatureNodes, decoderCriterionNodes, decoderEvaluationNodes, historyMat, localEpochCriterion, localEpochEvalErrors); + decoderFeatureNodes, decoderCriterionNodes, decoderEvaluationNodes, localEpochCriterion, localEpochEvalErrors); ElemType score1 = localEpochCriterion.Get00Element(); @@ -852,7 +1302,7 @@ namespace Microsoft { EncoderDecoderWithHiddenStatesForwardPass(encoderNet, decoderNet, encoderTrainSetDataReader, decoderTrainSetDataReader, encoderEvaluationNodes, - decoderFeatureNodes, decoderCriterionNodes, decoderEvaluationNodes, historyMat, localEpochCriterion, localEpochEvalErrors); + decoderFeatureNodes, decoderCriterionNodes, decoderEvaluationNodes, localEpochCriterion, localEpochEvalErrors); ElemType score1r = localEpochCriterion.Get00Element(); @@ -869,7 +1319,7 @@ namespace Microsoft { EncoderDecoderWithHiddenStatesForwardPass(encoderNet, decoderNet, encoderTrainSetDataReader, decoderTrainSetDataReader, encoderEvaluationNodes, - decoderFeatureNodes, decoderCriterionNodes, decoderEvaluationNodes, historyMat, localEpochCriterion, localEpochEvalErrors); + decoderFeatureNodes, decoderCriterionNodes, decoderEvaluationNodes, localEpochCriterion, localEpochEvalErrors); EncoderDecoderWithHiddenStatesErrorProp(encoderNet, decoderNet, encoderEvaluationNodes, @@ -906,7 +1356,6 @@ namespace Microsoft { const std::vector& decoderFeatureNodes, const std::vector& decoderCriterionNodes, const std::vector& decoderEvaluationNodes, - Matrix& historyMat, Matrix& localEpochCriterion, Matrix& localEpochEvalErrors ) @@ -928,21 +1377,6 @@ namespace Microsoft { /// not the sentence begining, because the initial hidden layer activity is from the encoder network decoderTrainSetDataReader->SetSentenceSegBatch(decoderNet.mSentenceBoundary, decoderNet.mExistsBeginOrNoLabels); - /// get the pair of encode and decoder nodes - for (typename list>::iterator iter = m_lst_pair_encoder_decoder_nodes.begin(); iter != m_lst_pair_encoder_decoder_nodes.end(); iter++) - { - /// past hidden layer activity from encoder network to decoder network - ComputationNodePtr encoderNode = iter->first; - ComputationNodePtr decoderNode = iter->second; - - encoderNode->GetHistory(historyMat, true); /// get the last state activity - decoderNode->SetHistory(historyMat); -#ifdef DEBUG_DECODER - fprintf(stderr, "LSTM past output norm = %.8e\n", historyMat.ColumnSlice(0, nstreams).FrobeniusNorm()); - fprintf(stderr, "LSTM past state norm = %.8e\n", historyMat.ColumnSlice(nstreams, nstreams).FrobeniusNorm()); -#endif - } - UpdateEvalTimeStamps(decoderFeatureNodes); decoderNet.Evaluate(decoderCriterionNodes[0]); diff --git a/MachineLearning/CNTK/NDLNetworkBuilder.h b/MachineLearning/CNTK/NDLNetworkBuilder.h index 1c9ea6deb..24bb3288e 100644 --- a/MachineLearning/CNTK/NDLNetworkBuilder.h +++ b/MachineLearning/CNTK/NDLNetworkBuilder.h @@ -153,10 +153,10 @@ namespace Microsoft { namespace MSR { namespace CNTK { delete m_executionEngine; } virtual ComputationNetwork& LoadNetworkFromFile(const wstring& modelFileName, bool forceLoad = true, - bool bAllowNoCriterionNode = false) + bool bAllowNoCriterionNode = false, ComputationNetwork* anotherNetwork = nullptr) { if (m_net->GetTotalNumberOfNodes() == 0 || forceLoad) //not built or force load - m_net->LoadFromFile(modelFileName, FileOptions::fileOptionsBinary, bAllowNoCriterionNode); + m_net->LoadFromFile(modelFileName, FileOptions::fileOptionsBinary, bAllowNoCriterionNode, anotherNetwork); m_net->ResetEvalTimeStamp(); return *m_net; @@ -211,7 +211,7 @@ namespace Microsoft { namespace MSR { namespace CNTK { ndlUtil.ProcessNDLConfig(config, true); } - virtual ComputationNetwork& BuildNetworkFromDescription() + virtual ComputationNetwork& BuildNetworkFromDescription(ComputationNetwork* = nullptr) { if (m_net->GetTotalNumberOfNodes() < 1) //not built yet { diff --git a/MachineLearning/CNTK/NetworkDescriptionLanguage.cpp b/MachineLearning/CNTK/NetworkDescriptionLanguage.cpp index 17e922416..7c2c9b116 100644 --- a/MachineLearning/CNTK/NetworkDescriptionLanguage.cpp +++ b/MachineLearning/CNTK/NetworkDescriptionLanguage.cpp @@ -238,6 +238,10 @@ bool CheckFunction(std::string& p_nodeType, bool* allowUndeterminedVariable) ret = true; else if (EqualInsensitive(nodeType, LSTMNode::TypeName(), L"LSTM")) ret = true; + else if (EqualInsensitive(nodeType, AlignmentNode::TypeName(), L"Alignment")) + ret = true; + else if (EqualInsensitive(nodeType, AlignmentNode::TypeName(), L"PairNetwork")) + ret = true; // return the actual node name in the parameter if we found something if (ret) diff --git a/MachineLearning/CNTK/RecurrentNodes.h b/MachineLearning/CNTK/RecurrentNodes.h index 4673c9f89..4316da56f 100644 --- a/MachineLearning/CNTK/RecurrentNodes.h +++ b/MachineLearning/CNTK/RecurrentNodes.h @@ -27,7 +27,7 @@ to-dos: delay_node : has another input that points to additional observations. memory_node: M x N node, with a argument telling whether to save the last observation, or save a window size of observations, or save all observations pair_node : copy function values and gradient values from one node in source network to target network - +sequential_alignment_node: compute similarity of the previous time or any matrix, versus a block of input, and output a weighted average from the input decoder delay_node -> memory_node -> pair(source, target) pair(source, target) -> memory_node -> encoder output node @@ -581,7 +581,7 @@ namespace Microsoft { namespace MSR { namespace CNTK { Developed by Kaisheng Yao Used in the following works: - K. Yao, G. Zweig, "Sequence to sequence neural net models for graphone to phoneme conversion", submitted to Interspeech 2015 + K. Yao, G. Zweig, "Sequence to sequence neural net models for graphone to phoneme conversion", in Interspeech 2015 */ template class LSTMNode : public ComputationNode @@ -1793,4 +1793,270 @@ namespace Microsoft { namespace MSR { namespace CNTK { template class LSTMNode; template class LSTMNode; + + /** + This node uses softmax to compute the similarity of an input versus the second input, which is a block of memory, and outputs + the weighed average of the second input. + */ + template + class AlignmentNode : public ComputationNode + { + UsingComputationNodeMembers; + public: + AlignmentNode(const DEVICEID_TYPE deviceId = AUTOPLACEMATRIX, const std::wstring name = L"") + : ComputationNode(deviceId), m_memoryBlk4EachUtt(deviceId), m_softmax(deviceId), m_ones(deviceId) + { + m_nodeName = (name == L"" ? CreateUniqNodeName() : name); + m_deviceId = deviceId; + MoveMatricesToDevice(deviceId); + InitRecurrentNode(); + } + + AlignmentNode(File& fstream, const size_t modelVersion, const DEVICEID_TYPE deviceId = AUTOPLACEMATRIX, const std::wstring name = L"") + : ComputationNode(deviceId), m_memoryBlk4EachUtt(deviceId), m_softmax(deviceId), m_ones(deviceId) + { + m_nodeName = (name == L"" ? CreateUniqNodeName() : name); + LoadFromFile(fstream, modelVersion, deviceId); + } + + virtual const std::wstring OperationName() const { return TypeName(); } + static const std::wstring TypeName() { return L"Alignment"; } + + virtual void ComputeInputPartial(const size_t ) + { + NOT_IMPLEMENTED; + } + + virtual void ComputeInputPartial(const size_t inputIndex, const size_t timeIdxInSeq) + { + if (inputIndex == 0) + return; + + if (inputIndex > 2) + throw std::invalid_argument("Alignment has three inputs."); + + Matrix sliceOutputGrad = GradientValues().ColumnSlice(timeIdxInSeq * m_samplesInRecurrentStep, m_samplesInRecurrentStep); + Matrix mTmp(m_deviceId); + Matrix mTmp1(m_deviceId); + Matrix mTmp2(m_deviceId); + Matrix mTmp3(m_deviceId); + Matrix mGBeforeSoftmax(m_deviceId); + Matrix mTmp4(m_deviceId); + Matrix mGToMemblk(m_deviceId); + size_t T = Inputs(1)->FunctionValues().GetNumCols() / m_samplesInRecurrentStep; + size_t e = Inputs(0)->FunctionValues().GetNumRows(); + size_t d = Inputs(1)->FunctionValues().GetNumRows(); + Matrix mGBeforeSoftmaxTimes(m_deviceId); + mGBeforeSoftmaxTimes.Resize(T, e); + mGBeforeSoftmaxTimes.SetValue(0); + mGToMemblk.Resize(d, T); + mGToMemblk.SetValue(0); + + if (m_ones.GetNumRows() != e || m_ones.GetNumCols() != e) + { + m_ones = Matrix::Ones(e, e, m_deviceId); + } + + mGBeforeSoftmax.Resize(m_softmax.GetNumRows(),1); + mGBeforeSoftmax.SetValue(0); + for (size_t k = 0; k < m_samplesInRecurrentStep; k++) + { + size_t i = timeIdxInSeq * m_samplesInRecurrentStep + k; + + /// right branch with softmax + mTmp4 = m_memoryBlk4EachUtt.ColumnSlice(k*T, T); + TimesNode::ComputeInputPartialRight(mTmp4, mTmp, sliceOutputGrad.ColumnSlice(k, 1)); /// before times + SoftmaxNode::ComputeInputPartialS(mTmp1, mTmp2, mGBeforeSoftmax, mTmp, m_softmax.ColumnSlice(k, 1)); /// before softmax + TimesNode::ComputeInputPartialLeft(Inputs(0)->FunctionValues().ColumnSlice(i, 1), mGBeforeSoftmaxTimes, mGBeforeSoftmax); /// before times + + switch (inputIndex) + { + case 0: + LogicError("no gradients should be backpropagated to past observation"); + case 1: //derivative to memory block + TimesNode::ComputeInputPartialLeft(m_softmax.ColumnSlice(k, 1), mGToMemblk, + sliceOutputGrad.ColumnSlice(k,1)); + + mTmp4.Resize(T,e); + mTmp4.SetValue(0); + TimesNode::ComputeInputPartialLeft(Inputs(2)->FunctionValues(), mTmp4, mGBeforeSoftmaxTimes); + TransposeNode::ComputeInputPartial(mGToMemblk, m_ones, mTmp4); + + for (size_t j = 0; j < T; j++) + Inputs(1)->GradientValues().ColumnSlice(j*m_samplesInRecurrentStep + k, 1) += mGToMemblk.ColumnSlice(j, 1); + + break; + case 2: // derivative to similarity matrix + mTmp2 = m_memoryBlk4EachUtt.ColumnSlice(k*T, T); + Matrix::MultiplyAndAdd(mTmp2, false, mGBeforeSoftmaxTimes, false, Inputs(2)->GradientValues()); /// before times + + break; + } + } + } + + virtual void EvaluateThisNode() + { + EvaluateThisNodeS(m_functionValues, Inputs(0)->FunctionValues(), Inputs(1)->FunctionValues(), Inputs(2)->FunctionValues(), + m_memoryBlk4EachUtt, m_softmax); + } + + virtual void EvaluateThisNode(const size_t timeIdxInSeq) + { + Matrix sliceInputValue = Inputs(0)->FunctionValues().ColumnSlice(timeIdxInSeq * m_samplesInRecurrentStep, m_samplesInRecurrentStep); + Matrix sliceOutputValue = m_functionValues.ColumnSlice(timeIdxInSeq * m_samplesInRecurrentStep, m_samplesInRecurrentStep); + + EvaluateThisNodeS(sliceOutputValue, sliceInputValue, Inputs(1)->FunctionValues(), Inputs(2)->FunctionValues(), m_memoryBlk4EachUtt, m_softmax); + } + + static void WINAPI EvaluateThisNodeS(Matrix& functionValues, + const Matrix& refFunction, const Matrix& memoryBlk, + const Matrix& wgtMatrix, + Matrix& tmpMemoryBlk4EachUtt, Matrix& tmpSoftMax) + { + size_t e = wgtMatrix.GetNumCols(); + size_t nbrUttPerSample = refFunction.GetNumCols(); + size_t T = memoryBlk.GetNumCols() / nbrUttPerSample; + + tmpMemoryBlk4EachUtt.Resize(memoryBlk.GetNumRows(), memoryBlk.GetNumCols()); + tmpSoftMax.Resize(T, nbrUttPerSample); + Matrix tmpMat(tmpMemoryBlk4EachUtt.GetDeviceId()); + Matrix tmpMat3(tmpMemoryBlk4EachUtt.GetDeviceId()); + tmpMat3.Resize(e, T); + Matrix tmpMat4(tmpMemoryBlk4EachUtt.GetDeviceId()); + Matrix tmpMat2(tmpMemoryBlk4EachUtt.GetDeviceId()); + + for (size_t k = 0; k < nbrUttPerSample; k++) + { + for (size_t t = 0; t < T; t++) + { + size_t i = t * nbrUttPerSample + k; + tmpMat3.ColumnSlice(t, 1).SetValue(memoryBlk.ColumnSlice(i, 1)); + } + /// d x T + tmpMemoryBlk4EachUtt.ColumnSlice(k*T, T) = tmpMat3; + + Matrix::Multiply(tmpMat3, true, wgtMatrix, false, tmpMat); + /// T x d x (d x e) = T x e + + Matrix::Multiply(tmpMat, false, refFunction.ColumnSlice(k,1), false, tmpMat2); + /// T x e x (e x 1) = T x 1 + + tmpSoftMax.ColumnSlice(k, 1) = tmpMat2; + tmpMat2.InplaceLogSoftmax(true); + tmpMat2.InplaceExp(); + + Matrix::Multiply(tmpMat3, false, tmpMat2, false, tmpMat4); + functionValues.ColumnSlice(k, 1).SetValue(tmpMat4); + /// d x 1 + } + /// d x k + +#if NANCHECK + functionValues.HasNan("Alignment"); +#endif + } + + /** + input 0, denoted as r (in d x k) : this is an input that is treated as given observation, so no gradient is backpropagated into it. + input 1, denoted as M (in e x k x T) : this is a block of memory + input 2, denoted as W (in d x e) : this is a matrix to compute similarity + d : input 0 feature dimension + k : number of utterances per minibatch + T : input 1 time dimension + e : input 1 feature dimension + the operation is + s = r^T W M in k x T + w = softmax(s) in k x T + o = M w^T in e x k + */ + virtual void Validate() + { + PrintSelfBeforeValidation(); + size_t k, T, e, d, i; + + if (m_children.size() != 3) + throw std::logic_error("AlignmentNode operation should have three input."); + + if (Inputs(0)->FunctionValues().GetNumElements() == 0 || + Inputs(1)->FunctionValues().GetNumElements() == 0 || + Inputs(2)->FunctionValues().GetNumElements() == 0) + throw std::logic_error("AlignmentNode operation: the input nodes have 0 element."); + + d = Inputs(0)->FunctionValues().GetNumRows(); + k = Inputs(0)->FunctionValues().GetNumCols(); + i = Inputs(1)->FunctionValues().GetNumCols(); + e = Inputs(1)->FunctionValues().GetNumRows(); + T = i / k; + if (Inputs(2)->FunctionValues().GetNumRows() != d || + Inputs(2)->FunctionValues().GetNumCols() != e) + LogicError("AlignmentNode operation: the weight matrix dimension doesn't match input feature dimensions."); + + FunctionValues().Resize(e, k); + + CopyImageSizeFromInputs(); + } + + virtual void AttachInputs(const ComputationNodePtr refFeature, const ComputationNodePtr memoryBlk, const ComputationNodePtr wgtMatrix) + { + m_children.resize(3); + m_children[0] = refFeature; + m_children[1] = memoryBlk; + m_children[2] = wgtMatrix; + } + + virtual void MoveMatricesToDevice(const DEVICEID_TYPE deviceId) + { + ComputationNode::MoveMatricesToDevice(deviceId); + + if (deviceId != AUTOPLACEMATRIX) + { + if (m_memoryBlk4EachUtt.GetDeviceId() != deviceId) + m_memoryBlk4EachUtt.TransferFromDeviceToDevice(m_memoryBlk4EachUtt.GetDeviceId(), deviceId); + if (m_softmax.GetDeviceId() != deviceId) + m_softmax.TransferFromDeviceToDevice(m_softmax.GetDeviceId(), deviceId); + if (m_weight.GetDeviceId() != deviceId) + m_weight.TransferFromDeviceToDevice(m_weight.GetDeviceId(), deviceId); + } + } + + virtual void CopyTo(const ComputationNodePtr nodeP, const std::wstring& newName, const CopyNodeFlags flags) const + { + ComputationNode::CopyTo(nodeP, newName, flags); + AlignmentNode* node = (AlignmentNode*) nodeP; + + if (flags & CopyNodeFlags::copyNodeValue) + { + node->m_memoryBlk4EachUtt = m_memoryBlk4EachUtt; + node->m_softmax = m_softmax; + node->m_weight = m_weight; + node->m_ones = m_ones; + } + } + + // copy constructor + AlignmentNode(const AlignmentNode* node, const std::wstring& newName, const CopyNodeFlags flags) + : ComputationNode(node->m_deviceId), m_memoryBlk4EachUtt(node->m_deviceId), m_softmax(node->m_deviceId), m_ones(node->m_deviceId) + { + node->CopyTo(this, newName, flags); + } + + virtual ComputationNodePtr Duplicate(const std::wstring& newName, const CopyNodeFlags flags) const + { + const std::wstring& name = (newName == L"") ? NodeName() : newName; + + ComputationNodePtr node = new AlignmentNode(this, name, flags); + return node; + } + + private: + Matrix m_memoryBlk4EachUtt; + Matrix m_softmax; + Matrix m_weight; + Matrix m_ones; + }; + + template class AlignmentNode; + template class AlignmentNode; + }}} diff --git a/MachineLearning/CNTK/SimpleEvaluator.h b/MachineLearning/CNTK/SimpleEvaluator.h index ed0a319eb..a50d6218c 100644 --- a/MachineLearning/CNTK/SimpleEvaluator.h +++ b/MachineLearning/CNTK/SimpleEvaluator.h @@ -431,6 +431,202 @@ namespace Microsoft { namespace MSR { namespace CNTK { } } + //initialize eval results + std::vector evalResults; + for (int i = 0; i < decoderEvalNodes.size(); i++) + { + evalResults.push_back((ElemType)0); + } + + //prepare features and labels + std::vector & encoderFeatureNodes = encoderNet.FeatureNodes(); + + std::vector & decoderFeatureNodes = decoderNet.FeatureNodes(); + std::vector & decoderLabelNodes = decoderNet.LabelNodes(); + + std::map*> encoderInputMatrices; + for (size_t i = 0; i < encoderFeatureNodes.size(); i++) + { + encoderInputMatrices[encoderFeatureNodes[i]->NodeName()] = &encoderFeatureNodes[i]->FunctionValues(); + } + + std::map*> decoderInputMatrices; + for (size_t i = 0; i < decoderFeatureNodes.size(); i++) + { + decoderInputMatrices[decoderFeatureNodes[i]->NodeName()] = &decoderFeatureNodes[i]->FunctionValues(); + } + for (size_t i = 0; i < decoderLabelNodes.size(); i++) + { + decoderInputMatrices[decoderLabelNodes[i]->NodeName()] = &decoderLabelNodes[i]->FunctionValues(); + } + + //evaluate through minibatches + size_t totalEpochSamples = 0; + size_t numMBsRun = 0; + size_t actualMBSize = 0; + size_t numSamplesLastMBs = 0; + size_t lastMBsRun = 0; //MBs run before this display + + std::vector evalResultsLastMBs; + for (int i = 0; i < evalResults.size(); i++) + evalResultsLastMBs.push_back((ElemType)0); + + encoderDataReader.StartMinibatchLoop(mbSize, 0, testSize); + decoderDataReader.StartMinibatchLoop(mbSize, 0, testSize); + + Matrix mEncoderOutput(encoderEvalNodes[0]->FunctionValues().GetDeviceId()); + Matrix historyMat(encoderEvalNodes[0]->FunctionValues().GetDeviceId()); + + bool bContinueDecoding = true; + while (bContinueDecoding){ + /// first evaluate encoder network + if (encoderDataReader.GetMinibatch(encoderInputMatrices) == false) + break; + if (decoderDataReader.GetMinibatch(decoderInputMatrices) == false) + break; + UpdateEvalTimeStamps(encoderFeatureNodes); + UpdateEvalTimeStamps(decoderFeatureNodes); + + actualMBSize = decoderNet.GetActualMBSize(); + if (actualMBSize == 0) + LogicError("decoderTrainSetDataReader read data but decoderNet reports no data read"); + + encoderNet.SetActualMiniBatchSize(actualMBSize); + encoderNet.SetActualNbrSlicesInEachRecIter(encoderDataReader.NumberSlicesInEachRecurrentIter()); + encoderDataReader.SetSentenceSegBatch(encoderNet.mSentenceBoundary, encoderNet.mExistsBeginOrNoLabels); + + assert(encoderEvalNodes.size() == 1); + for (int i = 0; i < encoderEvalNodes.size(); i++) + { + encoderNet.Evaluate(encoderEvalNodes[i]); + } + + + /// not the sentence begining, because the initial hidden layer activity is from the encoder network + decoderNet.SetActualNbrSlicesInEachRecIter(decoderDataReader.NumberSlicesInEachRecurrentIter()); + decoderDataReader.SetSentenceSegBatch(decoderNet.mSentenceBoundary, decoderNet.mExistsBeginOrNoLabels); + + for (int i = 0; iFunctionValues().Get00Element(); //criterionNode should be a scalar + } + + totalEpochSamples += actualMBSize; + numMBsRun++; + + if (m_traceLevel > 0) + { + numSamplesLastMBs += actualMBSize; + + if (numMBsRun % m_numMBsToShowResult == 0) + { + DisplayEvalStatistics(lastMBsRun + 1, numMBsRun, numSamplesLastMBs, decoderEvalNodes, evalResults, evalResultsLastMBs); + + for (int i = 0; i < evalResults.size(); i++) + { + evalResultsLastMBs[i] = evalResults[i]; + } + numSamplesLastMBs = 0; + lastMBsRun = numMBsRun; + } + } + + /// call DataEnd to check if end of sentence is reached + /// datareader will do its necessary/specific process for sentence ending + encoderDataReader.DataEnd(endDataSentence); + decoderDataReader.DataEnd(endDataSentence); + } + + // show last batch of results + if (m_traceLevel > 0 && numSamplesLastMBs > 0) + { + DisplayEvalStatistics(lastMBsRun + 1, numMBsRun, numSamplesLastMBs, decoderEvalNodes, evalResults, evalResultsLastMBs); + } + + //final statistics + for (int i = 0; i < evalResultsLastMBs.size(); i++) + { + evalResultsLastMBs[i] = 0; + } + + fprintf(stderr, "Final Results: "); + DisplayEvalStatistics(1, numMBsRun, totalEpochSamples, decoderEvalNodes, evalResults, evalResultsLastMBs); + + for (int i = 0; i < evalResults.size(); i++) + { + evalResults[i] /= totalEpochSamples; + } + + return evalResults; + } + + /// this evaluates encoder network and decoder network + vector sfbEvaluateEncoderDecoderWithHiddenStates( + ComputationNetwork& encoderNet, + ComputationNetwork& decoderNet, + IDataReader& encoderDataReader, + IDataReader& decoderDataReader, + const vector& encoderEvalNodeNames, + const vector& decoderEvalNodeNames, + const size_t mbSize, + const size_t testSize = requestDataSize) + { + //specify evaluation nodes + std::vector encoderEvalNodes; + std::vector decoderEvalNodes; + + if (encoderEvalNodeNames.size() == 0) + { + fprintf(stderr, "evalNodeNames are not specified, using all the default evalnodes and training criterion nodes.\n"); + if (encoderNet.EvaluationNodes().size() == 0) + throw std::logic_error("There is no default evalnodes criterion node specified in the network."); + + for (int i = 0; i < encoderNet.EvaluationNodes().size(); i++) + encoderEvalNodes.push_back(encoderNet.EvaluationNodes()[i]); + } + else + { + for (int i = 0; i < encoderEvalNodeNames.size(); i++) + { + ComputationNodePtr node = encoderNet.GetNodeFromName(encoderEvalNodeNames[i]); + encoderNet.BuildAndValidateNetwork(node); + if (!node->FunctionValues().GetNumElements() == 1) + { + throw std::logic_error("The nodes passed to SimpleEvaluator::Evaluate function must be either eval or training criterion nodes (which evalues to 1x1 value)."); + } + encoderEvalNodes.push_back(node); + } + } + + if (decoderEvalNodeNames.size() == 0) + { + fprintf(stderr, "evalNodeNames are not specified, using all the default evalnodes and training criterion nodes.\n"); + if (decoderNet.EvaluationNodes().size() == 0) + throw std::logic_error("There is no default evalnodes criterion node specified in the network."); + if (decoderNet.FinalCriterionNodes().size() == 0) + throw std::logic_error("There is no default criterion criterion node specified in the network."); + + for (int i = 0; i < decoderNet.EvaluationNodes().size(); i++) + decoderEvalNodes.push_back(encoderNet.EvaluationNodes()[i]); + + for (int i = 0; i < decoderNet.FinalCriterionNodes().size(); i++) + decoderEvalNodes.push_back(decoderNet.FinalCriterionNodes()[i]); + } + else + { + for (int i = 0; i < decoderEvalNodeNames.size(); i++) + { + ComputationNodePtr node = decoderNet.GetNodeFromName(decoderEvalNodeNames[i]); + decoderNet.BuildAndValidateNetwork(node); + if (!node->FunctionValues().GetNumElements() == 1) + { + throw std::logic_error("The nodes passed to SimpleEvaluator::Evaluate function must be either eval or training criterion nodes (which evalues to 1x1 value)."); + } + decoderEvalNodes.push_back(node); + } + } + if (m_lst_pair_encoder_decoder_nodes.size() == 0) throw runtime_error("TrainOneEpochEncoderDecoderWithHiddenStates: no encoder and decoder node pairs"); diff --git a/MachineLearning/CNTK/SimpleNetworkBuilder.cpp b/MachineLearning/CNTK/SimpleNetworkBuilder.cpp index 2b32cce9c..090411f2f 100644 --- a/MachineLearning/CNTK/SimpleNetworkBuilder.cpp +++ b/MachineLearning/CNTK/SimpleNetworkBuilder.cpp @@ -355,6 +355,125 @@ namespace Microsoft { namespace MSR { namespace CNTK { return *m_net; } + /** + this builds an alignment based LM generator + the aligment node takes a variable length input and relates each element to a variable length output + */ + template + ComputationNetwork& SimpleNetworkBuilder::BuildAlignmentDecoderNetworkFromDescription(ComputationNetwork* encoderNet, size_t mbSize) + { + if (m_net->GetTotalNumberOfNodes() < 1) //not built yet + { + unsigned long randomSeed = 1; + + size_t numHiddenLayers = m_layerSizes.size() - 2; + + size_t numRecurrentLayers = m_recurrentLayers.size(); + + ComputationNodePtr input = nullptr, encoderOutput = nullptr, e = nullptr, + b = nullptr, w = nullptr, u = nullptr, delay = nullptr, output = nullptr, label = nullptr, alignoutput = nullptr; + ComputationNodePtr clslogpostprob = nullptr; + ComputationNodePtr clsweight = nullptr; + + input = m_net->CreateSparseInputNode(L"features", m_layerSizes[0], mbSize); + m_net->FeatureNodes().push_back(input); + + if (m_lookupTableOrder > 0) + { + e = m_net->CreateLearnableParameter(msra::strfun::wstrprintf(L"E%d", 0), m_layerSizes[1], m_layerSizes[0] / m_lookupTableOrder); + m_net->InitLearnableParameters(e, m_uniformInit, randomSeed++, m_initValueScale); + output = m_net->LookupTable(e, input, L"LookupTable"); + + if (m_addDropoutNodes) + input = m_net->Dropout(output); + else + input = output; + } + else + { + LogicError("BuildCLASSLSTMNetworkFromDescription: LSTMNode cannot take sparse input. Need to project sparse input to continuous vector using LookupTable. Suggest using setups below\n layerSizes=$VOCABSIZE$:100:$HIDDIM$:$VOCABSIZE$ \nto have 100 dimension projection, and lookupTableOrder=1\n to project to a single window. To use larger context window, set lookupTableOrder=3 for example with width-3 context window.\n "); + } + + int recur_idx = 0; + int offset = m_lookupTableOrder > 0 ? 1 : 0; + + /// the source network side output dimension needs to match the 1st layer dimension in the decoder network + std::vector & encoderEvaluationNodes = encoderNet->OutputNodes(); + if (encoderEvaluationNodes.size() != 1) + LogicError("BuildAlignmentDecoderNetworkFromDescription: encoder network should have only one output node as source node for the decoder network: "); + + encoderOutput = m_net->PairNetwork(encoderEvaluationNodes[0], L"pairNetwork"); + + if (numHiddenLayers > 0) + { + int i = 1 + offset; + u = m_net->CreateLearnableParameter(msra::strfun::wstrprintf(L"U%d", i), m_layerSizes[i], m_layerSizes[offset] * (offset ? m_lookupTableOrder : 1)); + m_net->InitLearnableParameters(u, m_uniformInit, randomSeed++, m_initValueScale); + w = m_net->CreateLearnableParameter(msra::strfun::wstrprintf(L"W%d", i), m_layerSizes[i], m_layerSizes[i]); + m_net->InitLearnableParameters(w, m_uniformInit, randomSeed++, m_initValueScale); + + delay = m_net->Delay(NULL, m_defaultHiddenActivity, (size_t)m_layerSizes[i], mbSize); +// output = (ComputationNodePtr)BuildLSTMNodeComponent(randomSeed, 0, m_layerSizes[offset] * (offset ? m_lookupTableOrder : 1), m_layerSizes[offset + 1], input); +// output = (ComputationNodePtr)BuildLSTMComponent(randomSeed, mbSize, 0, m_layerSizes[offset] * (offset ? m_lookupTableOrder : 1), m_layerSizes[offset + 1], input); + + /// alignment node to get weights from source to target + /// this aligment node computes weights of the current hidden state after special encoder ending symbol to all + /// states before the special encoder ending symbol. The weights are used to summarize all encoder inputs. + /// the weighted sum of inputs are then used as the additional input to the LSTM input in the next layer + e = m_net->CreateLearnableParameter(msra::strfun::wstrprintf(L"MatForSimilarity%d", i), m_layerSizes[i], m_layerSizes[i]); + m_net->InitLearnableParameters(e, m_uniformInit, randomSeed++, m_initValueScale); + + alignoutput = (ComputationNodePtr)m_net->Alignment(delay, encoderOutput, e, L"alignment"); + + output = ApplyNonlinearFunction( + m_net->Plus( + m_net->Times(u, input), m_net->Times(w, alignoutput)), 0); + delay->AttachInputs(output); + input = output; + + for (; i < numHiddenLayers; i++) + { + output = (ComputationNodePtr)BuildLSTMNodeComponent(randomSeed, i, m_layerSizes[i], m_layerSizes[i + 1], input); + //output = (ComputationNodePtr)BuildLSTMComponent(randomSeed, mbSize, i, m_layerSizes[i], m_layerSizes[i + 1], input); + + if (m_addDropoutNodes) + input = m_net->Dropout(output); + else + input = output; + } + + } + + + /// need to have [input_dim x output_dim] matrix + /// e.g., [200 x 10000], where 10000 is the vocabulary size + /// this is for speed-up issue as per word matrix can be simply obtained using column slice + w = m_net->CreateLearnableParameter(msra::strfun::wstrprintf(L"OW%d", numHiddenLayers), m_layerSizes[numHiddenLayers], m_layerSizes[numHiddenLayers + 1]); + m_net->InitLearnableParameters(w, m_uniformInit, randomSeed++, m_initValueScale); + + /// the label is a dense matrix. each element is the word index + label = m_net->CreateInputNode(L"labels", 4, mbSize); + + clsweight = m_net->CreateLearnableParameter(L"WeightForClassPostProb", m_nbrCls, m_layerSizes[numHiddenLayers]); + m_net->InitLearnableParameters(clsweight, m_uniformInit, randomSeed++, m_initValueScale); + clslogpostprob = m_net->Times(clsweight, input, L"ClassPostProb"); + + output = AddTrainAndEvalCriterionNodes(input, label, w, L"TrainNodeClassBasedCrossEntropy", L"EvalNodeClassBasedCrossEntrpy", + clslogpostprob); + + output = m_net->Times(m_net->Transpose(w), input, L"outputs"); + + m_net->OutputNodes().push_back(output); + + //add softmax layer (if prob is needed or KL reg adaptation is needed) + output = m_net->Softmax(output, L"PosteriorProb"); + } + + m_net->ResetEvalTimeStamp(); + + return *m_net; + } + template ComputationNetwork& SimpleNetworkBuilder::BuildLogBilinearNetworkFromDescription(size_t mbSize) { diff --git a/MachineLearning/CNTK/SimpleNetworkBuilder.h b/MachineLearning/CNTK/SimpleNetworkBuilder.h index a37045f4e..70bf4bf19 100644 --- a/MachineLearning/CNTK/SimpleNetworkBuilder.h +++ b/MachineLearning/CNTK/SimpleNetworkBuilder.h @@ -36,7 +36,9 @@ namespace Microsoft { namespace MSR { namespace CNTK { NPLM = 32, CLASSLSTM = 64, NCELSTM = 128, CLSTM = 256, RCRF = 512, UNIDIRECTIONALLSTM=19, - BIDIRECTIONALLSTM= 20} RNNTYPE; + BIDIRECTIONALLSTM= 20, + ALIGNMENTSIMILARITYGENERATOR=21 + } RNNTYPE; enum class TrainingCriterion : int @@ -179,6 +181,8 @@ namespace Microsoft { namespace MSR { namespace CNTK { if (std::find(strType.begin(), strType.end(), L"JOINTCONDITIONALBILSTMSTREAMS") != strType.end() || std::find(strType.begin(), strType.end(), L"BIDIRECTIONALLSTMWITHPASTPREDICTION") != strType.end()) m_rnnType = BIDIRECTIONALLSTM; + if (std::find(strType.begin(), strType.end(), L"ALIGNMENTSIMILARITYGENERATOR") != strType.end()) + m_rnnType = ALIGNMENTSIMILARITYGENERATOR; } // Init - Builder Initialize for multiple data sets @@ -235,7 +239,7 @@ namespace Microsoft { namespace MSR { namespace CNTK { } virtual ComputationNetwork& LoadNetworkFromFile(const wstring& modelFileName, bool forceLoad = true, - bool bAllowNoCriterion = false) + bool bAllowNoCriterion = false, ComputationNetwork* anotherNetwork=nullptr) { if (m_net->GetTotalNumberOfNodes() == 0 || forceLoad) //not built or force load { @@ -252,7 +256,7 @@ namespace Microsoft { namespace MSR { namespace CNTK { } else { - m_net->LoadFromFile(modelFileName, FileOptions::fileOptionsBinary, bAllowNoCriterion); + m_net->LoadFromFile(modelFileName, FileOptions::fileOptionsBinary, bAllowNoCriterion, anotherNetwork); } } @@ -260,7 +264,7 @@ namespace Microsoft { namespace MSR { namespace CNTK { return *m_net; } - ComputationNetwork& BuildNetworkFromDescription() + ComputationNetwork& BuildNetworkFromDescription(ComputationNetwork* encoderNet) { size_t mbSize = 1; @@ -288,6 +292,8 @@ namespace Microsoft { namespace MSR { namespace CNTK { return BuildUnidirectionalLSTMNetworksFromDescription(mbSize); if (m_rnnType == BIDIRECTIONALLSTM) return BuildBiDirectionalLSTMNetworksFromDescription(mbSize); + if (m_rnnType == ALIGNMENTSIMILARITYGENERATOR) + return BuildAlignmentDecoderNetworkFromDescription(encoderNet, mbSize); if (m_net->GetTotalNumberOfNodes() < 1) //not built yet { @@ -421,7 +427,8 @@ namespace Microsoft { namespace MSR { namespace CNTK { ComputationNetwork& BuildNCELSTMNetworkFromDescription(size_t mbSize = 1); - + ComputationNetwork& BuildAlignmentDecoderNetworkFromDescription(ComputationNetwork* encoderNet, size_t mbSize = 1); + ComputationNetwork& BuildNetworkFromDbnFile(const std::wstring& dbnModelFileName) { diff --git a/MachineLearning/CNTK/TrainingCriterionNodes.h b/MachineLearning/CNTK/TrainingCriterionNodes.h index a6742a29d..70533f45b 100644 --- a/MachineLearning/CNTK/TrainingCriterionNodes.h +++ b/MachineLearning/CNTK/TrainingCriterionNodes.h @@ -1408,7 +1408,7 @@ namespace Microsoft { namespace MSR { namespace CNTK { size_t i = t % nS; if (m_existsSentenceBeginOrNoLabels->ColumnSlice(j, 1).Get00Element() == EXISTS_SENTENCE_BEGIN_OR_NO_LABELS) { - if ((*m_sentenceSeg)(j, i) == NO_LABELS) + if ((*m_sentenceSeg)(i,j) == NO_LABELS) { matrixToBeMasked.ColumnSlice(t,1).SetValue(0);