Rename EditDistanceNode to EditDistanceErrorNode. Use GetAllSequences in ED computation

This commit is contained in:
Vadim Mazalov 2017-01-05 13:28:08 -08:00
Родитель 7c8b4f3cb8
Коммит f9b9070b02
6 изменённых файлов: 116 добавлений и 120 удалений

Просмотреть файл

@ -160,7 +160,7 @@ bool CheckFunction(std::string& p_nodeType, bool* allowUndeterminedVariable)
#endif
else if (EqualInsensitive(nodeType, OperationNameOf(ClassBasedCrossEntropyWithSoftmaxNode), L"CBCEWithSM")) ret = true;
else if (EqualInsensitive(nodeType, OperationNameOf(ClassificationErrorNode), L"ErrorPrediction")) ret = true;
else if (EqualInsensitive(nodeType, OperationNameOf(EditDistanceNode))) ret = true;
else if (EqualInsensitive(nodeType, OperationNameOf(EditDistanceErrorNode))) ret = true;
else if (EqualInsensitive(nodeType, OperationNameOf(EqualNode))) ret = true;
else if (EqualInsensitive(nodeType, OperationNameOf(GreaterEqualNode))) ret = true;
else if (EqualInsensitive(nodeType, OperationNameOf(GreaterNode))) ret = true;

Просмотреть файл

@ -446,7 +446,7 @@ bool ComputationNetwork::IsTypicalCriterionNode(ComputationNodeBasePtr nodePtr)
nodePtr->OperationName() == OperationNameOf(CrossEntropyNode) ||
nodePtr->OperationName() == OperationNameOf(ClassBasedCrossEntropyWithSoftmaxNode) ||
nodePtr->OperationName() == OperationNameOf(ClassificationErrorNode) ||
nodePtr->OperationName() == OperationNameOf(EditDistanceNode) ||
nodePtr->OperationName() == OperationNameOf(EditDistanceErrorNode) ||
#ifdef COMING_SOON
nodePtr->OperationName() == OperationNameOf(CRFNode) ||
#endif

Просмотреть файл

@ -54,7 +54,7 @@ static shared_ptr<ComputationNode<ElemType>> CreateStandardNode(const std::wstri
else if (nodeType == OperationNameOf(DropoutNode)) return New<DropoutNode<ElemType>>(forward<_Types>(_Args)...);
else if (nodeType == OperationNameOf(DummyCriterionNode)) return New<DummyCriterionNode<ElemType>>(forward<_Types>(_Args)...);
else if (nodeType == OperationNameOf(DynamicAxisNode)) return New<DynamicAxisNode<ElemType>>(forward<_Types>(_Args)...);
else if (nodeType == OperationNameOf(EditDistanceNode)) return New<EditDistanceNode<ElemType>>(forward<_Types>(_Args)...);
else if (nodeType == OperationNameOf(EditDistanceErrorNode)) return New<EditDistanceErrorNode<ElemType>>(forward<_Types>(_Args)...);
else if (nodeType == OperationNameOf(ElementTimesNode)) return New<ElementTimesNode<ElemType>>(forward<_Types>(_Args)...);
else if (nodeType == OperationNameOf(EnvironmentInputNode)) return New<EnvironmentInputNode<ElemType>>(forward<_Types>(_Args)...);
else if (nodeType == OperationNameOf(EpochAccumulatorNode)) return New<EpochAccumulatorNode<ElemType>>(forward<_Types>(_Args)...);
@ -429,9 +429,9 @@ shared_ptr<ComputationNode<ElemType>> ComputationNetworkBuilder<ElemType>::Class
}
template <class ElemType>
shared_ptr<ComputationNode<ElemType>> ComputationNetworkBuilder<ElemType>::EditDistance(const ComputationNodePtr a, const ComputationNodePtr b, float subPen, float delPen, float insPen, bool squashInputs, vector<int> samplesToIgnore, const std::wstring nodeName)
shared_ptr<ComputationNode<ElemType>> ComputationNetworkBuilder<ElemType>::EditDistanceError(const ComputationNodePtr a, const ComputationNodePtr b, float subPen, float delPen, float insPen, bool squashInputs, vector<int> samplesToIgnore, const std::wstring nodeName)
{
return net.AddNodeToNetAndAttachInputs(New<EditDistanceNode<ElemType>>(net.GetDeviceId(), nodeName, subPen, delPen, insPen, squashInputs, samplesToIgnore), { a, b });
return net.AddNodeToNetAndAttachInputs(New<EditDistanceErrorNode<ElemType>>(net.GetDeviceId(), nodeName, subPen, delPen, insPen, squashInputs, samplesToIgnore), { a, b });
}
template <class ElemType>

Просмотреть файл

@ -129,7 +129,7 @@ public:
ComputationNodePtr Diagonal(const ComputationNodePtr a, const std::wstring nodeName = L"");
ComputationNodePtr Dropout(const ComputationNodePtr a, const std::wstring nodeName = L"");
ComputationNodePtr DummyCriterion(const ComputationNodePtr objectives, const ComputationNodePtr derivatives, const ComputationNodePtr prediction, const std::wstring nodeName = L"");
ComputationNodePtr EditDistance(const ComputationNodePtr a, const ComputationNodePtr b, float subPen, float delPen, float insPen, bool squashInputs, vector<int> samplesToIgnore, const std::wstring nodeName = L"");
ComputationNodePtr EditDistanceError(const ComputationNodePtr a, const ComputationNodePtr b, float subPen, float delPen, float insPen, bool squashInputs, vector<int> samplesToIgnore, const std::wstring nodeName = L"");
ComputationNodePtr ElementTimes(const ComputationNodePtr a, const ComputationNodePtr b, const std::wstring nodeName = L"");
ComputationNodePtr DynamicAxis(const ComputationNodePtr a, const std::wstring& nodeName = L"");
ComputationNodePtr ClassificationError(const ComputationNodePtr a, const ComputationNodePtr b, const std::wstring nodeName = L"");

Просмотреть файл

@ -456,11 +456,22 @@ protected:
template class NDCG1EvalNode<float>;
template class NDCG1EvalNode<double>;
// Edit distance error evaluation node with the option of specifying penalty of substitution, deletion and insertion
// Using the classic DP algorithm as described in https://en.wikipedia.org/wiki/Edit_distance, adjusted to take into account the penalties
//
// The node has the option to squash sequences of repeating labels and ignore certain labels, e.g. if squashInputs is true and samplesToIgnore contain label '-' then
// given first input sequence as s1="a-ab-" and second as s2="-aa--abb" the edit distance error will be computed after squashing the sequences, i.e. between sequence s1' = "aab" and s2' = "aab"
//
// The returned error is computed as: EditDistance(s1,s2) * length(s1') / length(s1)
//
// Just like ClassificationError and other evaluation nodes, when used as an evaluation criterion, the SGD process will aggregate all values over an epoch and report the average, i.e. the error rate.
// Primary objective of this node is for error evaluation of CTC training, see formula (1) in "Connectionist Temporal Classification: Labelling Unsegmented
// equence Data with Recurrent Neural Networks", http://machinelearning.wustl.edu/mlpapers/paper_files/icml2006_GravesFGS06.pdf
template<class ElemType>
class EditDistanceNode : public ComputationNodeNonLooping/*ComputationNode*/<ElemType>, public NumInputs<2>
class EditDistanceErrorNode : public ComputationNodeNonLooping/*ComputationNode*/<ElemType>, public NumInputs<2>
{
typedef ComputationNodeNonLooping<ElemType> Base; UsingComputationNodeMembersBoilerplate;
static const std::wstring TypeName() { return L"EditDistance"; }
static const std::wstring TypeName() { return L"EditDistanceError"; }
public:
// subPen - substitution penalty
@ -468,18 +479,18 @@ public:
// insPen - insertion penalty
// squashInputs - whether to merge sequences of identical samples.
// samplesToIgnore - list of samples to ignore during edit distance evaluation
EditDistanceNode(DEVICEID_TYPE deviceId, const wstring & name, float subPen, float delPen, float insPen, bool squashInputs, vector<int> samplesToIgnore)
EditDistanceErrorNode(DEVICEID_TYPE deviceId, const wstring & name, float subPen, float delPen, float insPen, bool squashInputs, vector<int> samplesToIgnore)
: Base(deviceId, name), m_SubPen(subPen), m_DelPen(delPen), m_InsPen(insPen), m_SquashInputs(squashInputs), m_SamplesToIgnore(samplesToIgnore)
{
}
EditDistanceNode(const ScriptableObjects::IConfigRecordPtr configp)
: EditDistanceNode(configp->Get(L"deviceId"), L"<placeholder>", configp->Get(L"subPen"), configp->Get(L"delPen"), configp->Get(L"insPen"), configp->Get(L"squashInputs"), configp->Get(L"samplesToIgnore"))
EditDistanceErrorNode(const ScriptableObjects::IConfigRecordPtr configp)
: EditDistanceErrorNode(configp->Get(L"deviceId"), L"<placeholder>", configp->Get(L"subPen"), configp->Get(L"delPen"), configp->Get(L"insPen"), configp->Get(L"squashInputs"), configp->Get(L"samplesToIgnore"))
{
AttachInputsFromConfig(configp, this->GetExpectedNumInputs());
}
EditDistanceNode(DEVICEID_TYPE deviceId, const wstring& name)
EditDistanceErrorNode(DEVICEID_TYPE deviceId, const wstring& name)
: Base(deviceId, name)
{
}
@ -491,13 +502,16 @@ public:
virtual void ForwardPropNonLooping() override
{
if (Input(0)->Is<SparseInputValue<ElemType>>() || Input(1)->Is<SparseInputValue<ElemType>>())
LogicError("EditDistanceError node was not tested for sparse inputs.");
FrameRange frameRange(Input(0)->GetMBLayout());
Input(0)->ValueFor(frameRange).VectorMax(*m_maxIndexes0, *m_maxValues, true);
Input(1)->ValueFor(frameRange).VectorMax(*m_maxIndexes1, *m_maxValues, true);
MaskMissingColumnsToZero(*m_maxIndexes0, Input(0)->GetMBLayout(), frameRange);
MaskMissingColumnsToZero(*m_maxIndexes1, Input(1)->GetMBLayout(), frameRange);
Value()(0, 0) = ComputeEditDistance(*m_maxIndexes0, *m_maxIndexes1, Input(1)->GetNumParallelSequences(), Input(0)->GetMBLayout(), m_SubPen, m_DelPen, m_InsPen, m_SquashInputs, m_SamplesToIgnore);
Value()(0, 0) = ComputeEditDistanceError(*m_maxIndexes0, *m_maxIndexes1, Input(1)->GetNumParallelSequences(), Input(0)->GetMBLayout(), m_SubPen, m_DelPen, m_InsPen, m_SquashInputs, m_SamplesToIgnore);
}
virtual void Validate(bool isFinalValidationPass) override
@ -522,7 +536,7 @@ public:
if (flags & CopyNodeFlags::copyNodeValue)
{
auto node = dynamic_pointer_cast<EditDistanceNode<ElemType>>(nodeP);
auto node = dynamic_pointer_cast<EditDistanceErrorNode<ElemType>>(nodeP);
node->m_maxIndexes0 = m_maxIndexes0;
node->m_maxIndexes1 = m_maxIndexes1;
node->m_maxValues = m_maxValues;
@ -561,7 +575,7 @@ public:
// insPen - insertion penalty
// squashInputs - whether to merge sequences of identical samples.
// samplesToIgnore - list of samples to ignore during edit distance evaluation
static ElemType ComputeEditDistance(Matrix<ElemType>& firstSeq, const Matrix<ElemType> & secondSeq, size_t numParallelSequences, MBLayoutPtr pMBLayout,
static ElemType ComputeEditDistanceError(Matrix<ElemType>& firstSeq, const Matrix<ElemType> & secondSeq, size_t numParallelSequences, MBLayoutPtr pMBLayout,
float subPen, float delPen, float insPen, bool squashInputs, const vector<int>& samplesToIgnore)
{
std::vector<int> firstSeqVec, secondSeqVec;
@ -580,111 +594,96 @@ public:
float del, ins, sub;
ElemType wrongSampleNum = 0.0;
size_t totalSampleNum = 0, totalRecNum = 0;
size_t mbsize = firstSeq.GetNumCols() / numParallelSequences;
size_t lastSentEnd = 0;
size_t frameNum = 0;
size_t j = 0;
for (size_t seqIndex = 0; seqIndex < numParallelSequences; seqIndex++)
size_t totalSampleNum = 0, totalframeNum = 0;
size_t sequenceStartFrame = 0;
for (const auto& sequence : pMBLayout->GetAllSequences())
{
lastSentEnd = 0;
j = 0;
while (j < mbsize)
if (sequence.seqId == GAP_SEQUENCE_ID)
continue;
auto seqIndex = sequence.s;
auto frameNum = min(sequence.tEnd, pMBLayout->GetNumTimeSteps()) - (size_t)(max(sequence.tBegin, (ptrdiff_t)0));
if (frameNum > 0)
{
frameNum = 0;
for (j = lastSentEnd; j < mbsize; j++)
totalframeNum += frameNum;
ExtractSampleSequence(firstSeq, numParallelSequences, sequenceStartFrame, frameNum, seqIndex, squashInputs, samplesToIgnore, firstSeqVec);
ExtractSampleSequence(secondSeq, numParallelSequences, sequenceStartFrame, frameNum, seqIndex, squashInputs, samplesToIgnore, secondSeqVec);
//calculate edit distance
size_t firstSize = firstSeqVec.size();
totalSampleNum += firstSize;
size_t secondSize = secondSeqVec.size();
grid.Resize(firstSize + 1, secondSize + 1);
insMatrix.Resize(firstSize + 1, secondSize + 1);
delMatrix.Resize(firstSize + 1, secondSize + 1);
subMatrix.Resize(firstSize + 1, secondSize + 1);
insMatrix.SetValue(0.0f);
delMatrix.SetValue(0.0f);
subMatrix.SetValue(0.0f);
for (size_t i = 0; i < firstSize + 1; i++)
{
if (pMBLayout->IsEnd(seqIndex, j))
{
frameNum = j - lastSentEnd + 1;
break;
}
grid(i, 0) = (float)(i * delPen);
delMatrix(i, 0) = (float)i;
}
if (frameNum > 0)
for (size_t j = 0; j < secondSize + 1; j++)
{
totalRecNum += frameNum;
ExtractSampleSequence(firstSeqVec, firstSeq, numParallelSequences, frameNum, lastSentEnd, seqIndex, squashInputs, samplesToIgnore);
ExtractSampleSequence(secondSeqVec, secondSeq, numParallelSequences, frameNum, lastSentEnd, seqIndex, squashInputs, samplesToIgnore);
//calculate edit distance
size_t firstSize = firstSeqVec.size();
totalSampleNum += firstSize;
size_t secondSize = secondSeqVec.size();
grid.Resize(firstSize + 1, secondSize + 1);
insMatrix.Resize(firstSize + 1, secondSize + 1);
delMatrix.Resize(firstSize + 1, secondSize + 1);
subMatrix.Resize(firstSize + 1, secondSize + 1);
insMatrix.SetValue(0.0f);
delMatrix.SetValue(0.0f);
subMatrix.SetValue(0.0f);
for (size_t i = 0; i < firstSize + 1; i++){
grid(i, 0) = (float)(i * delPen);
delMatrix(i, 0) = (float)i;
}
for (size_t j = 0; j < secondSize + 1; j++)
grid(0, j) = (float)(j * insPen);
insMatrix(0, j) = (float)j;
}
for (size_t i = 1; i < firstSize + 1; i++)
{
for (size_t j = 1; j < secondSize + 1; j++)
{
grid(0, j) = (float)(j * insPen);
insMatrix(0, j) = (float)j;
}
for (size_t i = 1; i < firstSize + 1; i++)
{
for (size_t j = 1; j < secondSize + 1; j++)
if (firstSeqVec[i - 1] == secondSeqVec[j - 1])
{
if (firstSeqVec[i - 1] == secondSeqVec[j - 1])
grid(i, j) = grid(i - 1, j - 1);
insMatrix(i, j) = insMatrix(i - 1, j - 1);
delMatrix(i, j) = delMatrix(i - 1, j - 1);
subMatrix(i, j) = subMatrix(i - 1, j - 1);
}
else
{
del = grid(i - 1, j) + delPen; //deletion
ins = grid(i, j - 1) + insPen; //insertion
sub = grid(i - 1, j - 1) + subPen; //substitution
if (sub <= del && sub <= ins)
{
grid(i, j) = grid(i - 1, j - 1);
insMatrix(i, j) = insMatrix(i - 1, j - 1);
delMatrix(i, j) = delMatrix(i - 1, j - 1);
subMatrix(i, j) = subMatrix(i - 1, j - 1);
subMatrix(i, j) = subMatrix(i - 1, j - 1) + 1.0f;
grid(i, j) = sub;
}
else if (del < ins)
{
insMatrix(i, j) = insMatrix(i - 1, j);
subMatrix(i, j) = subMatrix(i - 1, j);
delMatrix(i, j) = delMatrix(i - 1, j) + 1.0f;
grid(i, j) = del;
}
else
{
del = grid(i - 1, j) + delPen; //deletion
ins = grid(i, j - 1) + insPen; //insertion
sub = grid(i - 1, j - 1) + subPen; //substitution
if (sub <= del && sub <= ins)
{
insMatrix(i, j) = insMatrix(i - 1, j - 1);
delMatrix(i, j) = delMatrix(i - 1, j - 1);
subMatrix(i, j) = subMatrix(i - 1, j - 1) + 1.0f;
grid(i, j) = sub;
}
else if (del < ins)
{
insMatrix(i, j) = insMatrix(i - 1, j);
subMatrix(i, j) = subMatrix(i - 1, j);
delMatrix(i, j) = delMatrix(i - 1, j) + 1.0f;
grid(i, j) = del;
}
else
{
delMatrix(i, j) = delMatrix(i, j - 1);
subMatrix(i, j) = subMatrix(i, j - 1);
insMatrix(i, j) = insMatrix(i, j - 1) + 1.0f;
grid(i, j) = ins;
}
delMatrix(i, j) = delMatrix(i, j - 1);
subMatrix(i, j) = subMatrix(i, j - 1);
insMatrix(i, j) = insMatrix(i, j - 1) + 1.0f;
grid(i, j) = ins;
}
}
}
wrongSampleNum += insMatrix(firstSize, secondSize) + delMatrix(firstSize, secondSize) + subMatrix(firstSize, secondSize);
}
lastSentEnd += frameNum;
if (lastSentEnd < mbsize)
{
FrameRange fr(pMBLayout, lastSentEnd);
if (pMBLayout->IsGap(fr.Sequence(seqIndex)))
break;
}
wrongSampleNum += insMatrix(firstSize, secondSize) + delMatrix(firstSize, secondSize) + subMatrix(firstSize, secondSize);
}
sequenceStartFrame += frameNum;
}
return (ElemType)(wrongSampleNum * totalRecNum / totalSampleNum);
return (ElemType)(wrongSampleNum * totalframeNum / totalSampleNum);
}
private:
@ -696,48 +695,45 @@ private:
float m_InsPen;
std::vector<int> m_SamplesToIgnore;
// Extract a vector of samples from the matrix.
static void ExtractSampleSequence(std::vector<int>& outputSampleSeqVec, const Matrix<ElemType>& firstSeq, size_t numParallelSequences, size_t frameNum, size_t lastSentEnd, size_t seqIndex, bool squashInputs, const vector<int>& samplesToIgnore)
// Clear out_SampleSeqVec and extract a vector of samples from the matrix into out_SampleSeqVec.
static void ExtractSampleSequence(const Matrix<ElemType>& firstSeq, size_t numParallelSequences, size_t sequenceStartFrame, size_t frameNum, size_t seqIndex, bool squashInputs, const vector<int>& samplesToIgnore, std::vector<int>& out_SampleSeqVec)
{
outputSampleSeqVec.clear();
out_SampleSeqVec.clear();
if (frameNum == 0)
return;
// First element in the sequence
size_t lastId = (int)firstSeq(0, lastSentEnd * numParallelSequences + seqIndex);
if (std::find(samplesToIgnore.begin(), samplesToIgnore.end(), refId) == samplesToIgnore.end())
outputSampleSeqVec.push_back(refId);
// Get the first element in the sequence
size_t lastId = (int)firstSeq(0, sequenceStartFrame * numParallelSequences + seqIndex);
if (std::find(samplesToIgnore.begin(), samplesToIgnore.end(), lastId) == samplesToIgnore.end())
out_SampleSeqVec.push_back(lastId);
// Remaining elements
if (squashInputs)
{
//squash sequences of identical samples
for (size_t i = lastSentEnd+1; i < frameNum + lastSentEnd; i++)
for (size_t i = sequenceStartFrame+1; i < frameNum + sequenceStartFrame; i++)
{
auto refId = (int)firstSeq(0, i * numParallelSequences + seqIndex);
size_t refId = (int)firstSeq(0, i * numParallelSequences + seqIndex);
if (lastId != refId)
{
lastId = refId;
if (std::find(samplesToIgnore.begin(), samplesToIgnore.end(), refId) == samplesToIgnore.end())
outputSampleSeqVec.push_back(refId);
out_SampleSeqVec.push_back(refId);
}
}
}
else
{
for (size_t i = lastSentEnd+1; i < frameNum + lastSentEnd; i++)
for (size_t i = sequenceStartFrame+1; i < frameNum + sequenceStartFrame; i++)
{
auto refId = (int)firstSeq(0, i * numParallelSequences + seqIndex);
if (std::find(samplesToIgnore.begin(), samplesToIgnore.end(), refId) == samplesToIgnore.end())
outputSampleSeqVec.push_back(refId);
out_SampleSeqVec.push_back(refId);
}
}
}
};
template class EditDistanceNode<float>;
template class EditDistanceNode<double>;
template class EditDistanceErrorNode<float>;
template class EditDistanceErrorNode<double>;
#ifdef COMING_SOON

Просмотреть файл

@ -10,7 +10,7 @@ namespace Microsoft { namespace MSR { namespace CNTK { namespace Test {
BOOST_AUTO_TEST_SUITE(EditDistanceTests)
BOOST_AUTO_TEST_CASE(ComputeEditDistanceTest)
BOOST_AUTO_TEST_CASE(ComputeEditDistanceErrorTest)
{
Matrix<float> firstSeq(CPUDEVICE);
Matrix<float> secondSeq(CPUDEVICE);
@ -26,7 +26,7 @@ BOOST_AUTO_TEST_CASE(ComputeEditDistanceTest)
MBLayoutPtr pMBLayout = make_shared<MBLayout>(1, seqSize, L"X");
pMBLayout->AddSequence(0, 0, 0, seqSize);
float ed = EditDistanceNode<float>::ComputeEditDistance(firstSeq, secondSeq, 1, pMBLayout, 1, 1, 1, true, samplesToIgnore);
float ed = EditDistanceErrorNode<float>::ComputeEditDistanceError(firstSeq, secondSeq, 1, pMBLayout, 1, 1, 1, true, samplesToIgnore);
assert((int)ed == 2);
for (size_t i = 0; i < seqSize; i++)
@ -34,12 +34,12 @@ BOOST_AUTO_TEST_CASE(ComputeEditDistanceTest)
secondSeq(0, i) = (float)i;
}
ed = EditDistanceNode<float>::ComputeEditDistance(firstSeq, secondSeq, 1, pMBLayout, 1, 1, 1, true, samplesToIgnore);
ed = EditDistanceErrorNode<float>::ComputeEditDistanceError(firstSeq, secondSeq, 1, pMBLayout, 1, 1, 1, true, samplesToIgnore);
assert((int)ed == 0);
secondSeq(0, seqSize-1) = (float)123;
ed = EditDistanceNode<float>::ComputeEditDistance(firstSeq, secondSeq, 1, pMBLayout, 1, 1, 1, true, samplesToIgnore);
ed = EditDistanceErrorNode<float>::ComputeEditDistanceError(firstSeq, secondSeq, 1, pMBLayout, 1, 1, 1, true, samplesToIgnore);
assert((int)ed == 1);
}