merge linear and loss
This commit is contained in:
Родитель
d0675c2731
Коммит
9bd990319a
|
@ -596,7 +596,7 @@ EditDistanceError(leftInput, rightInput, subPen=1.0, delPen=1.0, insPen=1.0, squ
|
|||
RNNTError(leftInput, rightInput, mergedinput, tokensToIgnore=[||], tag='') = new ComputationNode [ operation = 'RNNTError' ; inputs = _AsNodes (leftInput : rightInput: mergedinput) /*plus the function args*/ ]
|
||||
LatticeSequenceWithSoftmax(labels, evaluation, scaledLogLikelihood, lattice, symListPath, phonePath, stateListPath, transProbPath, latticeConfigPath = "LatticeNode.config", hSmoothingWeight = 0.95, frameDropThresh = 1e-10, doReferenceAlign = false, seqGammarUsesMBR = false, seqGammarAMF = 14.0, seqGammarLMF = 14.0, seqGammarBMMIFactor = 0.0, seqGammarWordPen = 0.0, tag='') = new ComputationNode [ operation = 'LatticeSequenceWithSoftmax' ; inputs = _AsNodes (labels : evaluation : scaledLogLikelihood : lattice) /*plus the function args*/ ]
|
||||
ForwardBackward(graph, features, blankTokenId, delayConstraint=-1, tag='') = new ComputationNode [ operation = 'ForwardBackward' ; inputs = _AsNodes (graph : features) /*plus the function args*/ ]
|
||||
RNNT(graph, transcription, prediction,mergedinput, blankTokenId,delayConstraint=-1, tag='') = new ComputationNode [ operation = 'RNNT' ; inputs = _AsNodes (graph : transcription: prediction: mergedinput) /*plus the function args*/ ]
|
||||
RNNT(graph, transcription, prediction,mergedinput, W, b, blankTokenId,delayConstraint=-1, tag='') = new ComputationNode [ operation = 'RNNT' ; inputs = _AsNodes (graph : transcription: prediction: mergedinput: W: b) /*plus the function args*/ ]
|
||||
LabelsToGraph(labels, tag='') = new ComputationNode [ operation = 'LabelsToGraph' ; inputs = _AsNodes (labels) /*plus the function args*/ ]
|
||||
StopGradient(input, tag='') = new ComputationNode [ operation = 'StopGradient' ; inputs = _AsNodes (input) /*plus the function args*/ ]
|
||||
Slice(beginIndex, endIndex, input, axis=1, tag='') =
|
||||
|
|
|
@ -2243,13 +2243,13 @@ namespace CNTK
|
|||
|
||||
return BinaryOp(PrimitiveOpType::ForwardBackward, graph, features, std::move(additionalProperties), name);
|
||||
}
|
||||
FunctionPtr RNNT(const Variable& graph, const Variable& transcription, const Variable& prediction, const Variable& mergedinput, size_t blankTokenId, int delayConstraint, const std::wstring& name)
|
||||
FunctionPtr RNNT(const Variable& graph, const Variable& transcription, const Variable& prediction, const Variable& mergedinput, const Variable& W, const Variable& b, size_t blankTokenId, int delayConstraint, const std::wstring& name)
|
||||
{
|
||||
auto additionalProperties = Dictionary();
|
||||
additionalProperties[PrimitiveFunctionAttribute::AttributeNameBlankTokenId] = blankTokenId;
|
||||
additionalProperties[PrimitiveFunctionAttribute::AttributeNameDelayConstraint] = delayConstraint;
|
||||
|
||||
std::vector<Variable> operands = {graph, transcription, prediction, mergedinput};
|
||||
std::vector<Variable> operands = {graph, transcription, prediction, mergedinput, W, b};
|
||||
return AsComposite(MakeSharedObject<PrimitiveFunction>(PrimitiveOpType::RNNT, operands, std::move(additionalProperties), name), name);
|
||||
|
||||
//return BinaryOp(PrimitiveOpType::RNNT, graph,transcription, prediction, std::move(additionalProperties), name);
|
||||
|
|
|
@ -545,9 +545,10 @@ shared_ptr<ComputationNode<ElemType>> ComputationNetworkBuilder<ElemType>::Forwa
|
|||
return net.AddNodeToNetAndAttachInputs(New<ForwardBackwardNode<ElemType>>(net.GetDeviceId(), nodeName, blankTokenId, delayConstraint), { graph, features });
|
||||
}
|
||||
template <class ElemType>
|
||||
shared_ptr<ComputationNode<ElemType>> ComputationNetworkBuilder<ElemType>::RNNT(const ComputationNodePtr graph, const ComputationNodePtr transcription, const ComputationNodePtr prediction, const ComputationNodePtr mergedinput, int blankTokenId, int delayConstraint, const std::wstring nodeName)
|
||||
shared_ptr<ComputationNode<ElemType>> ComputationNetworkBuilder<ElemType>::RNNT(const ComputationNodePtr graph, const ComputationNodePtr transcription, const ComputationNodePtr prediction, const ComputationNodePtr mergedinput,
|
||||
const ComputationNodePtr W, const ComputationNodePtr b, int blankTokenId, int delayConstraint, const std::wstring nodeName)
|
||||
{
|
||||
return net.AddNodeToNetAndAttachInputs(New<RNNTNode<ElemType>>(net.GetDeviceId(), nodeName, blankTokenId, delayConstraint), { graph, transcription,prediction, mergedinput});
|
||||
return net.AddNodeToNetAndAttachInputs(New<RNNTNode<ElemType>>(net.GetDeviceId(), nodeName, blankTokenId, delayConstraint), { graph, transcription,prediction, mergedinput, W, b});
|
||||
}
|
||||
|
||||
template <class ElemType>
|
||||
|
|
|
@ -188,7 +188,7 @@ public:
|
|||
ComputationNodePtr RandomSampleInclusionFrequency(const ComputationNodePtr a, const std::wstring nodeName = L"");
|
||||
ComputationNodePtr RectifiedLinear(const ComputationNodePtr a, const std::wstring nodeName = L"");
|
||||
ComputationNodePtr Reshape(const ComputationNodePtr a, const TensorShape& imageLayout, const std::wstring nodeName = L"");
|
||||
ComputationNodePtr RNNT(const ComputationNodePtr graph, const ComputationNodePtr transcription, const ComputationNodePtr prediction, const ComputationNodePtr mergedinput, int blankTokenId, int delayConstraint, const std::wstring nodeName = L"");
|
||||
ComputationNodePtr RNNT(const ComputationNodePtr graph, const ComputationNodePtr transcription, const ComputationNodePtr prediction, const ComputationNodePtr mergedinput, const ComputationNodePtr W, const ComputationNodePtr b, int blankTokenId, int delayConstraint, const std::wstring nodeName = L"");
|
||||
ComputationNodePtr RowRepeat(const ComputationNodePtr a, const size_t num_repeat, const std::wstring nodeName = L"");
|
||||
ComputationNodePtr RowSlice(const ComputationNodePtr a, const size_t start_index, const size_t num_rows, const std::wstring nodeName = L"");
|
||||
ComputationNodePtr RowStack(const std::vector<ComputationNodePtr> pinputs, const std::wstring nodeName = L"");
|
||||
|
|
|
@ -1576,7 +1576,7 @@ template class CustomProxyOpNode<float>;
|
|||
// -----------------------------------------------------------------------
|
||||
|
||||
template <class ElemType>
|
||||
class RNNTNode : public ComputationNodeNonLooping<ElemType>, public NumInputs<4>
|
||||
class RNNTNode : public ComputationNodeNonLooping<ElemType>, public NumInputs<6>
|
||||
{
|
||||
typedef ComputationNodeNonLooping<ElemType> Base;
|
||||
UsingComputationNodeMembersBoilerplate;
|
||||
|
@ -1611,7 +1611,16 @@ public:
|
|||
}
|
||||
else if (inputIndex == 3)
|
||||
{
|
||||
BackpropToMerge(InputRef(inputIndex).Gradient(), Gradient(), *m_derivative);
|
||||
BackpropToX(InputRef(inputIndex).Gradient(), Gradient(), *m_derivative, InputRef(4).Value());
|
||||
|
||||
}
|
||||
else if (inputIndex == 5)
|
||||
{
|
||||
BackpropToB(InputRef(inputIndex).Gradient(), Gradient(), *m_derivative);
|
||||
}
|
||||
else if (inputIndex == 4)
|
||||
{
|
||||
BackpropToW(InputRef(inputIndex).Gradient(), Gradient(), *m_derivative, InputRef(3).Value());
|
||||
}
|
||||
else
|
||||
RuntimeError("RNNTNode criterion expects only two inputs: labels and network output.");
|
||||
|
@ -1634,7 +1643,7 @@ public:
|
|||
#endif
|
||||
}
|
||||
|
||||
void BackpropToMerge(Matrix<ElemType>& inputGradientValues, const Matrix<ElemType>& gradientValues,
|
||||
void BackpropToB(Matrix<ElemType>& inputGradientValues, const Matrix<ElemType>& gradientValues,
|
||||
Matrix<ElemType>& RNNTDerivative)
|
||||
{
|
||||
#if DUMPOUTPUT
|
||||
|
@ -1646,7 +1655,54 @@ public:
|
|||
//m_tmpMatrix->AssignUserOp2(RNNTDerivative, InputRef(2).Value().GetNumCols(), InputRef(1).Value().GetNumCols(), InputRef(0).GetMBLayout()->GetNumParallelSequences(), 0);
|
||||
//m_tmpMatrix->TransferFromDeviceToDevice(CPUDEVICE, InputRef(0).Value().GetDeviceId());
|
||||
// inputGradientValues+= gradientValues*(softmaxOfRight - CTCposterior)
|
||||
Matrix<ElemType>::Scale(gradientValues.Get00Element(), RNNTDerivative, inputGradientValues);
|
||||
Matrix<ElemType>::Scale(gradientValues.Get00Element(), RNNTDerivative, *m_tmpMatrix);
|
||||
Matrix<ElemType>::VectorSum(*m_tmpMatrix, inputGradientValues, false);
|
||||
//inputGradientValues.Print("gradient");
|
||||
/*printf("back to F\n");
|
||||
if (gradientValues.GetDeviceId() != CPUDEVICE)
|
||||
printf("gradientValues after F is in GPU\n");*/
|
||||
#if DUMPOUTPUT
|
||||
inputGradientValues.Print("RNNTNode Partial-Right");
|
||||
#endif
|
||||
}
|
||||
|
||||
void BackpropToW(Matrix<ElemType>& inputGradientValues, const Matrix<ElemType>& gradientValues,
|
||||
Matrix<ElemType>& RNNTDerivative, Matrix<ElemType>& inputValue)
|
||||
{
|
||||
#if DUMPOUTPUT
|
||||
inputFunctionValues.Print("RNNTNode Partial-inputFunctionValues");
|
||||
gradientValues.Print("RNNTNode Partial-gradientValues");
|
||||
inputGradientValues.Print("RNNTNode Partial-Right-in");
|
||||
#endif
|
||||
//sum u for RNNT Derivative
|
||||
//m_tmpMatrix->AssignUserOp2(RNNTDerivative, InputRef(2).Value().GetNumCols(), InputRef(1).Value().GetNumCols(), InputRef(0).GetMBLayout()->GetNumParallelSequences(), 0);
|
||||
//m_tmpMatrix->TransferFromDeviceToDevice(CPUDEVICE, InputRef(0).Value().GetDeviceId());
|
||||
// inputGradientValues+= gradientValues*(softmaxOfRight - CTCposterior)
|
||||
Matrix<ElemType>::Scale(gradientValues.Get00Element(), RNNTDerivative, *m_tmpMatrix);
|
||||
inputGradientValues.AssignProductOf(inputValue, false, *m_tmpMatrix, true);
|
||||
//inputGradientValues.Print("gradient");
|
||||
/*printf("back to F\n");
|
||||
if (gradientValues.GetDeviceId() != CPUDEVICE)
|
||||
printf("gradientValues after F is in GPU\n");*/
|
||||
#if DUMPOUTPUT
|
||||
inputGradientValues.Print("RNNTNode Partial-Right");
|
||||
#endif
|
||||
}
|
||||
|
||||
void BackpropToX(Matrix<ElemType>& inputGradientValues, const Matrix<ElemType>& gradientValues,
|
||||
Matrix<ElemType>& RNNTDerivative, Matrix<ElemType>& inputValue)
|
||||
{
|
||||
#if DUMPOUTPUT
|
||||
inputFunctionValues.Print("RNNTNode Partial-inputFunctionValues");
|
||||
gradientValues.Print("RNNTNode Partial-gradientValues");
|
||||
inputGradientValues.Print("RNNTNode Partial-Right-in");
|
||||
#endif
|
||||
//sum u for RNNT Derivative
|
||||
//m_tmpMatrix->AssignUserOp2(RNNTDerivative, InputRef(2).Value().GetNumCols(), InputRef(1).Value().GetNumCols(), InputRef(0).GetMBLayout()->GetNumParallelSequences(), 0);
|
||||
//m_tmpMatrix->TransferFromDeviceToDevice(CPUDEVICE, InputRef(0).Value().GetDeviceId());
|
||||
// inputGradientValues+= gradientValues*(softmaxOfRight - CTCposterior)
|
||||
Matrix<ElemType>::Scale(gradientValues.Get00Element(), RNNTDerivative, *m_tmpMatrix);
|
||||
inputGradientValues.AssignProductOf(inputValue, false, *m_tmpMatrix, false);
|
||||
//inputGradientValues.Print("gradient");
|
||||
/*printf("back to F\n");
|
||||
if (gradientValues.GetDeviceId() != CPUDEVICE)
|
||||
|
@ -1703,11 +1759,15 @@ public:
|
|||
|
||||
//m_RNNTDerivative->SwitchToMatrixType(m_outputLogDistribution->GetMatrixType(), m_outputLogDistribution->GetFormat(), false);
|
||||
//m_RNNTDerivative->Resize(m_outputLogDistribution->GetNumRows(), m_outputLogDistribution->GetNumCols());
|
||||
|
||||
m_outputDensity->AssignProductOf(InputRef(4).Value(), true, InputRef(3).Value(), false);
|
||||
m_outputDensity->AssignSumOf(*m_outputDensity, InputRef(5).Value());
|
||||
m_outputDensity->InplaceLogSoftmax(true);
|
||||
FrameRange fr(InputRef(0).GetMBLayout());
|
||||
InputRef(0).ValueFor(fr).VectorMax(*m_maxIndexes, *m_maxValues, true);
|
||||
|
||||
|
||||
// compute CTC score
|
||||
m_GammaCal.twodimForwardBackward(Value(), InputRef(1).Value(), InputRef(2).Value(), InputRef(3).Value(), *m_maxIndexes, *m_derivative, InputRef(1).GetMBLayout(), InputRef(2).GetMBLayout(), m_blankTokenId);
|
||||
m_GammaCal.twodimForwardBackward(Value(), InputRef(1).Value(), InputRef(2).Value(), *m_outputDensity, *m_maxIndexes, *m_derivative, InputRef(1).GetMBLayout(), InputRef(2).GetMBLayout(), m_blankTokenId);
|
||||
|
||||
#if NANCHECK
|
||||
functionValues.HasNan("RNNTNode");
|
||||
|
@ -1724,8 +1784,7 @@ public:
|
|||
|
||||
if (isFinalValidationPass)
|
||||
{
|
||||
if (!(Input(0)->GetSampleMatrixNumRows() == Input(3)->GetSampleMatrixNumRows() && // match vector dimension
|
||||
Input(0)->HasMBLayout() &&
|
||||
if (!(Input(0)->HasMBLayout() &&
|
||||
Input(0)->GetMBLayout() == Input(2)->GetMBLayout()))
|
||||
{
|
||||
LogicError("The Matrix dimension in the RNNTNode operation does not match.");
|
||||
|
@ -1751,6 +1810,7 @@ public:
|
|||
node->m_derivativeForG->SetValue(*m_derivative);
|
||||
node->m_maxIndexes->SetValue(*m_maxIndexes);
|
||||
node->m_maxValues->SetValue(*m_maxValues);
|
||||
node->m_outputDensity->SetValue(*m_outputDensity);
|
||||
node->m_delayConstraint = m_delayConstraint;
|
||||
//node->m_RNNTDerivative->SetValue(*m_RNNTDerivative);
|
||||
node->m_tmpMatrix->SetValue(*m_tmpMatrix);
|
||||
|
@ -1762,6 +1822,7 @@ public:
|
|||
{
|
||||
Base::RequestMatricesBeforeForwardProp(matrixPool);
|
||||
RequestMatrixFromPool(m_derivativeForG, matrixPool);
|
||||
RequestMatrixFromPool(m_outputDensity, matrixPool);
|
||||
RequestMatrixFromPool(m_derivative, matrixPool);
|
||||
//RequestMatrixFromPool(m_outputDistribution, matrixPool);
|
||||
RequestMatrixFromPool(m_maxIndexes, matrixPool);
|
||||
|
@ -1774,6 +1835,7 @@ public:
|
|||
{
|
||||
Base::ReleaseMatricesAfterBackprop(matrixPool);
|
||||
ReleaseMatrixToPool(m_derivativeForG, matrixPool);
|
||||
ReleaseMatrixToPool(m_outputDensity, matrixPool);
|
||||
ReleaseMatrixToPool(m_derivative, matrixPool);
|
||||
//ReleaseMatrixToPool(m_outputDistribution, matrixPool);
|
||||
ReleaseMatrixToPool(m_maxIndexes, matrixPool);
|
||||
|
|
|
@ -540,20 +540,21 @@ public:
|
|||
//matrixOutputDistribution.Print("h");
|
||||
//log softmax of f+g
|
||||
//mergedinput.InplaceLogSoftmax(true);
|
||||
Microsoft::MSR::CNTK::Matrix<ElemType> logsoftmax(m_deviceid_gpu);
|
||||
|
||||
/*Microsoft::MSR::CNTK::Matrix<ElemType> logsoftmax(m_deviceid_gpu);
|
||||
logsoftmax.SetValue(mergedinput);
|
||||
|
||||
logsoftmax.InplaceLogSoftmax(true);
|
||||
logsoftmax.InplaceLogSoftmax(true);*/
|
||||
//matrixOutputDistribution.Print("prob");
|
||||
// forward backward to compute alpha, beta derivaitves
|
||||
Microsoft::MSR::CNTK::Matrix<ElemType> alpha(m_deviceid_gpu);
|
||||
Microsoft::MSR::CNTK::Matrix<ElemType> beta(m_deviceid_gpu);
|
||||
m_derivative.TransferToDeviceIfNotThere(m_deviceid_gpu);
|
||||
m_derivative.AssignRNNTScore(logsoftmax, alpha, beta, matrixPhoneSeqs, matrixPhoneSeqs, uttFrameToChanInd, uttFrameBeginIdx, uttBeginForOutputditribution, uttPhoneToChanInd, uttPhoneBeginIdx,
|
||||
m_derivative.AssignRNNTScore(mergedinput, alpha, beta, matrixPhoneSeqs, matrixPhoneSeqs, uttFrameToChanInd, uttFrameBeginIdx, uttBeginForOutputditribution, uttPhoneToChanInd, uttPhoneBeginIdx,
|
||||
uttFrameNum, uttPhoneNum, numParallelSequences, numPhoneParallelSequences, maxPhoneNum, maxFrameNum, totalScore, blankTokenId, -1,true);
|
||||
|
||||
logsoftmax.InplaceExp();
|
||||
m_derivative.AssignElementProductOf(m_derivative, logsoftmax);
|
||||
mergedinput.InplaceExp();
|
||||
m_derivative.AssignElementProductOf(m_derivative, mergedinput);
|
||||
ElemType finalscore = 0;
|
||||
//m_derivative.Print("RNNT");
|
||||
finalscore = totalScore.Get00Element();
|
||||
|
|
Загрузка…
Ссылка в новой задаче