merge linear and loss
This commit is contained in:
Родитель
d0675c2731
Коммит
9bd990319a
|
@ -596,7 +596,7 @@ EditDistanceError(leftInput, rightInput, subPen=1.0, delPen=1.0, insPen=1.0, squ
|
||||||
RNNTError(leftInput, rightInput, mergedinput, tokensToIgnore=[||], tag='') = new ComputationNode [ operation = 'RNNTError' ; inputs = _AsNodes (leftInput : rightInput: mergedinput) /*plus the function args*/ ]
|
RNNTError(leftInput, rightInput, mergedinput, tokensToIgnore=[||], tag='') = new ComputationNode [ operation = 'RNNTError' ; inputs = _AsNodes (leftInput : rightInput: mergedinput) /*plus the function args*/ ]
|
||||||
LatticeSequenceWithSoftmax(labels, evaluation, scaledLogLikelihood, lattice, symListPath, phonePath, stateListPath, transProbPath, latticeConfigPath = "LatticeNode.config", hSmoothingWeight = 0.95, frameDropThresh = 1e-10, doReferenceAlign = false, seqGammarUsesMBR = false, seqGammarAMF = 14.0, seqGammarLMF = 14.0, seqGammarBMMIFactor = 0.0, seqGammarWordPen = 0.0, tag='') = new ComputationNode [ operation = 'LatticeSequenceWithSoftmax' ; inputs = _AsNodes (labels : evaluation : scaledLogLikelihood : lattice) /*plus the function args*/ ]
|
LatticeSequenceWithSoftmax(labels, evaluation, scaledLogLikelihood, lattice, symListPath, phonePath, stateListPath, transProbPath, latticeConfigPath = "LatticeNode.config", hSmoothingWeight = 0.95, frameDropThresh = 1e-10, doReferenceAlign = false, seqGammarUsesMBR = false, seqGammarAMF = 14.0, seqGammarLMF = 14.0, seqGammarBMMIFactor = 0.0, seqGammarWordPen = 0.0, tag='') = new ComputationNode [ operation = 'LatticeSequenceWithSoftmax' ; inputs = _AsNodes (labels : evaluation : scaledLogLikelihood : lattice) /*plus the function args*/ ]
|
||||||
ForwardBackward(graph, features, blankTokenId, delayConstraint=-1, tag='') = new ComputationNode [ operation = 'ForwardBackward' ; inputs = _AsNodes (graph : features) /*plus the function args*/ ]
|
ForwardBackward(graph, features, blankTokenId, delayConstraint=-1, tag='') = new ComputationNode [ operation = 'ForwardBackward' ; inputs = _AsNodes (graph : features) /*plus the function args*/ ]
|
||||||
RNNT(graph, transcription, prediction,mergedinput, blankTokenId,delayConstraint=-1, tag='') = new ComputationNode [ operation = 'RNNT' ; inputs = _AsNodes (graph : transcription: prediction: mergedinput) /*plus the function args*/ ]
|
RNNT(graph, transcription, prediction,mergedinput, W, b, blankTokenId,delayConstraint=-1, tag='') = new ComputationNode [ operation = 'RNNT' ; inputs = _AsNodes (graph : transcription: prediction: mergedinput: W: b) /*plus the function args*/ ]
|
||||||
LabelsToGraph(labels, tag='') = new ComputationNode [ operation = 'LabelsToGraph' ; inputs = _AsNodes (labels) /*plus the function args*/ ]
|
LabelsToGraph(labels, tag='') = new ComputationNode [ operation = 'LabelsToGraph' ; inputs = _AsNodes (labels) /*plus the function args*/ ]
|
||||||
StopGradient(input, tag='') = new ComputationNode [ operation = 'StopGradient' ; inputs = _AsNodes (input) /*plus the function args*/ ]
|
StopGradient(input, tag='') = new ComputationNode [ operation = 'StopGradient' ; inputs = _AsNodes (input) /*plus the function args*/ ]
|
||||||
Slice(beginIndex, endIndex, input, axis=1, tag='') =
|
Slice(beginIndex, endIndex, input, axis=1, tag='') =
|
||||||
|
|
|
@ -2243,13 +2243,13 @@ namespace CNTK
|
||||||
|
|
||||||
return BinaryOp(PrimitiveOpType::ForwardBackward, graph, features, std::move(additionalProperties), name);
|
return BinaryOp(PrimitiveOpType::ForwardBackward, graph, features, std::move(additionalProperties), name);
|
||||||
}
|
}
|
||||||
FunctionPtr RNNT(const Variable& graph, const Variable& transcription, const Variable& prediction, const Variable& mergedinput, size_t blankTokenId, int delayConstraint, const std::wstring& name)
|
FunctionPtr RNNT(const Variable& graph, const Variable& transcription, const Variable& prediction, const Variable& mergedinput, const Variable& W, const Variable& b, size_t blankTokenId, int delayConstraint, const std::wstring& name)
|
||||||
{
|
{
|
||||||
auto additionalProperties = Dictionary();
|
auto additionalProperties = Dictionary();
|
||||||
additionalProperties[PrimitiveFunctionAttribute::AttributeNameBlankTokenId] = blankTokenId;
|
additionalProperties[PrimitiveFunctionAttribute::AttributeNameBlankTokenId] = blankTokenId;
|
||||||
additionalProperties[PrimitiveFunctionAttribute::AttributeNameDelayConstraint] = delayConstraint;
|
additionalProperties[PrimitiveFunctionAttribute::AttributeNameDelayConstraint] = delayConstraint;
|
||||||
|
|
||||||
std::vector<Variable> operands = {graph, transcription, prediction, mergedinput};
|
std::vector<Variable> operands = {graph, transcription, prediction, mergedinput, W, b};
|
||||||
return AsComposite(MakeSharedObject<PrimitiveFunction>(PrimitiveOpType::RNNT, operands, std::move(additionalProperties), name), name);
|
return AsComposite(MakeSharedObject<PrimitiveFunction>(PrimitiveOpType::RNNT, operands, std::move(additionalProperties), name), name);
|
||||||
|
|
||||||
//return BinaryOp(PrimitiveOpType::RNNT, graph,transcription, prediction, std::move(additionalProperties), name);
|
//return BinaryOp(PrimitiveOpType::RNNT, graph,transcription, prediction, std::move(additionalProperties), name);
|
||||||
|
|
|
@ -545,9 +545,10 @@ shared_ptr<ComputationNode<ElemType>> ComputationNetworkBuilder<ElemType>::Forwa
|
||||||
return net.AddNodeToNetAndAttachInputs(New<ForwardBackwardNode<ElemType>>(net.GetDeviceId(), nodeName, blankTokenId, delayConstraint), { graph, features });
|
return net.AddNodeToNetAndAttachInputs(New<ForwardBackwardNode<ElemType>>(net.GetDeviceId(), nodeName, blankTokenId, delayConstraint), { graph, features });
|
||||||
}
|
}
|
||||||
template <class ElemType>
|
template <class ElemType>
|
||||||
shared_ptr<ComputationNode<ElemType>> ComputationNetworkBuilder<ElemType>::RNNT(const ComputationNodePtr graph, const ComputationNodePtr transcription, const ComputationNodePtr prediction, const ComputationNodePtr mergedinput, int blankTokenId, int delayConstraint, const std::wstring nodeName)
|
shared_ptr<ComputationNode<ElemType>> ComputationNetworkBuilder<ElemType>::RNNT(const ComputationNodePtr graph, const ComputationNodePtr transcription, const ComputationNodePtr prediction, const ComputationNodePtr mergedinput,
|
||||||
|
const ComputationNodePtr W, const ComputationNodePtr b, int blankTokenId, int delayConstraint, const std::wstring nodeName)
|
||||||
{
|
{
|
||||||
return net.AddNodeToNetAndAttachInputs(New<RNNTNode<ElemType>>(net.GetDeviceId(), nodeName, blankTokenId, delayConstraint), { graph, transcription,prediction, mergedinput});
|
return net.AddNodeToNetAndAttachInputs(New<RNNTNode<ElemType>>(net.GetDeviceId(), nodeName, blankTokenId, delayConstraint), { graph, transcription,prediction, mergedinput, W, b});
|
||||||
}
|
}
|
||||||
|
|
||||||
template <class ElemType>
|
template <class ElemType>
|
||||||
|
|
|
@ -188,7 +188,7 @@ public:
|
||||||
ComputationNodePtr RandomSampleInclusionFrequency(const ComputationNodePtr a, const std::wstring nodeName = L"");
|
ComputationNodePtr RandomSampleInclusionFrequency(const ComputationNodePtr a, const std::wstring nodeName = L"");
|
||||||
ComputationNodePtr RectifiedLinear(const ComputationNodePtr a, const std::wstring nodeName = L"");
|
ComputationNodePtr RectifiedLinear(const ComputationNodePtr a, const std::wstring nodeName = L"");
|
||||||
ComputationNodePtr Reshape(const ComputationNodePtr a, const TensorShape& imageLayout, const std::wstring nodeName = L"");
|
ComputationNodePtr Reshape(const ComputationNodePtr a, const TensorShape& imageLayout, const std::wstring nodeName = L"");
|
||||||
ComputationNodePtr RNNT(const ComputationNodePtr graph, const ComputationNodePtr transcription, const ComputationNodePtr prediction, const ComputationNodePtr mergedinput, int blankTokenId, int delayConstraint, const std::wstring nodeName = L"");
|
ComputationNodePtr RNNT(const ComputationNodePtr graph, const ComputationNodePtr transcription, const ComputationNodePtr prediction, const ComputationNodePtr mergedinput, const ComputationNodePtr W, const ComputationNodePtr b, int blankTokenId, int delayConstraint, const std::wstring nodeName = L"");
|
||||||
ComputationNodePtr RowRepeat(const ComputationNodePtr a, const size_t num_repeat, const std::wstring nodeName = L"");
|
ComputationNodePtr RowRepeat(const ComputationNodePtr a, const size_t num_repeat, const std::wstring nodeName = L"");
|
||||||
ComputationNodePtr RowSlice(const ComputationNodePtr a, const size_t start_index, const size_t num_rows, const std::wstring nodeName = L"");
|
ComputationNodePtr RowSlice(const ComputationNodePtr a, const size_t start_index, const size_t num_rows, const std::wstring nodeName = L"");
|
||||||
ComputationNodePtr RowStack(const std::vector<ComputationNodePtr> pinputs, const std::wstring nodeName = L"");
|
ComputationNodePtr RowStack(const std::vector<ComputationNodePtr> pinputs, const std::wstring nodeName = L"");
|
||||||
|
|
|
@ -1576,7 +1576,7 @@ template class CustomProxyOpNode<float>;
|
||||||
// -----------------------------------------------------------------------
|
// -----------------------------------------------------------------------
|
||||||
|
|
||||||
template <class ElemType>
|
template <class ElemType>
|
||||||
class RNNTNode : public ComputationNodeNonLooping<ElemType>, public NumInputs<4>
|
class RNNTNode : public ComputationNodeNonLooping<ElemType>, public NumInputs<6>
|
||||||
{
|
{
|
||||||
typedef ComputationNodeNonLooping<ElemType> Base;
|
typedef ComputationNodeNonLooping<ElemType> Base;
|
||||||
UsingComputationNodeMembersBoilerplate;
|
UsingComputationNodeMembersBoilerplate;
|
||||||
|
@ -1611,7 +1611,16 @@ public:
|
||||||
}
|
}
|
||||||
else if (inputIndex == 3)
|
else if (inputIndex == 3)
|
||||||
{
|
{
|
||||||
BackpropToMerge(InputRef(inputIndex).Gradient(), Gradient(), *m_derivative);
|
BackpropToX(InputRef(inputIndex).Gradient(), Gradient(), *m_derivative, InputRef(4).Value());
|
||||||
|
|
||||||
|
}
|
||||||
|
else if (inputIndex == 5)
|
||||||
|
{
|
||||||
|
BackpropToB(InputRef(inputIndex).Gradient(), Gradient(), *m_derivative);
|
||||||
|
}
|
||||||
|
else if (inputIndex == 4)
|
||||||
|
{
|
||||||
|
BackpropToW(InputRef(inputIndex).Gradient(), Gradient(), *m_derivative, InputRef(3).Value());
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
RuntimeError("RNNTNode criterion expects only two inputs: labels and network output.");
|
RuntimeError("RNNTNode criterion expects only two inputs: labels and network output.");
|
||||||
|
@ -1634,7 +1643,7 @@ public:
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
void BackpropToMerge(Matrix<ElemType>& inputGradientValues, const Matrix<ElemType>& gradientValues,
|
void BackpropToB(Matrix<ElemType>& inputGradientValues, const Matrix<ElemType>& gradientValues,
|
||||||
Matrix<ElemType>& RNNTDerivative)
|
Matrix<ElemType>& RNNTDerivative)
|
||||||
{
|
{
|
||||||
#if DUMPOUTPUT
|
#if DUMPOUTPUT
|
||||||
|
@ -1646,7 +1655,54 @@ public:
|
||||||
//m_tmpMatrix->AssignUserOp2(RNNTDerivative, InputRef(2).Value().GetNumCols(), InputRef(1).Value().GetNumCols(), InputRef(0).GetMBLayout()->GetNumParallelSequences(), 0);
|
//m_tmpMatrix->AssignUserOp2(RNNTDerivative, InputRef(2).Value().GetNumCols(), InputRef(1).Value().GetNumCols(), InputRef(0).GetMBLayout()->GetNumParallelSequences(), 0);
|
||||||
//m_tmpMatrix->TransferFromDeviceToDevice(CPUDEVICE, InputRef(0).Value().GetDeviceId());
|
//m_tmpMatrix->TransferFromDeviceToDevice(CPUDEVICE, InputRef(0).Value().GetDeviceId());
|
||||||
// inputGradientValues+= gradientValues*(softmaxOfRight - CTCposterior)
|
// inputGradientValues+= gradientValues*(softmaxOfRight - CTCposterior)
|
||||||
Matrix<ElemType>::Scale(gradientValues.Get00Element(), RNNTDerivative, inputGradientValues);
|
Matrix<ElemType>::Scale(gradientValues.Get00Element(), RNNTDerivative, *m_tmpMatrix);
|
||||||
|
Matrix<ElemType>::VectorSum(*m_tmpMatrix, inputGradientValues, false);
|
||||||
|
//inputGradientValues.Print("gradient");
|
||||||
|
/*printf("back to F\n");
|
||||||
|
if (gradientValues.GetDeviceId() != CPUDEVICE)
|
||||||
|
printf("gradientValues after F is in GPU\n");*/
|
||||||
|
#if DUMPOUTPUT
|
||||||
|
inputGradientValues.Print("RNNTNode Partial-Right");
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
void BackpropToW(Matrix<ElemType>& inputGradientValues, const Matrix<ElemType>& gradientValues,
|
||||||
|
Matrix<ElemType>& RNNTDerivative, Matrix<ElemType>& inputValue)
|
||||||
|
{
|
||||||
|
#if DUMPOUTPUT
|
||||||
|
inputFunctionValues.Print("RNNTNode Partial-inputFunctionValues");
|
||||||
|
gradientValues.Print("RNNTNode Partial-gradientValues");
|
||||||
|
inputGradientValues.Print("RNNTNode Partial-Right-in");
|
||||||
|
#endif
|
||||||
|
//sum u for RNNT Derivative
|
||||||
|
//m_tmpMatrix->AssignUserOp2(RNNTDerivative, InputRef(2).Value().GetNumCols(), InputRef(1).Value().GetNumCols(), InputRef(0).GetMBLayout()->GetNumParallelSequences(), 0);
|
||||||
|
//m_tmpMatrix->TransferFromDeviceToDevice(CPUDEVICE, InputRef(0).Value().GetDeviceId());
|
||||||
|
// inputGradientValues+= gradientValues*(softmaxOfRight - CTCposterior)
|
||||||
|
Matrix<ElemType>::Scale(gradientValues.Get00Element(), RNNTDerivative, *m_tmpMatrix);
|
||||||
|
inputGradientValues.AssignProductOf(inputValue, false, *m_tmpMatrix, true);
|
||||||
|
//inputGradientValues.Print("gradient");
|
||||||
|
/*printf("back to F\n");
|
||||||
|
if (gradientValues.GetDeviceId() != CPUDEVICE)
|
||||||
|
printf("gradientValues after F is in GPU\n");*/
|
||||||
|
#if DUMPOUTPUT
|
||||||
|
inputGradientValues.Print("RNNTNode Partial-Right");
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
void BackpropToX(Matrix<ElemType>& inputGradientValues, const Matrix<ElemType>& gradientValues,
|
||||||
|
Matrix<ElemType>& RNNTDerivative, Matrix<ElemType>& inputValue)
|
||||||
|
{
|
||||||
|
#if DUMPOUTPUT
|
||||||
|
inputFunctionValues.Print("RNNTNode Partial-inputFunctionValues");
|
||||||
|
gradientValues.Print("RNNTNode Partial-gradientValues");
|
||||||
|
inputGradientValues.Print("RNNTNode Partial-Right-in");
|
||||||
|
#endif
|
||||||
|
//sum u for RNNT Derivative
|
||||||
|
//m_tmpMatrix->AssignUserOp2(RNNTDerivative, InputRef(2).Value().GetNumCols(), InputRef(1).Value().GetNumCols(), InputRef(0).GetMBLayout()->GetNumParallelSequences(), 0);
|
||||||
|
//m_tmpMatrix->TransferFromDeviceToDevice(CPUDEVICE, InputRef(0).Value().GetDeviceId());
|
||||||
|
// inputGradientValues+= gradientValues*(softmaxOfRight - CTCposterior)
|
||||||
|
Matrix<ElemType>::Scale(gradientValues.Get00Element(), RNNTDerivative, *m_tmpMatrix);
|
||||||
|
inputGradientValues.AssignProductOf(inputValue, false, *m_tmpMatrix, false);
|
||||||
//inputGradientValues.Print("gradient");
|
//inputGradientValues.Print("gradient");
|
||||||
/*printf("back to F\n");
|
/*printf("back to F\n");
|
||||||
if (gradientValues.GetDeviceId() != CPUDEVICE)
|
if (gradientValues.GetDeviceId() != CPUDEVICE)
|
||||||
|
@ -1703,11 +1759,15 @@ public:
|
||||||
|
|
||||||
//m_RNNTDerivative->SwitchToMatrixType(m_outputLogDistribution->GetMatrixType(), m_outputLogDistribution->GetFormat(), false);
|
//m_RNNTDerivative->SwitchToMatrixType(m_outputLogDistribution->GetMatrixType(), m_outputLogDistribution->GetFormat(), false);
|
||||||
//m_RNNTDerivative->Resize(m_outputLogDistribution->GetNumRows(), m_outputLogDistribution->GetNumCols());
|
//m_RNNTDerivative->Resize(m_outputLogDistribution->GetNumRows(), m_outputLogDistribution->GetNumCols());
|
||||||
|
m_outputDensity->AssignProductOf(InputRef(4).Value(), true, InputRef(3).Value(), false);
|
||||||
|
m_outputDensity->AssignSumOf(*m_outputDensity, InputRef(5).Value());
|
||||||
|
m_outputDensity->InplaceLogSoftmax(true);
|
||||||
FrameRange fr(InputRef(0).GetMBLayout());
|
FrameRange fr(InputRef(0).GetMBLayout());
|
||||||
InputRef(0).ValueFor(fr).VectorMax(*m_maxIndexes, *m_maxValues, true);
|
InputRef(0).ValueFor(fr).VectorMax(*m_maxIndexes, *m_maxValues, true);
|
||||||
|
|
||||||
|
|
||||||
// compute CTC score
|
// compute CTC score
|
||||||
m_GammaCal.twodimForwardBackward(Value(), InputRef(1).Value(), InputRef(2).Value(), InputRef(3).Value(), *m_maxIndexes, *m_derivative, InputRef(1).GetMBLayout(), InputRef(2).GetMBLayout(), m_blankTokenId);
|
m_GammaCal.twodimForwardBackward(Value(), InputRef(1).Value(), InputRef(2).Value(), *m_outputDensity, *m_maxIndexes, *m_derivative, InputRef(1).GetMBLayout(), InputRef(2).GetMBLayout(), m_blankTokenId);
|
||||||
|
|
||||||
#if NANCHECK
|
#if NANCHECK
|
||||||
functionValues.HasNan("RNNTNode");
|
functionValues.HasNan("RNNTNode");
|
||||||
|
@ -1724,8 +1784,7 @@ public:
|
||||||
|
|
||||||
if (isFinalValidationPass)
|
if (isFinalValidationPass)
|
||||||
{
|
{
|
||||||
if (!(Input(0)->GetSampleMatrixNumRows() == Input(3)->GetSampleMatrixNumRows() && // match vector dimension
|
if (!(Input(0)->HasMBLayout() &&
|
||||||
Input(0)->HasMBLayout() &&
|
|
||||||
Input(0)->GetMBLayout() == Input(2)->GetMBLayout()))
|
Input(0)->GetMBLayout() == Input(2)->GetMBLayout()))
|
||||||
{
|
{
|
||||||
LogicError("The Matrix dimension in the RNNTNode operation does not match.");
|
LogicError("The Matrix dimension in the RNNTNode operation does not match.");
|
||||||
|
@ -1751,6 +1810,7 @@ public:
|
||||||
node->m_derivativeForG->SetValue(*m_derivative);
|
node->m_derivativeForG->SetValue(*m_derivative);
|
||||||
node->m_maxIndexes->SetValue(*m_maxIndexes);
|
node->m_maxIndexes->SetValue(*m_maxIndexes);
|
||||||
node->m_maxValues->SetValue(*m_maxValues);
|
node->m_maxValues->SetValue(*m_maxValues);
|
||||||
|
node->m_outputDensity->SetValue(*m_outputDensity);
|
||||||
node->m_delayConstraint = m_delayConstraint;
|
node->m_delayConstraint = m_delayConstraint;
|
||||||
//node->m_RNNTDerivative->SetValue(*m_RNNTDerivative);
|
//node->m_RNNTDerivative->SetValue(*m_RNNTDerivative);
|
||||||
node->m_tmpMatrix->SetValue(*m_tmpMatrix);
|
node->m_tmpMatrix->SetValue(*m_tmpMatrix);
|
||||||
|
@ -1762,6 +1822,7 @@ public:
|
||||||
{
|
{
|
||||||
Base::RequestMatricesBeforeForwardProp(matrixPool);
|
Base::RequestMatricesBeforeForwardProp(matrixPool);
|
||||||
RequestMatrixFromPool(m_derivativeForG, matrixPool);
|
RequestMatrixFromPool(m_derivativeForG, matrixPool);
|
||||||
|
RequestMatrixFromPool(m_outputDensity, matrixPool);
|
||||||
RequestMatrixFromPool(m_derivative, matrixPool);
|
RequestMatrixFromPool(m_derivative, matrixPool);
|
||||||
//RequestMatrixFromPool(m_outputDistribution, matrixPool);
|
//RequestMatrixFromPool(m_outputDistribution, matrixPool);
|
||||||
RequestMatrixFromPool(m_maxIndexes, matrixPool);
|
RequestMatrixFromPool(m_maxIndexes, matrixPool);
|
||||||
|
@ -1774,6 +1835,7 @@ public:
|
||||||
{
|
{
|
||||||
Base::ReleaseMatricesAfterBackprop(matrixPool);
|
Base::ReleaseMatricesAfterBackprop(matrixPool);
|
||||||
ReleaseMatrixToPool(m_derivativeForG, matrixPool);
|
ReleaseMatrixToPool(m_derivativeForG, matrixPool);
|
||||||
|
ReleaseMatrixToPool(m_outputDensity, matrixPool);
|
||||||
ReleaseMatrixToPool(m_derivative, matrixPool);
|
ReleaseMatrixToPool(m_derivative, matrixPool);
|
||||||
//ReleaseMatrixToPool(m_outputDistribution, matrixPool);
|
//ReleaseMatrixToPool(m_outputDistribution, matrixPool);
|
||||||
ReleaseMatrixToPool(m_maxIndexes, matrixPool);
|
ReleaseMatrixToPool(m_maxIndexes, matrixPool);
|
||||||
|
|
|
@ -540,20 +540,21 @@ public:
|
||||||
//matrixOutputDistribution.Print("h");
|
//matrixOutputDistribution.Print("h");
|
||||||
//log softmax of f+g
|
//log softmax of f+g
|
||||||
//mergedinput.InplaceLogSoftmax(true);
|
//mergedinput.InplaceLogSoftmax(true);
|
||||||
Microsoft::MSR::CNTK::Matrix<ElemType> logsoftmax(m_deviceid_gpu);
|
|
||||||
|
/*Microsoft::MSR::CNTK::Matrix<ElemType> logsoftmax(m_deviceid_gpu);
|
||||||
logsoftmax.SetValue(mergedinput);
|
logsoftmax.SetValue(mergedinput);
|
||||||
|
|
||||||
logsoftmax.InplaceLogSoftmax(true);
|
logsoftmax.InplaceLogSoftmax(true);*/
|
||||||
//matrixOutputDistribution.Print("prob");
|
//matrixOutputDistribution.Print("prob");
|
||||||
// forward backward to compute alpha, beta derivaitves
|
// forward backward to compute alpha, beta derivaitves
|
||||||
Microsoft::MSR::CNTK::Matrix<ElemType> alpha(m_deviceid_gpu);
|
Microsoft::MSR::CNTK::Matrix<ElemType> alpha(m_deviceid_gpu);
|
||||||
Microsoft::MSR::CNTK::Matrix<ElemType> beta(m_deviceid_gpu);
|
Microsoft::MSR::CNTK::Matrix<ElemType> beta(m_deviceid_gpu);
|
||||||
m_derivative.TransferToDeviceIfNotThere(m_deviceid_gpu);
|
m_derivative.TransferToDeviceIfNotThere(m_deviceid_gpu);
|
||||||
m_derivative.AssignRNNTScore(logsoftmax, alpha, beta, matrixPhoneSeqs, matrixPhoneSeqs, uttFrameToChanInd, uttFrameBeginIdx, uttBeginForOutputditribution, uttPhoneToChanInd, uttPhoneBeginIdx,
|
m_derivative.AssignRNNTScore(mergedinput, alpha, beta, matrixPhoneSeqs, matrixPhoneSeqs, uttFrameToChanInd, uttFrameBeginIdx, uttBeginForOutputditribution, uttPhoneToChanInd, uttPhoneBeginIdx,
|
||||||
uttFrameNum, uttPhoneNum, numParallelSequences, numPhoneParallelSequences, maxPhoneNum, maxFrameNum, totalScore, blankTokenId, -1,true);
|
uttFrameNum, uttPhoneNum, numParallelSequences, numPhoneParallelSequences, maxPhoneNum, maxFrameNum, totalScore, blankTokenId, -1,true);
|
||||||
|
|
||||||
logsoftmax.InplaceExp();
|
mergedinput.InplaceExp();
|
||||||
m_derivative.AssignElementProductOf(m_derivative, logsoftmax);
|
m_derivative.AssignElementProductOf(m_derivative, mergedinput);
|
||||||
ElemType finalscore = 0;
|
ElemType finalscore = 0;
|
||||||
//m_derivative.Print("RNNT");
|
//m_derivative.Print("RNNT");
|
||||||
finalscore = totalScore.Get00Element();
|
finalscore = totalScore.Get00Element();
|
||||||
|
|
Загрузка…
Ссылка в новой задаче