diff --git a/MachineLearning/cn/ComputationNode.h b/MachineLearning/cn/ComputationNode.h index eb23b5b33..9e294e863 100644 --- a/MachineLearning/cn/ComputationNode.h +++ b/MachineLearning/cn/ComputationNode.h @@ -4376,11 +4376,11 @@ protected: \ if (inputIndex == 0) //left derivative { - ComputeInputPartialLeft(m_invNorm0, m_invNorm1, FunctionValues(), m_temp, m_rightTerm, m_leftTerm, Inputs(0)->FunctionValues(), Inputs(1)->FunctionValues(), Inputs(inputIndex)->GradientValues()); + ComputeInputPartialLeft(m_invNorm0, m_invNorm1, FunctionValues(), m_temp, m_rightTerm, m_leftTerm, Inputs(0)->FunctionValues(), Inputs(1)->FunctionValues(), GradientValues(), Inputs(inputIndex)->GradientValues()); } else //right derivative { - ComputeInputPartialRight(m_invNorm0, m_invNorm1, FunctionValues(), m_temp, m_rightTerm, m_leftTerm, Inputs(0)->FunctionValues(), Inputs(1)->FunctionValues(), Inputs(inputIndex)->GradientValues()); + ComputeInputPartialRight(m_invNorm0, m_invNorm1, FunctionValues(), m_temp, m_rightTerm, m_leftTerm, Inputs(0)->FunctionValues(), Inputs(1)->FunctionValues(), GradientValues(), Inputs(inputIndex)->GradientValues()); } } @@ -4393,31 +4393,32 @@ protected: \ Matrix sliceInput1Value = Inputs(1)->FunctionValues().ColumnSlice(timeIdxInSeq * m_samplesInRecurrentStep, m_samplesInRecurrentStep); Matrix sliceOutputValue = m_functionValues.ColumnSlice(timeIdxInSeq * m_samplesInRecurrentStep, m_samplesInRecurrentStep); Matrix sliceInputGrad = Inputs(inputIndex)->GradientValues().ColumnSlice(timeIdxInSeq * m_samplesInRecurrentStep, m_samplesInRecurrentStep); + Matrix sliceOutputGrad = this->GradientValues().ColumnSlice(timeIdxInSeq * m_samplesInRecurrentStep, m_samplesInRecurrentStep); if (inputIndex == 0) //left derivative { - ComputeInputPartialLeft(m_invNorm0, m_invNorm1, sliceOutputValue, m_temp, m_rightTerm, m_leftTerm, sliceInput0Value, sliceInput1Value, sliceInputGrad); + ComputeInputPartialLeft(m_invNorm0, m_invNorm1, sliceOutputValue, m_temp, m_rightTerm, m_leftTerm, sliceInput0Value, sliceInput1Value, sliceOutputGrad, sliceInputGrad); } else //right derivative { - ComputeInputPartialRight(m_invNorm0, m_invNorm1, sliceOutputValue, m_temp, m_rightTerm, m_leftTerm, sliceInput0Value, sliceInput1Value, sliceInputGrad); + ComputeInputPartialRight(m_invNorm0, m_invNorm1, sliceOutputValue, m_temp, m_rightTerm, m_leftTerm, sliceInput0Value, sliceInput1Value, sliceOutputGrad, sliceInputGrad); } } static void WINAPI ComputeInputPartialLeft(const Matrix& invNorm0, const Matrix& invNorm1, const Matrix& functionValues, Matrix& temp, Matrix& rightTerm, Matrix& leftTerm, // the temporary variables - const Matrix& in0, const Matrix& in1, + const Matrix& in0, const Matrix& in1, const Matrix& gradientValues, Matrix& inputGradientValues) { - ComputeInputPartialS(0, invNorm0, invNorm1, functionValues, temp, rightTerm, leftTerm, in0, in1, inputGradientValues); + ComputeInputPartialS(0, invNorm0, invNorm1, functionValues, temp, rightTerm, leftTerm, in0, in1, gradientValues, inputGradientValues); } static void WINAPI ComputeInputPartialRight(const Matrix& invNorm0, const Matrix& invNorm1, const Matrix& functionValues, Matrix& temp, Matrix& rightTerm, Matrix& leftTerm, // the temporary variables - const Matrix& in0, const Matrix& in1, + const Matrix& in0, const Matrix& in1, const Matrix& gradientValues, Matrix& inputGradientValues) { - ComputeInputPartialS(1, invNorm0, invNorm1, functionValues, temp, rightTerm, leftTerm, in0, in1, inputGradientValues); + ComputeInputPartialS(1, invNorm0, invNorm1, functionValues, temp, rightTerm, leftTerm, in0, in1, gradientValues, inputGradientValues); } // functionValues, invNorm0, invNorm1 - output from the EvaluateNode() method @@ -4426,7 +4427,7 @@ protected: \ // inputGradientValues(x) - gradients to update, where x matches inputIndex static void WINAPI ComputeInputPartialS(const size_t inputIndex, const Matrix& invNorm0, const Matrix& invNorm1, const Matrix& functionValues, Matrix& temp, Matrix& rightTerm, Matrix& leftTerm, // the temporary variables - const Matrix& in0, const Matrix& in1, + const Matrix& in0, const Matrix& in1, const Matrix& gradientValues, Matrix& inputGradientValues) { if (inputIndex == 0) //left derivative @@ -4446,7 +4447,14 @@ protected: \ leftTerm.SetValue(inputIndex?in0:in1); leftTerm.RowElementMultiplyWith(temp); - Matrix::AddScaledDifference(1, leftTerm, rightTerm, inputGradientValues); + leftTerm -= rightTerm; + leftTerm.RowElementMultiplyWith(gradientValues); + inputGradientValues += leftTerm; + + //alternatively the above three lines can be replaced by + //leftTerm.RowElementMultiplyWith(gradientValues); + //rightTerm.RowElementMultiplyWith(gradientValues); + //Matrix::AddScaledDifference(1, leftTerm, rightTerm, inputGradientValues); } // GetTaskDescriptor - Get a task descriptor for this node @@ -4465,7 +4473,8 @@ protected: \ descriptor->MatrixParam(m_leftTerm, "leftTerm", paramOptionsInput | paramOptionsTemporary); descriptor->FunctionParam(0, paramOptionsInput); descriptor->FunctionParam(1, paramOptionsInput); - descriptor->GradientParam(inputIndex,paramOptionsInput | paramOptionsOutput | paramOptionsInitialize); + descriptor->GradientParam(-1, paramOptionsInput); + descriptor->GradientParam(inputIndex, paramOptionsInput | paramOptionsOutput | paramOptionsInitialize); descriptor->SetFunction(inputIndex?(FARPROC)ComputeInputPartialRight:(FARPROC)ComputeInputPartialLeft); break; case taskEvaluate: