fixed the gradient computation bug in CosDistanceNode.
This commit is contained in:
Родитель
286b109087
Коммит
7a9bb9a64d
|
@ -4376,11 +4376,11 @@ protected: \
|
|||
|
||||
if (inputIndex == 0) //left derivative
|
||||
{
|
||||
ComputeInputPartialLeft(m_invNorm0, m_invNorm1, FunctionValues(), m_temp, m_rightTerm, m_leftTerm, Inputs(0)->FunctionValues(), Inputs(1)->FunctionValues(), Inputs(inputIndex)->GradientValues());
|
||||
ComputeInputPartialLeft(m_invNorm0, m_invNorm1, FunctionValues(), m_temp, m_rightTerm, m_leftTerm, Inputs(0)->FunctionValues(), Inputs(1)->FunctionValues(), GradientValues(), Inputs(inputIndex)->GradientValues());
|
||||
}
|
||||
else //right derivative
|
||||
{
|
||||
ComputeInputPartialRight(m_invNorm0, m_invNorm1, FunctionValues(), m_temp, m_rightTerm, m_leftTerm, Inputs(0)->FunctionValues(), Inputs(1)->FunctionValues(), Inputs(inputIndex)->GradientValues());
|
||||
ComputeInputPartialRight(m_invNorm0, m_invNorm1, FunctionValues(), m_temp, m_rightTerm, m_leftTerm, Inputs(0)->FunctionValues(), Inputs(1)->FunctionValues(), GradientValues(), Inputs(inputIndex)->GradientValues());
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -4393,31 +4393,32 @@ protected: \
|
|||
Matrix<ElemType> sliceInput1Value = Inputs(1)->FunctionValues().ColumnSlice(timeIdxInSeq * m_samplesInRecurrentStep, m_samplesInRecurrentStep);
|
||||
Matrix<ElemType> sliceOutputValue = m_functionValues.ColumnSlice(timeIdxInSeq * m_samplesInRecurrentStep, m_samplesInRecurrentStep);
|
||||
Matrix<ElemType> sliceInputGrad = Inputs(inputIndex)->GradientValues().ColumnSlice(timeIdxInSeq * m_samplesInRecurrentStep, m_samplesInRecurrentStep);
|
||||
Matrix<ElemType> sliceOutputGrad = this->GradientValues().ColumnSlice(timeIdxInSeq * m_samplesInRecurrentStep, m_samplesInRecurrentStep);
|
||||
|
||||
if (inputIndex == 0) //left derivative
|
||||
{
|
||||
ComputeInputPartialLeft(m_invNorm0, m_invNorm1, sliceOutputValue, m_temp, m_rightTerm, m_leftTerm, sliceInput0Value, sliceInput1Value, sliceInputGrad);
|
||||
ComputeInputPartialLeft(m_invNorm0, m_invNorm1, sliceOutputValue, m_temp, m_rightTerm, m_leftTerm, sliceInput0Value, sliceInput1Value, sliceOutputGrad, sliceInputGrad);
|
||||
}
|
||||
else //right derivative
|
||||
{
|
||||
ComputeInputPartialRight(m_invNorm0, m_invNorm1, sliceOutputValue, m_temp, m_rightTerm, m_leftTerm, sliceInput0Value, sliceInput1Value, sliceInputGrad);
|
||||
ComputeInputPartialRight(m_invNorm0, m_invNorm1, sliceOutputValue, m_temp, m_rightTerm, m_leftTerm, sliceInput0Value, sliceInput1Value, sliceOutputGrad, sliceInputGrad);
|
||||
}
|
||||
}
|
||||
|
||||
static void WINAPI ComputeInputPartialLeft(const Matrix<ElemType>& invNorm0, const Matrix<ElemType>& invNorm1, const Matrix<ElemType>& functionValues,
|
||||
Matrix<ElemType>& temp, Matrix<ElemType>& rightTerm, Matrix<ElemType>& leftTerm, // the temporary variables
|
||||
const Matrix<ElemType>& in0, const Matrix<ElemType>& in1,
|
||||
const Matrix<ElemType>& in0, const Matrix<ElemType>& in1, const Matrix<ElemType>& gradientValues,
|
||||
Matrix<ElemType>& inputGradientValues)
|
||||
{
|
||||
ComputeInputPartialS(0, invNorm0, invNorm1, functionValues, temp, rightTerm, leftTerm, in0, in1, inputGradientValues);
|
||||
ComputeInputPartialS(0, invNorm0, invNorm1, functionValues, temp, rightTerm, leftTerm, in0, in1, gradientValues, inputGradientValues);
|
||||
}
|
||||
|
||||
static void WINAPI ComputeInputPartialRight(const Matrix<ElemType>& invNorm0, const Matrix<ElemType>& invNorm1, const Matrix<ElemType>& functionValues,
|
||||
Matrix<ElemType>& temp, Matrix<ElemType>& rightTerm, Matrix<ElemType>& leftTerm, // the temporary variables
|
||||
const Matrix<ElemType>& in0, const Matrix<ElemType>& in1,
|
||||
const Matrix<ElemType>& in0, const Matrix<ElemType>& in1, const Matrix<ElemType>& gradientValues,
|
||||
Matrix<ElemType>& inputGradientValues)
|
||||
{
|
||||
ComputeInputPartialS(1, invNorm0, invNorm1, functionValues, temp, rightTerm, leftTerm, in0, in1, inputGradientValues);
|
||||
ComputeInputPartialS(1, invNorm0, invNorm1, functionValues, temp, rightTerm, leftTerm, in0, in1, gradientValues, inputGradientValues);
|
||||
}
|
||||
|
||||
// functionValues, invNorm0, invNorm1 - output from the EvaluateNode() method
|
||||
|
@ -4426,7 +4427,7 @@ protected: \
|
|||
// inputGradientValues(x) - gradients to update, where x matches inputIndex
|
||||
static void WINAPI ComputeInputPartialS(const size_t inputIndex, const Matrix<ElemType>& invNorm0, const Matrix<ElemType>& invNorm1, const Matrix<ElemType>& functionValues,
|
||||
Matrix<ElemType>& temp, Matrix<ElemType>& rightTerm, Matrix<ElemType>& leftTerm, // the temporary variables
|
||||
const Matrix<ElemType>& in0, const Matrix<ElemType>& in1,
|
||||
const Matrix<ElemType>& in0, const Matrix<ElemType>& in1, const Matrix<ElemType>& gradientValues,
|
||||
Matrix<ElemType>& inputGradientValues)
|
||||
{
|
||||
if (inputIndex == 0) //left derivative
|
||||
|
@ -4446,7 +4447,14 @@ protected: \
|
|||
leftTerm.SetValue(inputIndex?in0:in1);
|
||||
leftTerm.RowElementMultiplyWith(temp);
|
||||
|
||||
Matrix<ElemType>::AddScaledDifference(1, leftTerm, rightTerm, inputGradientValues);
|
||||
leftTerm -= rightTerm;
|
||||
leftTerm.RowElementMultiplyWith(gradientValues);
|
||||
inputGradientValues += leftTerm;
|
||||
|
||||
//alternatively the above three lines can be replaced by
|
||||
//leftTerm.RowElementMultiplyWith(gradientValues);
|
||||
//rightTerm.RowElementMultiplyWith(gradientValues);
|
||||
//Matrix<ElemType>::AddScaledDifference(1, leftTerm, rightTerm, inputGradientValues);
|
||||
}
|
||||
|
||||
// GetTaskDescriptor - Get a task descriptor for this node
|
||||
|
@ -4465,7 +4473,8 @@ protected: \
|
|||
descriptor->MatrixParam(m_leftTerm, "leftTerm", paramOptionsInput | paramOptionsTemporary);
|
||||
descriptor->FunctionParam(0, paramOptionsInput);
|
||||
descriptor->FunctionParam(1, paramOptionsInput);
|
||||
descriptor->GradientParam(inputIndex,paramOptionsInput | paramOptionsOutput | paramOptionsInitialize);
|
||||
descriptor->GradientParam(-1, paramOptionsInput);
|
||||
descriptor->GradientParam(inputIndex, paramOptionsInput | paramOptionsOutput | paramOptionsInitialize);
|
||||
descriptor->SetFunction(inputIndex?(FARPROC)ComputeInputPartialRight:(FARPROC)ComputeInputPartialLeft);
|
||||
break;
|
||||
case taskEvaluate:
|
||||
|
|
Загрузка…
Ссылка в новой задаче