diff --git a/MachineLearning/CNTKComputationNetworkLib/RecurrentNodes.h b/MachineLearning/CNTKComputationNetworkLib/RecurrentNodes.h index 18394dcd4..1a965a798 100644 --- a/MachineLearning/CNTKComputationNetworkLib/RecurrentNodes.h +++ b/MachineLearning/CNTKComputationNetworkLib/RecurrentNodes.h @@ -588,1091 +588,4 @@ namespace Microsoft { namespace MSR { namespace CNTK { template class FutureValueNode; template class FutureValueNode; - // ----------------------------------------------------------------------- - // LSTMNode (obs, inputGate, forgetGate, outputGate, memoryCellWgt) - // deprecated early implementation of LSTM operating on minibatches directly - // - input(0) : child with dimension [inputdim x T] - // - input(1) : input gate [outputdim x [inputdim + outputdim + 2]] bi, Wxi, Whi, Wci - // - input(2) : forget gate [outputdim x [inputdim + outputdim + 2]] for bf, Wxf, Whf, Wcf - // - input(3) : output gate [outputdim x [inputdim + outputdim + 2]] for bo, Wxo, Who, and Wco - // - input(4) : memory cell weight [outputdim x [inputdim + outputdim + 1]] for bc, Wxc, and Whc - // - output : dimension [outputdim x T] - // ----------------------------------------------------------------------- - - /** - LSTM specific node. This node uses matrix operations to have LSTM functionality. - It avoids using general recurrent loop operations in the network operations in ComputationNetwork. - - Developed by Kaisheng Yao - Used in the following works: - K. Yao, G. Zweig, "Sequence to sequence neural net models for graphone to phoneme conversion", in Interspeech 2015 - */ - template - class LSTMNode : public ComputationNodeNonLooping/*ComputationNode*/, public NumInputs<5> - { - typedef ComputationNodeNonLooping Base; UsingComputationNodeMembersBoilerplate; - static const std::wstring TypeName() { return L"LSTM"; } - public: - DeclareConstructorFromConfigWithNumInputs(LSTMNode); - LSTMNode(DEVICEID_TYPE deviceId, const wstring & name) : Base(deviceId, name), - m_State(deviceId), m_PastState(deviceId), - m_PastOutput(deviceId), m_Gi(deviceId), m_Gf(deviceId), m_Go(deviceId), grdToObs(deviceId), grdToInputGate(deviceId), - grdToForgetGate(deviceId), grdToOutputGate(deviceId), grdToCellWgt(deviceId), tanhObs(deviceId), - tanhState(deviceId), m_tempMatrix(deviceId), - mSlicePrevState(deviceId), mSlicePrevOutput(deviceId), - grdBeforeInputGate(deviceId), - grdBeforeForget(deviceId), grdBeforeGo(deviceId), grdToCell(deviceId), - grdBeforeTanhInputGate(deviceId), m_obs_error_from_future_minibatch(deviceId), - m_state_error_from_future_minibatch(deviceId), mLastState(deviceId), mLastOutput(deviceId), - m_inputDim(0), - m_outputDim(0), - m_use_errors_from_future_minibatch(false), - m_DefaultState((ElemType)DEFAULT_HIDDEN_ACTIVATION) - { - } - - virtual void Save(File& fstream) const override - { - Base::Save(fstream); - fstream << m_inputDim << m_outputDim; - fstream << m_DefaultState; - } - - virtual void Load(File& fstream, size_t modelVersion) override - { - Base::Load(fstream, modelVersion); - if (modelVersion == 2) - fstream >> m_inputDim >> m_outputDim; - fstream >> m_DefaultState; - } - - virtual void CopyTo(ComputationNodeBasePtr nodeP, const std::wstring& newName, const CopyNodeFlags flags) const override - { - Base::CopyTo(nodeP, newName, flags); - if (flags & CopyNodeFlags::copyNodeValue) - { - auto node = dynamic_pointer_cast>(nodeP); - node->m_inputDim = m_inputDim; - node->m_outputDim = m_outputDim; - - node->m_State = m_State; // hidden state activity - node->m_PastState = m_PastState; // state activity in the previous minibatch - node->m_PastOutput = m_PastOutput; // output in the previou minibatch - - node->m_Gi = m_Gi; // input gate activity - node->m_Gf = m_Gf; // forget gate activity - node->m_Go = m_Go; // output gate activity - - node->mSlicePrevOutput = mSlicePrevOutput; - node->mSlicePrevState = mSlicePrevState; - - node->m_use_errors_from_future_minibatch = m_use_errors_from_future_minibatch; - - node->m_DefaultState = m_DefaultState; - } - } - - virtual void BackpropToNonLooping(size_t inputIndex) override - { - if (inputIndex > 4) - InvalidArgument("LSTM operation only takes five inputs."); - - size_t nT = Input(0)->GetNumCols(); - size_t inputDim = Input(0)->GetNumRows(); - size_t outputDim = Input(1)->GetNumRows(); - - if (m_GradientComputed == false) - { - if (GetNumCols() != GradientValues().GetNumCols() || - GetNumRows() != GradientValues().GetNumRows()) - { - RuntimeError("LSTMNode::GradientValue size doesn't match to the function value size"); - } - - // reset gradients - grdToObs.Resize(inputDim, nT); grdToObs.SetValue(0); - grdToInputGate.Resize(Input(1)->GetNumRows(), Input(1)->GetNumCols()); grdToInputGate.SetValue(0); - grdToForgetGate.Resize(Input(2)->GetNumRows(), Input(2)->GetNumCols()); grdToForgetGate.SetValue(0); - grdToOutputGate.Resize(Input(3)->GetNumRows(), Input(3)->GetNumCols()); grdToOutputGate.SetValue(0); - grdToCellWgt.Resize(Input(4)->GetNumRows(), Input(4)->GetNumCols()); grdToCellWgt.SetValue(0); - - Matrix slicePrevOutput(m_deviceId), slicePrevState(m_deviceId); - Matrix grdToPrevOutput(m_deviceId), grdToPrevState(m_deviceId); - Matrix stateError(m_deviceId); - slicePrevState.Resize(outputDim, GetNumParallelSequences()); - slicePrevOutput.Resize(outputDim, GetNumParallelSequences()); - slicePrevOutput.SetValue(0); - - stateError.Resize(slicePrevState.GetNumRows(), slicePrevState.GetNumCols()); - - grdToPrevOutput.Resize(slicePrevOutput.GetNumRows(), slicePrevOutput.GetNumCols()); - grdToPrevState.Resize(slicePrevState.GetNumRows(), slicePrevState.GetNumCols()); - grdToPrevOutput.SetValue(0); - grdToPrevState.SetValue(0); - - for (int timeIdxInSeq = nT - GetNumParallelSequences(); timeIdxInSeq >= 0; timeIdxInSeq -= GetNumParallelSequences()) - { - FrameRange frameRange(m_pMBLayout, timeIdxInSeq); - Matrix sliceObs = Input(0)->OutputFor(frameRange/*TODO: delete this:*/.Check(timeIdxInSeq, GetNumParallelSequences(), m_pMBLayout)); - Matrix sliceOutput = OutputFor(frameRange/*TODO: delete this:*/.Check(timeIdxInSeq, GetNumParallelSequences(), m_pMBLayout)); - Matrix sliceState = DataFor(m_State, frameRange/*TODO: delete this:*/.Check(timeIdxInSeq, GetNumParallelSequences(), m_pMBLayout)); - - Matrix sliceGi = DataFor(m_Gi, frameRange/*TODO: delete this:*/.Check(timeIdxInSeq, GetNumParallelSequences(), m_pMBLayout)); - Matrix sliceGf = DataFor(m_Gf, frameRange/*TODO: delete this:*/.Check(timeIdxInSeq, GetNumParallelSequences(), m_pMBLayout)); - Matrix sliceGo = DataFor(m_Go, frameRange/*TODO: delete this:*/.Check(timeIdxInSeq, GetNumParallelSequences(), m_pMBLayout)); - - Matrix sliceTanhState = DataFor(tanhState, frameRange/*TODO: delete this:*/.Check(timeIdxInSeq, GetNumParallelSequences(), m_pMBLayout)); - Matrix sliceTanhObs = DataFor(tanhObs, frameRange/*TODO: delete this:*/.Check(timeIdxInSeq, GetNumParallelSequences(), m_pMBLayout)); - - Matrix error = GradientFor(frameRange/*TODO: delete this:*/.Check(timeIdxInSeq, GetNumParallelSequences(), m_pMBLayout)); - - Matrix grdToObsSlice(this->m_deviceId); - - -#ifdef DEBUG_DECODER - fprintf(stderr, "original output error [%ld] norm = %.8e\n", timeIdxInSeq, error.FrobeniusNorm()); -#endif - - PrepareThisErrorsBeforeBackProp(timeIdxInSeq, nT, error, stateError, grdToPrevOutput, grdToPrevState, - m_obs_error_from_future_minibatch, m_state_error_from_future_minibatch, GetNumParallelSequences(), &m_pMBLayout->GetM()); - -#ifdef DEBUG_DECODER - fprintf(stderr, "output error [%ld] norm = %.8e\n", timeIdxInSeq, error.FrobeniusNorm()); - fprintf(stderr, "state error [%ld] norm = %.8e\n", timeIdxInSeq, stateError.FrobeniusNorm()); -#endif - - grdToPrevOutput.Resize(slicePrevOutput.GetNumRows(), slicePrevOutput.GetNumCols()); - grdToPrevState.Resize(slicePrevState.GetNumRows(), slicePrevState.GetNumCols()); - grdToPrevOutput.SetValue(0); - grdToPrevState.SetValue(0); - - PrepareHistory(timeIdxInSeq, mSlicePrevOutput, mSlicePrevState, Output(), m_State, m_PastOutput, m_PastState, GetNumParallelSequences(), m_DefaultState, &m_pMBLayout->GetM()); - - ComputeInputGradientWrtGates( - error, - sliceObs, - grdToObsSlice, - Input(1)->Output(), - grdToInputGate, - Input(2)->Output(), - grdToForgetGate, - Input(3)->Output(), - grdToOutputGate, - Input(4)->Output(), - grdToCellWgt, - mSlicePrevOutput, - mSlicePrevState, - stateError, - sliceState, - sliceTanhState, - sliceTanhObs, - sliceGi, - sliceGf, - sliceGo, - grdToPrevOutput, - grdToPrevState, - m_tempMatrix - ); - DataFor(grdToObs, frameRange/*TODO: delete this:*/.Check(timeIdxInSeq, GetNumParallelSequences(), m_pMBLayout)).SetValue(grdToObsSlice); - - PrepareErrors(timeIdxInSeq, grdToPrevOutput, grdToPrevState, GetNumParallelSequences(), &m_pMBLayout->GetM()); - } -#ifdef DEBUG_DECODER - fprintf(stderr, "after error prop b_c norm = %.8e\n", Input(4)->Output().ColumnSlice(0, 1).FrobeniusNorm()); -#endif - m_obs_error_from_future_minibatch = grdToPrevOutput; - m_state_error_from_future_minibatch = grdToPrevState; - - -#ifdef DEBUG_DECODER - fprintf(stderr, "pass error to encoder error = %.4e state error = %.4e\n", m_obs_error_from_future_minibatch.FrobeniusNorm(), m_state_error_from_future_minibatch.FrobeniusNorm()); -#endif - m_GradientComputed = true; - } - - if (inputIndex == 0) //derivative with regard to the observation - { - if (Input(inputIndex)->GradientValues().HasNoElements()) - Input(inputIndex)->GradientValues().SetValue(grdToObs); - else - Input(inputIndex)->GradientValues() += grdToObs; - } - - if (inputIndex == 1) - { - if (Input(inputIndex)->GradientValues().HasNoElements()) - Input(inputIndex)->GradientValues().SetValue(grdToInputGate); - else - Input(inputIndex)->GradientValues() += grdToInputGate; - } - - if (inputIndex == 2) - { - if (Input(inputIndex)->GradientValues().HasNoElements()) - Input(inputIndex)->GradientValues().SetValue(grdToForgetGate); - else - Input(inputIndex)->GradientValues() += grdToForgetGate; - } - - if (inputIndex == 3) - { - if (Input(inputIndex)->GradientValues().HasNoElements()) - Input(inputIndex)->GradientValues().SetValue(grdToOutputGate); - else - Input(inputIndex)->GradientValues() += grdToOutputGate; - } - - if (inputIndex == 4) - { - if (Input(inputIndex)->GradientValues().HasNoElements()) - Input(inputIndex)->GradientValues().SetValue(grdToCellWgt); - else - Input(inputIndex)->GradientValues() += grdToCellWgt; - } -#ifdef DEBUG_DECODER - fprintf(stderr, "LSTM gradient[%d] norm = %.8e\n", inputIndex, Input(inputIndex)->GradientValues().FrobeniusNorm()); -#endif - - } - - static void WINAPI GradientOfTanh(const Matrix& functionValues, - const Matrix& gradientOut, - Matrix& inputGradientValues, - Matrix& extTmp) - { - Matrix mTmp(inputGradientValues.GetDeviceId()); - extTmp.AssignElementProductOf(functionValues, functionValues); // v .* v - mTmp.AssignDifferenceOf(1, extTmp); // 1-v^2 - if (inputGradientValues.GetNumRows() != functionValues.GetNumRows() || - inputGradientValues.GetNumCols() != functionValues.GetNumCols()) - LogicError("LSTMNode::GradientOfTanh : inputGradientValues need to be pre-allocated!"); - inputGradientValues.AddElementProductOf(gradientOut, mTmp); // d .* ((1-v) .* v)) - } - - static void WINAPI ComputeInputGradientWrtGates( - const Matrix& outGrd, // the error to h_t from upper layer - const Matrix & obs, - Matrix &grdToObs, - const Matrix& mInputGate, - Matrix &grdToInputGate, - const Matrix &mForgetGate, - Matrix &grdToForgetGate, - const Matrix &mOutputGate, - Matrix& grdToOutputGate, - const Matrix &mCellWgt, - Matrix &grdToCellWgt, - const Matrix& prevOutput, - const Matrix& prevState, - const Matrix& stateError, // the error propagated to cell from t+1 - const Matrix &state, - const Matrix &tanhState, - const Matrix & tanhBeforeApplyingInputGating, - const Matrix &gi, - const Matrix &gf, - const Matrix &go, - Matrix &grdToPrevOutput, - Matrix &grdToPrevState, - Matrix & tmpMat - ) - { - int inputDim = obs.GetNumRows(); - int outputDim = mOutputGate.GetNumRows(); - - assert(grdToPrevOutput.FrobeniusNorm() == 0); - assert(grdToPrevState.FrobeniusNorm() == 0); - assert(state.FrobeniusNorm() > 0); - Matrix Who = mOutputGate.ColumnSlice(1 + inputDim, outputDim); - Matrix Wco = mOutputGate.ColumnSlice(1 + inputDim + outputDim, 1); - Matrix Wxo = mOutputGate.ColumnSlice(1, inputDim); - Matrix grdToWho = grdToOutputGate.ColumnSlice(1 + inputDim, outputDim); - Matrix grdToWco = grdToOutputGate.ColumnSlice(1 + inputDim + outputDim, 1); - Matrix grdToWxo = grdToOutputGate.ColumnSlice(1, inputDim); - Matrix grdTobo = grdToOutputGate.ColumnSlice(0, 1); - - Matrix Whf = mForgetGate.ColumnSlice(1 + inputDim, outputDim); - Matrix Wcf = mForgetGate.ColumnSlice(1 + inputDim + outputDim, 1); - Matrix Wxf = mForgetGate.ColumnSlice(1, inputDim); - Matrix grdToWhf = grdToForgetGate.ColumnSlice(1 + inputDim, outputDim); - Matrix grdToWcf = grdToForgetGate.ColumnSlice(1 + inputDim + outputDim, 1); - Matrix grdToWxf = grdToForgetGate.ColumnSlice(1, inputDim); - Matrix grdTobf = grdToForgetGate.ColumnSlice(0, 1); - - Matrix Wxc = mCellWgt.ColumnSlice(1, inputDim); - Matrix Whc = mCellWgt.ColumnSlice(1 + inputDim, outputDim); - Matrix grdToWxc = grdToCellWgt.ColumnSlice(1, inputDim); - Matrix grdToWhc = grdToCellWgt.ColumnSlice(1 + inputDim, outputDim); - Matrix grdTobc = grdToCellWgt.ColumnSlice(0, 1); - - Matrix Whi = mInputGate.ColumnSlice(1 + inputDim, outputDim); - Matrix Wci = mInputGate.ColumnSlice(1 + inputDim + outputDim, 1); - Matrix Wxi = mInputGate.ColumnSlice(1, inputDim); - Matrix grdToWhi = grdToInputGate.ColumnSlice(1 + inputDim, outputDim); - Matrix grdToWci = grdToInputGate.ColumnSlice(1 + inputDim + outputDim, 1); - Matrix grdToWxi = grdToInputGate.ColumnSlice(1, inputDim); - Matrix grdTobi = grdToInputGate.ColumnSlice(0, 1); - - // error backpropagate to output gate - Matrix grdToGo(tmpMat.GetDeviceId()), gradientOfSigmoid(tmpMat.GetDeviceId()); - Matrix grdBeforeGo(tmpMat.GetDeviceId()), grdBeforeInputGate(tmpMat.GetDeviceId()); - Matrix grdToCell(tmpMat.GetDeviceId()); - - tmpMat.AssignElementProductOf(outGrd, tanhState); // error to o_t - gradientOfSigmoid.AssignSigmoidDerivativeOf(go); - grdBeforeGo.AssignElementProductOf(tmpMat, gradientOfSigmoid); // error before softmax -#ifdef DEBUG_DECODER - fprintf(stderr, "output gate error = %.4e\n", grdBeforeGo(0, 0)); -#endif - Matrix::MultiplyAndAdd(Who, true, grdBeforeGo, false, grdToPrevOutput); // error to previous output - Matrix::MultiplyAndAdd(Wxo, true, grdBeforeGo, false, grdToObs); // error to observation - tmpMat = grdBeforeGo; - tmpMat.ColumnElementMultiplyWith(Wco); - grdToCell = tmpMat; // error to memory cell - - Matrix::MultiplyAndAdd(grdBeforeGo, false, prevOutput, true, grdToWho); // gradient to Who - Matrix::MultiplyAndAdd(grdBeforeGo, false, obs, true, grdToWxo); // gradient to Wxo - tmpMat.AssignInnerProductOf(grdBeforeGo, state, false); - grdToWco += tmpMat; // to Wco - for (size_t i = 0; i < grdBeforeGo.GetNumCols(); i++) - { - grdTobo += grdBeforeGo.ColumnSlice(i, 1); // gradient to bo - } - - grdToGo.AssignElementProductOf(outGrd, go); // error to tanh - GradientOfTanh(tanhState, grdToGo, grdToCell, tmpMat); // error to memory cell - grdToCell += stateError; // add error to memory cell from t+1 -#ifdef DEBUG_DECODER - fprintf(stderr, "previous state[0] = %.4e norm = %.4e\n", prevState(0, 0), prevState.FrobeniusNorm()); - fprintf(stderr, "state error = %.4e\n", grdToCell(0, 0)); - fprintf(stderr, "state error norm = %.4e\n", grdToCell.FrobeniusNorm()); -#endif - // error backpropagate to memory cells - grdToPrevState.AssignElementProductOf(gf, grdToCell); // error to previous memory cell - // be careful, need to double check if errors are missing - - Matrix grdBeforeForget(tmpMat.GetDeviceId()); - tmpMat.AssignElementProductOf(prevState, grdToCell); // error to f_t - gradientOfSigmoid.AssignSigmoidDerivativeOf(gf); - grdBeforeForget.AssignElementProductOf(gradientOfSigmoid, tmpMat); // error before forget gate -#ifdef DEBUG_DECODER - fprintf(stderr, "forget gate error = %.4e\n", grdBeforeForget(0, 0)); -#endif - - Matrix::MultiplyAndAdd(Whf, true, grdBeforeForget, false, grdToPrevOutput); // error to previous output - tmpMat = grdBeforeForget; - tmpMat.ColumnElementMultiplyWith(Wcf); - grdToPrevState += tmpMat; // error to previous state - - Matrix::MultiplyAndAdd(Wxf, true, grdBeforeForget, false, grdToObs); // error to observation - - Matrix::MultiplyAndAdd(grdBeforeForget, false, prevOutput, true, grdToWhf); // gradient to Whf - tmpMat.AssignInnerProductOf(grdBeforeForget, prevState, false); - grdToWcf += tmpMat; // gradient to Wcf - - Matrix::MultiplyAndAdd(grdBeforeForget, false, obs, true, grdToWxf); // gradient to Wxf - for (size_t i = 0; i < grdBeforeForget.GetNumCols(); i++) - grdTobf += grdBeforeForget.ColumnSlice(i, 1); // gradient to bf - - // error backpropagate to input gate - tmpMat.AssignElementProductOf(tanhBeforeApplyingInputGating, grdToCell); - gradientOfSigmoid.AssignSigmoidDerivativeOf(gi); - grdBeforeInputGate.AssignElementProductOf(gradientOfSigmoid, tmpMat); // error before input gate -#ifdef DEBUG_DECODER - fprintf(stderr, "input gate error = %.4e\n", grdBeforeInputGate(0, 0)); -#endif - - Matrix::MultiplyAndAdd(Whi, true, grdBeforeInputGate, false, grdToPrevOutput); // error to previous output - tmpMat = grdBeforeInputGate; - tmpMat.ColumnElementMultiplyWith(Wci); - grdToPrevState += tmpMat; // error to previous state - -#ifdef DEBUG_DECODER - fprintf(stderr, "to previous state error = %.4e\n", grdToPrevState(0, 0)); - fprintf(stderr, "to previous state error norm = %.4e\n", grdToPrevState.FrobeniusNorm()); -#endif - Matrix::MultiplyAndAdd(Wxi, true, grdBeforeInputGate, false, grdToObs); // error to observation - - Matrix::MultiplyAndAdd(grdBeforeInputGate, false, prevOutput, true, grdToWhi); // gradient to Whi - tmpMat.AssignInnerProductOf(grdBeforeInputGate, prevState, false); - grdToWci += tmpMat; // gradient to Wci - Matrix::MultiplyAndAdd(grdBeforeInputGate, false, obs, true, grdToWxi); // gradient to Wxi - for (size_t i = 0; i < grdBeforeInputGate.GetNumCols(); i++) - grdTobi += grdBeforeInputGate.ColumnSlice(i, 1); // gradient to bi - - // error backpropagate to inputs - Matrix grdTmp2(tmpMat.GetDeviceId()); - Matrix grdBeforeTanhInputGate(tmpMat.GetDeviceId()); - grdTmp2.AssignElementProductOf(gi, grdToCell); - grdBeforeTanhInputGate.Resize(tanhBeforeApplyingInputGating.GetNumRows(), tanhBeforeApplyingInputGating.GetNumCols()); - GradientOfTanh(tanhBeforeApplyingInputGating, grdTmp2, grdBeforeTanhInputGate, tmpMat); // error to memory cell - Matrix::MultiplyAndAdd(Wxc, true, grdBeforeTanhInputGate, false, grdToObs); // error to observation -#ifdef DEBUG_DECODER - fprintf(stderr, "to observation error = %.4e\n", grdToObs(0, 0)); -#endif - - Matrix::MultiplyAndAdd(Whc, true, grdBeforeTanhInputGate, false, grdToPrevOutput); // error to previous output - Matrix::MultiplyAndAdd(grdBeforeTanhInputGate, false, obs, true, grdToWxc); // gradient to Wxc - - Matrix::MultiplyAndAdd(grdBeforeTanhInputGate, false, prevOutput, true, grdToWhc); // gradient to Whc - for (size_t i = 0; i < grdBeforeTanhInputGate.GetNumCols(); i++) - grdTobc += grdBeforeTanhInputGate.ColumnSlice(i, 1); // gradient to bc - - } - - /** - get the segmentation information, SENTENECE_BEGIN, ((int) MinibatchPackingFlags::None), ((int) MinibatchPackingFlags::NoInput) - for time at t and stream of streamid - */ - int GetSegInfo(size_t t, size_t streamid) - { - if (streamid >= GetNumParallelSequences()) - LogicError("GetSegInfo: stream id %d is larger than the number of streams %d", (int)streamid, (int)GetNumParallelSequences()); - - size_t nT = Input(0)->GetNumCols(); - if (t >= nT) - LogicError("GetSegInfo: time %d times is larger than the total number of observations %d", (int)t, (int)nT); - - int utt_t = (int)t / GetNumParallelSequences(); - auto thisCol = m_pMBLayout->GetFrame(utt_t).first; - thisCol.Reshape(1, GetNumParallelSequences()); - return (int) thisCol.ColumnSlice(streamid, 1).Get00Element(); - } - - /** - save the last hidden layer activity and output - */ - void SaveLastStateActity() - { - size_t nT = Input(0)->GetNumCols(); - size_t outputDim = Input(1)->GetNumRows(); - - // save the hidden activities and output for the next minibatch - mLastOutput.Resize(outputDim, GetNumParallelSequences()); - mLastState.Resize(outputDim, GetNumParallelSequences()); - - for (size_t i = 0; i < GetNumParallelSequences(); i++) - { - for (int t = nT - GetNumParallelSequences() + i; t >= 0; t -= GetNumParallelSequences()) - { - if (GetSegInfo(t, i) == ((int) MinibatchPackingFlags::None)) - { - mLastOutput.ColumnSlice(i, 1).SetValue(Output().ColumnSlice(t, 1)); - mLastState.ColumnSlice(i, 1).SetValue(m_State.ColumnSlice(t, 1)); - break; - } - } - } - } - - virtual void /*ComputationNodeNonLooping::*/ForwardPropNonLooping() override - { - size_t nT = Input(0)->GetNumCols(); - size_t outputDim = Input(1)->GetNumRows(); - - { - SetDims(outputDim, nT); - Output().SetValue(NAN); // set to this extrem value so, if anything wrong in later procedure, problems can be easily spotted. - m_State.Resize(outputDim, nT); - m_State.SetValue(NAN); // set to this extrem value so, if anything wrong in later procedure, problems can be easily spotted. - m_Gi.Resize(outputDim, nT); - m_Gi.SetValue(NAN); // set to this extrem value so, if anything wrong in later procedure, problems can be easily spotted. - m_Gf.Resize(outputDim, nT); - m_Gf.SetValue(NAN); // set to this extrem value so, if anything wrong in later procedure, problems can be easily spotted. - m_Go.Resize(outputDim, nT); - m_Go.SetValue(NAN); // set to this extrem value so, if anything wrong in later procedure, problems can be easily spotted. - tanhState.Resize(outputDim, nT); - tanhState.SetValue(NAN); // set to this extrem value so, if anything wrong in later procedure, problems can be easily spotted. - tanhObs.Resize(outputDim, nT); - tanhObs.SetValue(NAN); // set to this extrem value so, if anything wrong in later procedure, problems can be easily spotted. - - if (m_PastState.IsEmpty() || m_PastState.GetNumCols() != GetNumParallelSequences()) - { - m_PastState.Resize(outputDim, GetNumParallelSequences()); - m_PastState.SetValue(m_DefaultState); - } - if (m_PastOutput.IsEmpty() || m_PastOutput.GetNumCols() != GetNumParallelSequences()) - { - m_PastOutput.Resize(outputDim, GetNumParallelSequences()); - } - -#ifdef DEBUG_DECODER - if (m_PastOutput.IsEmpty() == false) - fprintf(stderr, "LSTM node %ls past output norm = %.8e\n", this->NodeName().c_str(), m_PastOutput.FrobeniusNorm()); - if (m_PastState.IsEmpty() == false) - fprintf(stderr, "LSTM node %ls past state norm = %.8e\n", this->NodeName().c_str(), m_PastState.FrobeniusNorm()); -#endif - - for (size_t timeIdxInSeq = 0; timeIdxInSeq < nT; timeIdxInSeq += GetNumParallelSequences()) - { - FrameRange frameRange(m_pMBLayout, timeIdxInSeq); - Matrix sliceObs = Input(0)->OutputFor(frameRange/*TODO: delete this:*/.Check(frameRange.t(), GetNumParallelSequences(), m_pMBLayout)); - Matrix sliceOutput = OutputFor(frameRange/*TODO: delete this:*/.Check(frameRange.t(), GetNumParallelSequences(), m_pMBLayout)); - Matrix sliceState = DataFor(m_State, frameRange/*TODO: delete this:*/.Check(frameRange.t(), GetNumParallelSequences(), m_pMBLayout)); - - Matrix sliceGi = DataFor(m_Gi, frameRange/*TODO: delete this:*/.Check(frameRange.t(), GetNumParallelSequences(), m_pMBLayout)); - Matrix sliceGf = DataFor(m_Gf, frameRange/*TODO: delete this:*/.Check(frameRange.t(), GetNumParallelSequences(), m_pMBLayout)); - Matrix sliceGo = DataFor(m_Go, frameRange/*TODO: delete this:*/.Check(frameRange.t(), GetNumParallelSequences(), m_pMBLayout)); - - Matrix sliceTanhState = DataFor(tanhState, frameRange/*TODO: delete this:*/.Check(frameRange.t(), GetNumParallelSequences(), m_pMBLayout)); - Matrix sliceTanhInput = DataFor(tanhObs, frameRange/*TODO: delete this:*/.Check(frameRange.t(), GetNumParallelSequences(), m_pMBLayout)); - - PrepareHistory(timeIdxInSeq, mSlicePrevOutput, mSlicePrevState, Output(), m_State, m_PastOutput, m_PastState, GetNumParallelSequences(), m_DefaultState, &m_pMBLayout->GetM()); - - ForwardPropS(Input(1)->Output(), Input(2)->Output(), Input(3)->Output(), Input(4)->Output(), - sliceObs, mSlicePrevOutput, mSlicePrevState, sliceOutput, sliceState, sliceGi, sliceGf, sliceGo, sliceTanhState, sliceTanhInput, m_tempMatrix); - } - - // save the hidden activities and output for the next minibatch - SaveLastStateActity(); - -#ifdef DEBUG_DECODER - if (mLastOutput.IsEmpty() == false) - fprintf(stderr, "LSTM node %ls last output norm = %.8e\n", this->NodeName().c_str(), mLastOutput.FrobeniusNorm()); - if (mLastState.IsEmpty() == false) - fprintf(stderr, "LSTM node %ls last state norm = %.8e\n", this->NodeName().c_str(), mLastState.FrobeniusNorm()); -#endif - -#ifdef DEBUG_DECODER - ElemType tmpnorm = Output().FrobeniusNorm(); - if (ISCLOSE(tmpnorm, 0.834251, 0.002)) - fprintf(stderr, "check!"); - fprintf(stderr, "LSTM function norm = %.8e\n", tmpnorm); - for (size_t i = 0; i < 5; i++) - fprintf(stderr, "LSTM input[%d] norm = %.8e ", i, Input(i)->Output().FrobeniusNorm()); - fprintf(stderr, "\n"); -#endif - - m_GradientComputed = false; - } - } - - /** - Prepare history for LSTMnode - - This function returns state and output from the previous time instance. For recurrent network, the initial state needs to be set in the case of sentence begining, which is carried over from sentenceBegin. In case of sentence begining, the state activity is set to an initial value. The sentenceBegin has element of ((int) MinibatchPackingFlags::SequenceStart), ((int) MinibatchPackingFlags::None) and ((int) MinibatchPackingFlags::NoInput), which are 0, 1, and -1, respectively. - To compute the initial value, we use - prevState = sentenceBegin * delayedActivation + ~sentenceBegin * initialStateValue - and ~sentenceBegin is computed as -1*(sentenceBegin - 1), assuming that sentenceBegin is either 0 or 1. For example, when sentenceBegin == 1, ~sentenceBegin == 0. - The previous-time output doesn't have initial value, so it is computed as - prevOutput = sentenceBegin * pastOutput - - */ - // prepare prevstate and prevoutput - static void WINAPI PrepareHistory( - size_t timeIdxInSeq, - Matrix & slicePrevOutput, - Matrix & slicePrevState, - const Matrix & output, - const Matrix & state, - const Matrix & pastOutput, - const Matrix & pastState, - size_t nsamples, const ElemType & initStateValue, const Matrix* sentenceBegin) - { - size_t nRow = pastOutput.GetNumRows(); - size_t nStream = sentenceBegin->GetNumRows(); - - assert(nStream == nsamples); - - int utt_t = (int)floor(timeIdxInSeq / nsamples); - if (slicePrevOutput.IsEmpty() || slicePrevOutput.GetNumRows() != nRow || slicePrevOutput.GetNumCols() != nsamples) - slicePrevOutput.Resize(nRow, nsamples); - if (slicePrevState.IsEmpty() || slicePrevState.GetNumRows() != nRow || slicePrevState.GetNumCols() != nsamples) - slicePrevState.Resize(nRow, nsamples); - - if (sentenceBegin->GetNumRows() != nsamples) - LogicError("Number of rows should be the same as the number of data streams"); - - Matrix colBegin(sentenceBegin->GetDeviceId()); - colBegin.SetValue(sentenceBegin->ColumnSlice(utt_t, 1)); - Matrix colSeg(colBegin.GetDeviceId()); - colSeg.Resize(nStream, nStream); - // will reset to 0 if sentence begining at a position is 0 - // will keep the output if it is not the sentence begining - colBegin.InplaceTruncateBottom(((int) MinibatchPackingFlags::SequenceStart)); - colBegin.InplaceTruncateTop(((int) MinibatchPackingFlags::None)); -#if 1 - initStateValue; pastState; pastOutput; state; output; - LogicError("PrepareHistory: finish this"); -#else - // BUGBUG: we need to upcast float to double here - colSeg.SetDiagonalValue(colBegin); - - Matrix newPrevOutput(colBegin.GetDeviceId()); - Matrix newPrevState(colBegin.GetDeviceId()); - if (utt_t == 0) - { - // this is the begining of this minibatch - Matrix::Multiply(pastOutput.ColumnSlice(0, nsamples), false, colSeg, false, newPrevOutput); - Matrix::Multiply(pastState.ColumnSlice(0, nsamples), false, colSeg, false, newPrevState); - } - else - { - // this is in the minibatch - FrameRange frameRange(timeIdxInSeq, nsamples); - Matrix::Multiply(DataFor(output, frameRange/*TODO: delete the next two parameters*/, frameRange.t() - nsamples, nsamples), false, colSeg, false, newPrevOutput); - Matrix::Multiply(DataFor(state, frameRange/*TODO: delete the next two parameters*/, frameRange.t() - nsamples, nsamples), false, colSeg, false, newPrevState); - } - - Base::SetToInitStateValueForResetSeg(sentenceBegin->ColumnSlice(utt_t, 1), nStream, initStateValue, newPrevState); - - slicePrevOutput.ColumnSlice(0, nsamples).SetValue(newPrevOutput); - slicePrevState.ColumnSlice(0, nsamples).SetValue(newPrevState); -#endif - } - - // prepare prevstate and prevoutput - void PrepareThisErrorsBeforeBackProp( - size_t timeIdxInSeq, - size_t nT, // number of columns - Matrix & error, - Matrix & stateError, - const Matrix& grdToPrevOutput, - const Matrix& grdToPrevState, - const Matrix& obs_error_from_future_minibatch, - const Matrix& state_error_from_future_minibatch, - size_t nsamples, const Matrix* sentenceBegin) - { - int utt_t = (int)floor(timeIdxInSeq / nsamples); - int total_utt_t = (int)floor(nT / nsamples); - - error += grdToPrevOutput; - stateError = grdToPrevState; - - if (m_use_errors_from_future_minibatch) - { - for (size_t utt_id = 0; utt_id < nsamples; utt_id++) - { - // if uses errors from future minibatch - if ((GetSegInfo(timeIdxInSeq, utt_id) == ((int) MinibatchPackingFlags::None) && utt_t == total_utt_t - 1) // last time - || (utt_t < total_utt_t - 1 && GetSegInfo(timeIdxInSeq, utt_id) == ((int) MinibatchPackingFlags::None) && GetSegInfo(timeIdxInSeq + nsamples, utt_id) == ((int) MinibatchPackingFlags::NoInput)) // future observation is no observation - ) - { - error.ColumnSlice(utt_id, 1) += obs_error_from_future_minibatch.ColumnSlice(utt_id, 1); - stateError.ColumnSlice(utt_id, 1) += state_error_from_future_minibatch.ColumnSlice(utt_id, 1); - } - } - } - - -#if 1 - sentenceBegin; - LogicError("PrepareThisErrorsBeforeBackProp: finish this"); -#else - Matrix colBegin(sentenceBegin->GetDeviceId()); - colBegin.SetValue(sentenceBegin->ColumnSlice(utt_t, 1)); - colBegin.InplaceTruncateBottom(((int) MinibatchPackingFlags::NoInput)); - colBegin.InplaceTruncateTop(((int) MinibatchPackingFlags::SequenceStart)); - colBegin += fabs((ElemType)((int) MinibatchPackingFlags::NoInput)); // raise this so that -1 -> 0 and therefore - Matrix colSeg(colBegin.GetDeviceId()); - colSeg.Resize(nsamples, nsamples); - colSeg.SetDiagonalValue(colBegin); - - // times the errors with the mask - Matrix newOutputError(colBegin.GetDeviceId()); - Matrix newStateError(colBegin.GetDeviceId()); - - Matrix::Multiply(error, false, colSeg, false, newOutputError); - Matrix::Multiply(stateError, false, colSeg, false, newStateError); - - error.ColumnSlice(0, nsamples).SetValue(newOutputError); - stateError.ColumnSlice(0, nsamples).SetValue(newStateError); -#endif - } - - // prepare prevstate and prevoutput - static void WINAPI PrepareErrors( - size_t timeIdxInSeq, - Matrix & errors, - Matrix & stateError, - size_t nsamples, const Matrix* sentenceBegin) - { - int utt_t = (int)floor(timeIdxInSeq / nsamples); - Matrix colBegin(sentenceBegin->GetDeviceId()); -#if 1 - errors; stateError; utt_t; - LogicError("PrepareErrors: finish this"); -#else - colBegin.SetValue(sentenceBegin->ColumnSlice(utt_t, 1)); - // will reset to 0 if sentence begining at a posiiton is 0 - // will keep the output if it is not the sentence begining - colBegin.InplaceTruncateBottom(((int) MinibatchPackingFlags::SequenceStart)); - colBegin.InplaceTruncateTop(((int) MinibatchPackingFlags::None)); - - Matrix colSeg(colBegin.GetDeviceId()); - colSeg.Resize(nsamples, nsamples); - colSeg.SetDiagonalValue(colBegin); - - // times the errors with the mask - Matrix newOutputError(colBegin.GetDeviceId()); - Matrix newStateError(colBegin.GetDeviceId()); - - Matrix::Multiply(errors, false, colSeg, false, newOutputError); - Matrix::Multiply(stateError, false, colSeg, false, newStateError); - - errors.ColumnSlice(0, nsamples).SetValue(newOutputError); - stateError.ColumnSlice(0, nsamples).SetValue(newStateError); -#endif - } - - /*TODO: merge with call site*/void ForwardPropS( - const Matrix& mInputGate, - const Matrix &mForgetGate, const Matrix &mOutputGate, - const Matrix &mCellWgt, - const Matrix &obs, - const Matrix& prevOutput, - const Matrix& prevState, - Matrix &output, - Matrix &state, - Matrix &gi, - Matrix &gf, - Matrix &go, - Matrix &tanhState, - Matrix &tanhObs, - Matrix &tmp) - { - int inputDim = obs.GetNumRows(); - int outputDim = mOutputGate.GetNumRows(); - - // for input gate - Matrix::Multiply(mInputGate.ColumnSlice(1, inputDim), false, obs, false, gi); - Matrix::MultiplyAndAdd(mInputGate.ColumnSlice(1 + inputDim, outputDim), false, prevOutput, false, gi); - gi += mInputGate.ColumnSlice(0, 1); - tmp = prevState; - tmp.ColumnElementMultiplyWith(mInputGate.ColumnSlice(1 + inputDim + outputDim, 1)); - gi += tmp; - gi.AssignSigmoidOf(gi); - - // for forget gate - Matrix::Multiply(mForgetGate.ColumnSlice(1, inputDim), false, obs, false, gf); - Matrix::MultiplyAndAdd(mForgetGate.ColumnSlice(1 + inputDim, outputDim), false, prevOutput, false, gf); - gf += mForgetGate.ColumnSlice(0, 1); - tmp = prevState; - tmp.ColumnElementMultiplyWith(mForgetGate.ColumnSlice(1 + inputDim + outputDim, 1)); - gf += tmp; - gf.AssignSigmoidOf(gf); - - // for cell state - Matrix::Multiply(mCellWgt.ColumnSlice(1, inputDim), false, obs, false, state); - Matrix::MultiplyAndAdd(mCellWgt.ColumnSlice(1 + inputDim, outputDim), false, prevOutput, false, state); - state += mCellWgt.ColumnSlice(0, 1); -#ifdef DEBUG_DECODER -// fprintf(stderr, "W_xc norm = %.8e\n", mCellWgt.ColumnSlice(1, inputDim).FrobeniusNorm()); -// fprintf(stderr, "W_hc norm = %.8e\n", mCellWgt.ColumnSlice(1 + inputDim, outputDim).FrobeniusNorm()); -// fprintf(stderr, "b_c norm = %.8e\n", mCellWgt.ColumnSlice(0, 1).FrobeniusNorm()); -#endif - tanhObs.AssignTanhOf(state); - state.AssignElementProductOf(gi, tanhObs); - state.AddElementProductOf(gf, prevState); - - // for output gate - Matrix::Multiply(mOutputGate.ColumnSlice(1, inputDim), false, obs, false, go); - Matrix::MultiplyAndAdd(mOutputGate.ColumnSlice(1 + inputDim, outputDim), false, prevOutput, false, go); - go += mOutputGate.ColumnSlice(0, 1); - tmp = state; - tmp.ColumnElementMultiplyWith(mOutputGate.ColumnSlice(1 + inputDim + outputDim, 1)); - go += tmp; - go.AssignSigmoidOf(go); - - // to return output - tanhState.AssignTanhOf(state); - output.AssignElementProductOf(go, tanhState); - } - - - // input(0) : child with dimension [inputdim x T] - // input(1) : input gate [outputdim x [inputdim + outputdim + 2]] bi, Wxi, Whi, Wci - // input(2) : forget gate [outputdim x [inputdim + outputdim + 2]] for bf, Wxf, Whf, Wcf - // input(3) : output gate [outputdim x [inputdim + outputdim + 2]] for bo, Wxo, Who, and Wco - // input(4) : memory cell weight [outputdim x [inputdim + outputdim + 1]] for bc, Wxc, and Whc - // output : dimension [outputdim x T] - virtual void /*ComputationNodeBase::*/Validate(bool isFinalValidationPass) override - { - Base::Validate(isFinalValidationPass); - - InferMBLayoutFromInputsForStandardCase(); - InferImageDimsFromInputs(); - - if (Input(0)->Output().GetMatrixType() == SPARSE) - LogicError("LSTMNode: input to LSTM has to be dense matrix. Consider adding a project layer using lookuptable before LSTM node. "); - -#if 0 - // TODO: use dynamic_pointer_cast instead - if (Input(1)->OperationName() != OperationNameOf(LearnableParameter) || - Input(2)->OperationName() != OperationNameOf(LearnableParameter) || - Input(3)->OperationName() != OperationNameOf(LearnableParameter) || - Input(4)->OperationName() != OperationNameOf(LearnableParameter)) - LogicError("LSTM validation: need to have learnable parameters "); -#endif - - //if (Input(0)->GetNumRows() == 0) - // LogicError("LSTM validation: input size is zero!"); - - //if (Input(1)->GetNumRows() == 0 || - // Input(2)->GetNumRows() == 0 || - // Input(3)->GetNumRows() == 0 || - // Input(4)->GetNumRows() == 0) - // LogicError("LSTM validation : parameter size is zero!"); - - size_t nindim = Input(0)->GetNumRows(); - size_t noutdim = Input(1)->GetNumRows(); - size_t nT = Input(0)->GetNumCols(); - size_t nCol = nindim + noutdim + 2; - if (isFinalValidationPass) - { - if (Input(1)->GetNumCols() != nCol) - { - LogicError("LSTM validation : dimension mismatched between child and inputGate"); - } - if (Input(2)->GetNumCols() != nCol) - { - LogicError("LSTM validation : dimension mismatched between child and forgetGate"); - } - if (Input(3)->GetNumCols() != nCol) - { - LogicError("LSTM validation : dimension mismatched between child and outputGate"); - } - - if (noutdim != Input(2)->GetNumRows() || - noutdim != Input(3)->GetNumRows() || - noutdim != Input(4)->GetNumRows()) - { - LogicError("LSTM validation: output dimension mismatched!"); - } - } - - SetDims(noutdim, nT); - Output().SetValue(NAN); // set to this extrem value so, if anything wrong in later procedure, problems can be easily spotted. - } - - bool UnitTest() - { - { - size_t nT = 3; - size_t nInput = 2; - size_t nHidden = 3; - size_t nOutput = 3; - - // backup - Matrix f0(m_deviceId), f1(m_deviceId), f2(m_deviceId), f3(m_deviceId), f4(m_deviceId), func(m_deviceId), f5(m_deviceId); - Matrix target(m_deviceId); - Matrix giWeight, ghWeight, goWeight; - ElemType initStateValue = m_DefaultState; - auto pMBLayout = make_shared(); - pMBLayout->Init(1, nT); - //Matrix & boundary = pMBLayout->m_sentenceBoundaryFlags; - //vector & minibatchPackingFlags = pMBLayout->m_minibatchPackingFlags; - //boundary.ColumnSlice(0, 1).SetValue(((int) MinibatchPackingFlags::SequenceStart)); - //minibatchPackingFlags[1] = MinibatchPackingFlags::SequenceStart; - pMBLayout->Set(0, 1, MinibatchPackingFlags::SequenceStart); // TODO: strange--start at frame[1] instead of [0]? - Base::LinkToMBLayout(pMBLayout); - - f0 = Input(0)->Output(); - f1 = Input(1)->Output(); - f2 = Input(2)->Output(); - f3 = Input(3)->Output(); - f4 = Input(4)->Output(); - func = Output(); - - target.Resize(nOutput, nT); - for (size_t i = 0; i < nT; i++) - target(0, i) = 1; - - Input(0)->SetDims(nInput, nT); - Input(0)->Output().SetValue(ConstOnes(nInput, nT, m_deviceId)); - Input(0)->Output().SetValue((ElemType)0.1); - Input(1)->SetDims(nHidden, nInput + nOutput + 2); - Input(1)->Output().SetValue((ElemType)0.1); - Input(2)->SetDims(nHidden, nInput + nHidden + 2); - Input(2)->Output().SetValue((ElemType)0.1); - Input(3)->SetDims(nOutput, nInput + nHidden + 2); - Input(3)->Output().SetValue((ElemType)0.1); - Input(4)->SetDims(nOutput, nHidden + nInput + 1); - Input(4)->Output().SetValue((ElemType)0.1); - SetDims(nOutput, nT); - - m_DefaultState = 0.0; - ForwardProp(FrameRange(m_pMBLayout)); - - // check with expected values - if (!ISCLOSE(Output()(0, 0), 0.0335975, EPSILON) || - !ISCLOSE(Output()(0, 1), 0.05485132, EPSILON) || - !ISCLOSE(Output()(0, 2), 0.06838435, EPSILON) || - !(Output()(0, 0) == Output()(1, 0))) - throw("LSTMNode forward computation error"); - - - Output().TransferToDeviceIfNotThere( m_deviceId, true); - - GradientValues().Resize(nOutput, nT); - GradientValues().SetValue(1.0); - for (size_t i = 0; i < 5; i++) - { - Input(i)->GradientValues().Resize(Input(i)->GetNumRows(), Input(i)->GetNumCols()); - Input(i)->GradientValues().SetValue(0); - } - for (size_t i = 0; i < 5; i++) - BackpropTo(i, FrameRange(m_pMBLayout)); - - // check with expected values - if (!ISCLOSE(Input(1)->GradientValues()(0, 0), 0.07843818, EPSILON) // bi - || !ISCLOSE(Input(1)->GradientValues()(0, 1), 0.00784382, EPSILON) // Wxi - || !ISCLOSE(Input(1)->GradientValues()(0, 3), 0.00192997, EPSILON) // Whi - || !ISCLOSE(Input(1)->GradientValues()(0, 6), 0.00362767, EPSILON) // Wci - ) - throw("LSTMNode gradient error on input gates"); - if (!ISCLOSE(Input(2)->GradientValues()(0, 0), 0.02738655, EPSILON) // bf - || !ISCLOSE(Input(2)->GradientValues()(0, 1), 0.00273866, EPSILON) // Wxf - || !ISCLOSE(Input(2)->GradientValues()(0, 3), 0.00120922, EPSILON) // Whf - || !ISCLOSE(Input(2)->GradientValues()(0, 6), 0.00227184, EPSILON) // Wcf - ) - throw("LSTMNode gradient error on forget gates"); - if (!ISCLOSE(Input(3)->GradientValues()(0, 0), 0.07801557, EPSILON) // bo - || !ISCLOSE(Input(3)->GradientValues()(0, 1), 0.00780156, EPSILON) // Wxo - || !ISCLOSE(Input(3)->GradientValues()(0, 3), 0.00268089, EPSILON) // Who - || !ISCLOSE(Input(3)->GradientValues()(0, 6), 0.00809852, EPSILON) // Wco - ) - throw("LSTMNode gradient error on output gates"); - if (!ISCLOSE(Input(4)->GradientValues()(0, 0), 1.3075038, EPSILON) // bc - || !ISCLOSE(Input(4)->GradientValues()(0, 1), 0.13075038, EPSILON) // Wxc - || !ISCLOSE(Input(4)->GradientValues()(0, 3), 0.03080355, EPSILON) // Whc - ) - throw("LSTMNode gradient error on memory cells"); - - for (size_t i = 0; i < 5; i++) - { - - Input(i)->GradientValues().TransferToDeviceIfNotThere( m_deviceId, true); - } - m_DefaultState = initStateValue; - } - - fprintf(stderr, "LSTMNode unit test passed!\n"); - return true; - } - - virtual void InferImageDimsFromInputs() - { - InferImageDimsFromInput(1, false); - } - - virtual void DumpNodeInfo(const bool printValues, File& fstream) const override - { - Base::DumpNodeInfo(printValues, fstream); - fstream << L"Input[Width:" << m_inputDim << L"] \n" ; - fstream << L"Hidden[Width:" << m_outputDim << L"] Output[Width:" << m_outputDim << L"] \n"; - } - public: - bool GetHistory(Matrix& hist, bool bLastTime) - { - size_t tRow = m_PastOutput.GetNumRows(); - size_t tCol = m_PastOutput.GetNumCols(); - size_t rCol = m_PastState.GetNumCols(); - - DEVICEID_TYPE device = hist.GetDeviceId(); - hist.TransferFromDeviceToDevice(device, m_deviceId, true); - hist.Resize(tRow, tCol + rCol); - - if (bLastTime) - { - hist.ColumnSlice(0, tCol).SetValue(mLastOutput); - hist.ColumnSlice(tCol, rCol).SetValue(mLastState); - } - else{ - hist.ColumnSlice(0, tCol).SetValue(m_PastOutput); - hist.ColumnSlice(tCol, rCol).SetValue(m_PastState); - } - - hist.TransferFromDeviceToDevice(m_deviceId, device, true); - return true; - } - - void SetHistory(const Matrix& hist) - { - size_t tRow = hist.GetNumRows(); - size_t tCol = hist.GetNumCols(); - size_t eCols = tCol / 2; - - DEVICEID_TYPE device = hist.GetDeviceId(); - hist.TransferFromDeviceToDevice(device, m_deviceId, true); - - m_PastOutput.Resize(tRow, eCols); - m_PastState.Resize(tRow, eCols); - m_PastOutput.SetValue(hist.ColumnSlice(0, eCols)); - m_PastState.SetValue(hist.ColumnSlice(eCols, eCols)); - - hist.TransferFromDeviceToDevice(m_deviceId, device, true); - } - - virtual void GetErrorsToPreviousMinibatch(Matrix& hist) - { - size_t tRow = m_obs_error_from_future_minibatch.GetNumRows(); - size_t tCol = m_obs_error_from_future_minibatch.GetNumCols(); - size_t rCol = m_state_error_from_future_minibatch.GetNumCols(); - - DEVICEID_TYPE device = hist.GetDeviceId(); - - hist.TransferFromDeviceToDevice(device, m_deviceId, true); - hist.Resize(tRow, tCol + rCol); - - hist.ColumnSlice(0, tCol).SetValue(m_obs_error_from_future_minibatch); - hist.ColumnSlice(tCol, rCol).SetValue(m_state_error_from_future_minibatch); - - hist.TransferFromDeviceToDevice(m_deviceId, device, true); - } - - virtual void SetErrorsFromFutureMinibatch(Matrix& hist) - { - size_t tCol = hist.GetNumCols(); - size_t rCol = tCol / 2; - - DEVICEID_TYPE device = hist.GetDeviceId(); - - hist.TransferFromDeviceToDevice(device, m_deviceId, true); - - m_obs_error_from_future_minibatch.SetValue(hist.ColumnSlice(0, rCol)); - m_state_error_from_future_minibatch.SetValue(hist.ColumnSlice(rCol, rCol)); - - m_use_errors_from_future_minibatch = true; - - hist.TransferFromDeviceToDevice(m_deviceId, device, true); - } - - protected: - size_t m_inputDim; - size_t m_outputDim; - - Matrix m_State; // hidden state activity - Matrix m_PastState; // state activity in the previous minibatch - Matrix m_PastOutput; // output in the previou minibatch - - Matrix mLastState; // last state activity - Matrix mLastOutput; // last output - - Matrix m_Gi; // input gate activity - Matrix m_Gf; // forget gate activity - Matrix m_Go; // output gate activity - - Matrix grdToObs, grdToInputGate, grdToForgetGate, grdToOutputGate, grdToCellWgt; - Matrix tanhState, tanhObs; - - Matrix m_tempMatrix; // temp matrix for speed-up - - bool m_GradientComputed; // true if LSTM node has computed gradients, set to false if forward computation is just finished - - Matrix mSlicePrevOutput, mSlicePrevState; - - Matrix grdBeforeInputGate, grdBeforeForget, grdBeforeGo, grdToCell, grdBeforeTanhInputGate; - - public: - // errors from future minibatch - Matrix m_obs_error_from_future_minibatch; - Matrix m_state_error_from_future_minibatch; - bool m_use_errors_from_future_minibatch; - - ElemType m_DefaultState; - - }; - - template class LSTMNode; - template class LSTMNode; - }}} diff --git a/Makefile b/Makefile index f6deaad13..e4023d77a 100644 --- a/Makefile +++ b/Makefile @@ -432,7 +432,6 @@ CNTK_SRC =\ MachineLearning/CNTKComputationNetworkLib/ComputationNetworkEditing.cpp \ MachineLearning/CNTKComputationNetworkLib/ComputationNetworkBuilder.cpp \ MachineLearning/CNTKComputationNetworkLib/ComputationNetworkScripting.cpp \ - MachineLearning/CNTKComputationNetworkLib/NetworkBuilderFromConfig.cpp \ MachineLearning/CNTKSGDLib/Profiler.cpp \ MachineLearning/CNTKSGDLib/SGD.cpp \ MachineLearning/CNTKActionsLib/TrainActions.cpp \