fixed the gradient computation bug in CosDistanceNode.

This commit is contained in:
Dong Yu 2014-12-02 20:02:49 -08:00
Родитель 286b109087
Коммит 7a9bb9a64d
1 изменённых файлов: 20 добавлений и 11 удалений

Просмотреть файл

@ -4376,11 +4376,11 @@ protected: \
if (inputIndex == 0) //left derivative
{
ComputeInputPartialLeft(m_invNorm0, m_invNorm1, FunctionValues(), m_temp, m_rightTerm, m_leftTerm, Inputs(0)->FunctionValues(), Inputs(1)->FunctionValues(), Inputs(inputIndex)->GradientValues());
ComputeInputPartialLeft(m_invNorm0, m_invNorm1, FunctionValues(), m_temp, m_rightTerm, m_leftTerm, Inputs(0)->FunctionValues(), Inputs(1)->FunctionValues(), GradientValues(), Inputs(inputIndex)->GradientValues());
}
else //right derivative
{
ComputeInputPartialRight(m_invNorm0, m_invNorm1, FunctionValues(), m_temp, m_rightTerm, m_leftTerm, Inputs(0)->FunctionValues(), Inputs(1)->FunctionValues(), Inputs(inputIndex)->GradientValues());
ComputeInputPartialRight(m_invNorm0, m_invNorm1, FunctionValues(), m_temp, m_rightTerm, m_leftTerm, Inputs(0)->FunctionValues(), Inputs(1)->FunctionValues(), GradientValues(), Inputs(inputIndex)->GradientValues());
}
}
@ -4393,31 +4393,32 @@ protected: \
Matrix<ElemType> sliceInput1Value = Inputs(1)->FunctionValues().ColumnSlice(timeIdxInSeq * m_samplesInRecurrentStep, m_samplesInRecurrentStep);
Matrix<ElemType> sliceOutputValue = m_functionValues.ColumnSlice(timeIdxInSeq * m_samplesInRecurrentStep, m_samplesInRecurrentStep);
Matrix<ElemType> sliceInputGrad = Inputs(inputIndex)->GradientValues().ColumnSlice(timeIdxInSeq * m_samplesInRecurrentStep, m_samplesInRecurrentStep);
Matrix<ElemType> sliceOutputGrad = this->GradientValues().ColumnSlice(timeIdxInSeq * m_samplesInRecurrentStep, m_samplesInRecurrentStep);
if (inputIndex == 0) //left derivative
{
ComputeInputPartialLeft(m_invNorm0, m_invNorm1, sliceOutputValue, m_temp, m_rightTerm, m_leftTerm, sliceInput0Value, sliceInput1Value, sliceInputGrad);
ComputeInputPartialLeft(m_invNorm0, m_invNorm1, sliceOutputValue, m_temp, m_rightTerm, m_leftTerm, sliceInput0Value, sliceInput1Value, sliceOutputGrad, sliceInputGrad);
}
else //right derivative
{
ComputeInputPartialRight(m_invNorm0, m_invNorm1, sliceOutputValue, m_temp, m_rightTerm, m_leftTerm, sliceInput0Value, sliceInput1Value, sliceInputGrad);
ComputeInputPartialRight(m_invNorm0, m_invNorm1, sliceOutputValue, m_temp, m_rightTerm, m_leftTerm, sliceInput0Value, sliceInput1Value, sliceOutputGrad, sliceInputGrad);
}
}
static void WINAPI ComputeInputPartialLeft(const Matrix<ElemType>& invNorm0, const Matrix<ElemType>& invNorm1, const Matrix<ElemType>& functionValues,
Matrix<ElemType>& temp, Matrix<ElemType>& rightTerm, Matrix<ElemType>& leftTerm, // the temporary variables
const Matrix<ElemType>& in0, const Matrix<ElemType>& in1,
const Matrix<ElemType>& in0, const Matrix<ElemType>& in1, const Matrix<ElemType>& gradientValues,
Matrix<ElemType>& inputGradientValues)
{
ComputeInputPartialS(0, invNorm0, invNorm1, functionValues, temp, rightTerm, leftTerm, in0, in1, inputGradientValues);
ComputeInputPartialS(0, invNorm0, invNorm1, functionValues, temp, rightTerm, leftTerm, in0, in1, gradientValues, inputGradientValues);
}
static void WINAPI ComputeInputPartialRight(const Matrix<ElemType>& invNorm0, const Matrix<ElemType>& invNorm1, const Matrix<ElemType>& functionValues,
Matrix<ElemType>& temp, Matrix<ElemType>& rightTerm, Matrix<ElemType>& leftTerm, // the temporary variables
const Matrix<ElemType>& in0, const Matrix<ElemType>& in1,
const Matrix<ElemType>& in0, const Matrix<ElemType>& in1, const Matrix<ElemType>& gradientValues,
Matrix<ElemType>& inputGradientValues)
{
ComputeInputPartialS(1, invNorm0, invNorm1, functionValues, temp, rightTerm, leftTerm, in0, in1, inputGradientValues);
ComputeInputPartialS(1, invNorm0, invNorm1, functionValues, temp, rightTerm, leftTerm, in0, in1, gradientValues, inputGradientValues);
}
// functionValues, invNorm0, invNorm1 - output from the EvaluateNode() method
@ -4426,7 +4427,7 @@ protected: \
// inputGradientValues(x) - gradients to update, where x matches inputIndex
static void WINAPI ComputeInputPartialS(const size_t inputIndex, const Matrix<ElemType>& invNorm0, const Matrix<ElemType>& invNorm1, const Matrix<ElemType>& functionValues,
Matrix<ElemType>& temp, Matrix<ElemType>& rightTerm, Matrix<ElemType>& leftTerm, // the temporary variables
const Matrix<ElemType>& in0, const Matrix<ElemType>& in1,
const Matrix<ElemType>& in0, const Matrix<ElemType>& in1, const Matrix<ElemType>& gradientValues,
Matrix<ElemType>& inputGradientValues)
{
if (inputIndex == 0) //left derivative
@ -4446,7 +4447,14 @@ protected: \
leftTerm.SetValue(inputIndex?in0:in1);
leftTerm.RowElementMultiplyWith(temp);
Matrix<ElemType>::AddScaledDifference(1, leftTerm, rightTerm, inputGradientValues);
leftTerm -= rightTerm;
leftTerm.RowElementMultiplyWith(gradientValues);
inputGradientValues += leftTerm;
//alternatively the above three lines can be replaced by
//leftTerm.RowElementMultiplyWith(gradientValues);
//rightTerm.RowElementMultiplyWith(gradientValues);
//Matrix<ElemType>::AddScaledDifference(1, leftTerm, rightTerm, inputGradientValues);
}
// GetTaskDescriptor - Get a task descriptor for this node
@ -4465,7 +4473,8 @@ protected: \
descriptor->MatrixParam(m_leftTerm, "leftTerm", paramOptionsInput | paramOptionsTemporary);
descriptor->FunctionParam(0, paramOptionsInput);
descriptor->FunctionParam(1, paramOptionsInput);
descriptor->GradientParam(inputIndex,paramOptionsInput | paramOptionsOutput | paramOptionsInitialize);
descriptor->GradientParam(-1, paramOptionsInput);
descriptor->GradientParam(inputIndex, paramOptionsInput | paramOptionsOutput | paramOptionsInitialize);
descriptor->SetFunction(inputIndex?(FARPROC)ComputeInputPartialRight:(FARPROC)ComputeInputPartialLeft);
break;
case taskEvaluate: