Exposing EditDistanceError node in BrainScript

This commit is contained in:
Vadim Mazalov 2017-02-15 22:48:43 -08:00
Родитель a1e25e1073
Коммит ba3f24b74e
15 изменённых файлов: 137 добавлений и 114 удалений

Просмотреть файл

@ -348,11 +348,13 @@ Dropout = CNTK2.Dropout
ElementTimes = CNTK2.ElementTimes
ElementDivide = CNTK2.ElementDivide
ClassificationError = CNTK2.ClassificationError
EditDistanceError = CNTK2.EditDistanceError
Exp = CNTK2.Exp
Floor = CNTK2.Floor
Log = CNTK2.Log
Minus = CNTK2.Minus
Pass = CNTK2.Pass
LabelsToGraph = CNTK2.LabelsToGraph
Plus = CNTK2.Plus
RectifiedLinear = CNTK2.ReLU # deprecated
ReLU = CNTK2.ReLU
@ -518,6 +520,7 @@ CNTK2 = [
// 13. Others
Pass(_, tag='') = new ComputationNode [ operation = 'Pass' ; inputs = _AsNodes (_) /*plus the function args*/ ]
Identity = Pass
LabelsToGraph = Pass
// The value of GetRandomSample(weights /* vector of length nClasses */, numSamples, sampleWithReplacement) randomly samples numSamples using the specified sampling weights.
// The result is a sparse matrix of num samples one-hot vectors as columns.
@ -577,6 +580,7 @@ Shift(input, fromOffset, boundaryValue, boundaryMode=-1/*context*/, dim=-1, tag=
RowSlice(beginIndex, numRows, input, tag='') = Slice(beginIndex, beginIndex + numRows, input, axis = 1)
RowRepeat(input, numRepeats, tag='') = new ComputationNode [ operation = 'RowRepeat' ; inputs = _AsNodes (input) /*plus the function args*/ ]
RowStack(inputs, axis=1, tag='') = new ComputationNode [ operation = 'RowStack' /*plus the function args*/ ]
EditDistanceError(inputs, subPen=0.0, delPen=0.0, insPen=0.0, squashInputs=false, tokensToIgnore={}) = new ComputationNode [ operation = 'EditDistanceError' ; inputs = _AsNodes (input) /*plus the function args*/ ]
Slice(beginIndex, endIndex, input, axis=1, tag='') =
if axis < 0 then [ # time axis: specify -1
beginFlags = if beginIndex > 0 then BS.Boolean.Not (BS.Loop.IsFirstN (beginIndex, input)) else BS.Loop.IsLastN (-beginIndex, input)

Просмотреть файл

@ -431,9 +431,9 @@ shared_ptr<ComputationNode<ElemType>> ComputationNetworkBuilder<ElemType>::Class
}
template <class ElemType>
shared_ptr<ComputationNode<ElemType>> ComputationNetworkBuilder<ElemType>::EditDistanceError(const ComputationNodePtr a, const ComputationNodePtr b, float subPen, float delPen, float insPen, bool squashInputs, vector<size_t> samplesToIgnore, const std::wstring nodeName)
shared_ptr<ComputationNode<ElemType>> ComputationNetworkBuilder<ElemType>::EditDistanceError(const ComputationNodePtr a, const ComputationNodePtr b, float subPen, float delPen, float insPen, bool squashInputs, vector<int> tokensToIgnore, const std::wstring nodeName)
{
return net.AddNodeToNetAndAttachInputs(New<EditDistanceErrorNode<ElemType>>(net.GetDeviceId(), subPen, delPen, insPen, squashInputs, samplesToIgnore, nodeName), { a, b });
return net.AddNodeToNetAndAttachInputs(New<EditDistanceErrorNode<ElemType>>(net.GetDeviceId(), nodeName, subPen, delPen, insPen, squashInputs, tokensToIgnore), { a, b });
}
template <class ElemType>
@ -501,9 +501,9 @@ shared_ptr<ComputationNode<ElemType>> ComputationNetworkBuilder<ElemType>::Seque
}
template <class ElemType>
shared_ptr<ComputationNode<ElemType>> ComputationNetworkBuilder<ElemType>::ForwardBackward(const ComputationNodePtr label, const ComputationNodePtr prediction, int delayConstraint, const std::wstring nodeName)
shared_ptr<ComputationNode<ElemType>> ComputationNetworkBuilder<ElemType>::ForwardBackward(const ComputationNodePtr label, const ComputationNodePtr prediction, int blankTokenId, int delayConstraint, const std::wstring nodeName)
{
return net.AddNodeToNetAndAttachInputs(New<ForwardBackwardNode<ElemType>>(net.GetDeviceId(), nodeName, delayConstraint), { label, prediction });
return net.AddNodeToNetAndAttachInputs(New<ForwardBackwardNode<ElemType>>(net.GetDeviceId(), nodeName, blankTokenId, delayConstraint), { label, prediction });
}
template <class ElemType>

Просмотреть файл

@ -126,12 +126,12 @@ public:
ComputationNodePtr CosDistance(const ComputationNodePtr a, const ComputationNodePtr b, const std::wstring nodeName = L"");
ComputationNodePtr CrossEntropy(const ComputationNodePtr label, const ComputationNodePtr prediction, const std::wstring nodeName = L"");
ComputationNodePtr CrossEntropyWithSoftmax(const ComputationNodePtr label, const ComputationNodePtr prediction, const std::wstring nodeName = L"");
ComputationNodePtr ForwardBackward(const ComputationNodePtr label, const ComputationNodePtr prediction, int delayConstraint, const std::wstring nodeName = L"");
ComputationNodePtr ForwardBackward(const ComputationNodePtr label, const ComputationNodePtr prediction, int blankTokenId, int delayConstraint, const std::wstring nodeName = L"");
ComputationNodePtr DiagTimes(const ComputationNodePtr a, const ComputationNodePtr b, const std::wstring nodeName = L"");
ComputationNodePtr Diagonal(const ComputationNodePtr a, const std::wstring nodeName = L"");
ComputationNodePtr Dropout(const ComputationNodePtr a, const std::wstring nodeName = L"");
ComputationNodePtr DummyCriterion(const ComputationNodePtr objectives, const ComputationNodePtr derivatives, const ComputationNodePtr prediction, const std::wstring nodeName = L"");
ComputationNodePtr EditDistanceError(const ComputationNodePtr a, const ComputationNodePtr b, float subPen, float delPen, float insPen, bool squashInputs, vector<size_t> samplesToIgnore, const std::wstring nodeName = L"");
ComputationNodePtr EditDistanceError(const ComputationNodePtr a, const ComputationNodePtr b, float subPen, float delPen, float insPen, bool squashInputs, vector<int> tokensToIgnore, const std::wstring nodeName = L"");
ComputationNodePtr ElementTimes(const ComputationNodePtr a, const ComputationNodePtr b, const std::wstring nodeName = L"");
ComputationNodePtr DynamicAxis(const ComputationNodePtr a, const std::wstring& nodeName = L"");
ComputationNodePtr ClassificationError(const ComputationNodePtr a, const ComputationNodePtr b, const std::wstring nodeName = L"");

Просмотреть файл

@ -461,7 +461,7 @@ template class NDCG1EvalNode<double>;
// Edit distance error evaluation node with the option of specifying penalty of substitution, deletion and insertion, as well as squashing the input sequences and ignoring certain samples.
// Using the classic DP algorithm as described in https://en.wikipedia.org/wiki/Edit_distance, adjusted to take into account the penalties.
//
// The node allows to squash sequences of repeating labels and ignore certain labels. For example, if squashInputs is true and samplesToIgnore contains label '-' then
// The node allows to squash sequences of repeating labels and ignore certain labels. For example, if squashInputs is true and tokensToIgnore contains label '-' then
// given first input sequence as s1="a-ab-" and second as s2="-aa--abb" the edit distance will be computed against s1' = "aab" and s2' = "aab".
//
// The returned error is computed as: EditDistance(s1,s2) * length(s1') / length(s1)
@ -480,14 +480,14 @@ public:
// delPen - deletion penalty
// insPen - insertion penalty
// squashInputs - whether to merge sequences of identical samples.
// samplesToIgnore - list of samples to ignore during edit distance evaluation
EditDistanceErrorNode(DEVICEID_TYPE deviceId, const wstring & name, float subPen = 0.0f, float delPen = 0.0f, float insPen = 0.0f, bool squashInputs = false, vector<int> samplesToIgnore = {})
: Base(deviceId, name), m_SubPen(subPen), m_DelPen(delPen), m_InsPen(insPen), m_SquashInputs(squashInputs), m_SamplesToIgnore(samplesToIgnore)
// tokensToIgnore - list of samples to ignore during edit distance evaluation
EditDistanceErrorNode(DEVICEID_TYPE deviceId, const wstring & name, float subPen = 0.0f, float delPen = 0.0f, float insPen = 0.0f, bool squashInputs = false, vector<int> tokensToIgnore = {})
: Base(deviceId, name), m_SubPen(subPen), m_DelPen(delPen), m_InsPen(insPen), m_SquashInputs(squashInputs), m_tokensToIgnore(tokensToIgnore)
{
}
EditDistanceErrorNode(const ScriptableObjects::IConfigRecordPtr configp)
: EditDistanceErrorNode(configp->Get(L"deviceId"), configp->Get(L"subPen"), configp->Get(L"delPen"), configp->Get(L"insPen"), configp->Get(L"squashInputs"), configp->Get(L"samplesToIgnore"), L"<placeholder>")
: EditDistanceErrorNode(configp->Get(L"deviceId"), L"<placeholder>", configp->Get(L"subPen"), configp->Get(L"delPen"), configp->Get(L"insPen"), configp->Get(L"squashInputs"), configp->Get(L"tokensToIgnore"))
{
AttachInputsFromConfig(configp, this->GetExpectedNumInputs());
}
@ -510,7 +510,7 @@ public:
MaskMissingColumnsToZero(*m_maxIndexes0, Input(0)->GetMBLayout(), frameRange);
MaskMissingColumnsToZero(*m_maxIndexes1, Input(1)->GetMBLayout(), frameRange);
Value()(0, 0) = ComputeEditDistanceError(*m_maxIndexes0, *m_maxIndexes1, Input(0)->GetMBLayout(), m_subPen, m_delPen, m_insPen, m_squashInputs, m_SamplesToIgnore);
Value()(0, 0) = ComputeEditDistanceError(*m_maxIndexes0, *m_maxIndexes1, Input(0)->GetMBLayout(), m_SubPen, m_DelPen, m_InsPen, m_SquashInputs, m_tokensToIgnore);
}
virtual void Validate(bool isFinalValidationPass) override
@ -539,11 +539,11 @@ public:
node->m_maxIndexes0 = m_maxIndexes0;
node->m_maxIndexes1 = m_maxIndexes1;
node->m_maxValues = m_maxValues;
node->m_squashInputs = m_squashInputs;
node->m_subPen = m_subPen;
node->m_delPen = m_delPen;
node->m_insPen = m_insPen;
node->m_SamplesToIgnore = m_SamplesToIgnore;
node->m_SquashInputs = m_SquashInputs;
node->m_SubPen = m_SubPen;
node->m_DelPen = m_DelPen;
node->m_InsPen = m_InsPen;
node->m_tokensToIgnore = m_tokensToIgnore;
}
}
@ -573,9 +573,9 @@ public:
// delPen - deletion penalty
// insPen - insertion penalty
// squashInputs - whether to merge sequences of identical samples.
// samplesToIgnore - list of samples to ignore during edit distance evaluation
// tokensToIgnore - list of samples to ignore during edit distance evaluation
static ElemType ComputeEditDistanceError(Matrix<ElemType>& firstSeq, const Matrix<ElemType> & secondSeq, MBLayoutPtr pMBLayout,
float subPen, float delPen, float insPen, bool squashInputs, const vector<size_t>& samplesToIgnore)
float subPen, float delPen, float insPen, bool squashInputs, const vector<int>& tokensToIgnore)
{
std::vector<int> firstSeqVec, secondSeqVec;
@ -609,8 +609,8 @@ public:
auto columnIndices = pMBLayout->GetColumnIndices(sequence);
ExtractSampleSequence(firstSeq, columnIndices, squashInputs, samplesToIgnore, firstSeqVec);
ExtractSampleSequence(secondSeq, columnIndices, squashInputs, samplesToIgnore, secondSeqVec);
ExtractSampleSequence(firstSeq, columnIndices, squashInputs, tokensToIgnore, firstSeqVec);
ExtractSampleSequence(secondSeq, columnIndices, squashInputs, tokensToIgnore, secondSeqVec);
//calculate edit distance
size_t firstSize = firstSeqVec.size();
@ -694,20 +694,20 @@ public:
private:
shared_ptr<Matrix<ElemType>> m_maxIndexes0, m_maxIndexes1;
shared_ptr<Matrix<ElemType>> m_maxValues;
bool m_squashInputs;
float m_subPen;
float m_delPen;
float m_insPen;
std::vector<size_t> m_SamplesToIgnore;
bool m_SquashInputs;
float m_SubPen;
float m_DelPen;
float m_InsPen;
std::vector<int> m_tokensToIgnore;
// Clear out_SampleSeqVec and extract a vector of samples from the matrix into out_SampleSeqVec.
static void ExtractSampleSequence(const Matrix<ElemType>& firstSeq, vector<size_t>& columnIndices, bool squashInputs, const vector<size_t>& samplesToIgnore, std::vector<int>& out_SampleSeqVec)
static void ExtractSampleSequence(const Matrix<ElemType>& firstSeq, vector<size_t>& columnIndices, bool squashInputs, const vector<int>& tokensToIgnore, std::vector<int>& out_SampleSeqVec)
{
out_SampleSeqVec.clear();
// Get the first element in the sequence
size_t lastId = (int)firstSeq(0, columnIndices[0]);
if (std::find(samplesToIgnore.begin(), samplesToIgnore.end(), lastId) == samplesToIgnore.end())
if (std::find(tokensToIgnore.begin(), tokensToIgnore.end(), lastId) == tokensToIgnore.end())
out_SampleSeqVec.push_back(lastId);
// Remaining elements
@ -720,7 +720,7 @@ private:
if (lastId != refId)
{
lastId = refId;
if (std::find(samplesToIgnore.begin(), samplesToIgnore.end(), refId) == samplesToIgnore.end())
if (std::find(tokensToIgnore.begin(), tokensToIgnore.end(), refId) == tokensToIgnore.end())
out_SampleSeqVec.push_back(refId);
}
}
@ -730,7 +730,7 @@ private:
for (size_t i = 1; i < columnIndices.size(); i++)
{
auto refId = (int)firstSeq(0, columnIndices[i]);
if (std::find(samplesToIgnore.begin(), samplesToIgnore.end(), refId) == samplesToIgnore.end())
if (std::find(tokensToIgnore.begin(), tokensToIgnore.end(), refId) == tokensToIgnore.end())
out_SampleSeqVec.push_back(refId);
}
}

Просмотреть файл

@ -784,8 +784,8 @@ class ForwardBackwardNode : public ComputationNodeNonLooping<ElemType>, public
}
public:
DeclareConstructorFromConfigWithNumInputs(ForwardBackwardNode);
ForwardBackwardNode(DEVICEID_TYPE deviceId, const wstring & name, int delayConstraint=3) :
Base(deviceId, name), m_delayConstraint(delayConstraint)
ForwardBackwardNode(DEVICEID_TYPE deviceId, const wstring & name, int blankTokenId=INT_MIN, int delayConstraint=3) :
Base(deviceId, name), m_blankTokenId(blankTokenId), m_delayConstraint(delayConstraint)
{
}
@ -857,7 +857,7 @@ public:
FrameRange fr(InputRef(0).GetMBLayout());
InputRef(0).ValueFor(fr).VectorMax(*m_maxIndexes, *m_maxValues, true);
// compute CTC score
m_GammaCal.doCTC(Value(), *m_logSoftmaxOfRight, *m_maxIndexes, *m_maxValues, *m_CTCposterior, InputRef(0).GetMBLayout(), m_delayConstraint);
m_GammaCal.doCTC(Value(), *m_logSoftmaxOfRight, *m_maxIndexes, *m_maxValues, *m_CTCposterior, InputRef(0).GetMBLayout(), m_blankTokenId, m_delayConstraint);
#if NANCHECK
functionValues.HasNan("ForwardBackwardNode");
@ -944,6 +944,7 @@ protected:
shared_ptr<Matrix<ElemType>> m_maxValues;
msra::lattices::GammaCalculation<ElemType> m_GammaCal;
int m_blankTokenId;
int m_delayConstraint;
};

Просмотреть файл

@ -5873,8 +5873,8 @@ void CPUMatrix<ElemType>::RCRFBackwardCompute(const CPUMatrix<ElemType>& alpha,
template<class ElemType>
CPUMatrix<ElemType>& CPUMatrix<ElemType>::AssignCTCScore(
const CPUMatrix<ElemType>& prob, CPUMatrix<ElemType>& alpha, CPUMatrix<ElemType>& beta,
const CPUMatrix<ElemType>& phoneSeq, const CPUMatrix<ElemType>& phoneBoundary, ElemType &totalScore, std::vector<size_t>& uttMap, std::vector<size_t> & uttBeginFrame, std::vector<size_t> & uttFrameNum,
std::vector<size_t> & uttPhoneNum, size_t samplesInRecurrentStep, const size_t maxFrameNum, int delayConstraint, const bool isColWise)
const CPUMatrix<ElemType>& phoneSeq, const CPUMatrix<ElemType>& phoneBoundary, ElemType &totalScore, const std::vector<size_t>& uttMap, const std::vector<size_t> & uttBeginFrame, const std::vector<size_t> & uttFrameNum,
const std::vector<size_t> & uttPhoneNum, const size_t samplesInRecurrentStep, const size_t maxFrameNum, const int delayConstraint, const bool isColWise)
{
// Column wise representation of sequences in input matrices (each column is one sequence/utterance)
if (isColWise)
@ -5936,7 +5936,7 @@ CPUMatrix<ElemType>& CPUMatrix<ElemType>::AssignCTCScore(
y = alpha(s - 2, t - 1);
x = LogAddD(x, y);
}
if (senoneid != 65535)
if (senoneid != SIZE_MAX)
ascore = prob(senoneid, t);
else
ascore = 0;
@ -5977,7 +5977,7 @@ CPUMatrix<ElemType>& CPUMatrix<ElemType>::AssignCTCScore(
x = LogAddD(x, y);
}
if (senoneid != 65535)
if (senoneid != SIZE_MAX)
ascore = prob(senoneid, t);
else
ascore = 0;
@ -6008,7 +6008,7 @@ CPUMatrix<ElemType>& CPUMatrix<ElemType>::AssignCTCScore(
for (s = 1; s < senonenum - 1; s++)
{
senoneid = curPhoneSeq[s];
if (senoneid != 65535)
if (senoneid != SIZE_MAX)
{
ElemType logoccu = alpha(s, t) + beta(s, t) - prob(senoneid, t) - (float)Zt;
if (logoccu < LOG_OF_EPS_IN_LOG)

Просмотреть файл

@ -231,7 +231,7 @@ public:
// sequence training
CPUMatrix<ElemType>& DropFrame(const CPUMatrix<ElemType>& label, const CPUMatrix<ElemType>& gamma, const ElemType& threshhold);
CPUMatrix<ElemType>& AssignSequenceError(const ElemType hsmoothingWeight, const CPUMatrix<ElemType>& label, const CPUMatrix<ElemType>& dnnoutput, const CPUMatrix<ElemType>& gamma, ElemType alpha);
CPUMatrix<ElemType>& CPUMatrix<ElemType>::AssignCTCScore(const CPUMatrix<ElemType>& prob, CPUMatrix<ElemType>& alpha, CPUMatrix<ElemType>& beta, const CPUMatrix<ElemType>& phoneSeq, const CPUMatrix<ElemType>& phoneBoundary, ElemType &totalScore, std::vector<size_t>& uttMap, std::vector<size_t> & uttBeginFrame, std::vector<size_t> & uttFrameNum, std::vector<size_t> & uttPhoneNum, size_t samplesInRecurrentStep, const size_t maxFrameNum, int delayConstraint, const bool isColWise);
CPUMatrix<ElemType>& CPUMatrix<ElemType>::AssignCTCScore(const CPUMatrix<ElemType>& prob, CPUMatrix<ElemType>& alpha, CPUMatrix<ElemType>& beta, const CPUMatrix<ElemType>& phoneSeq, const CPUMatrix<ElemType>& phoneBoundary, ElemType &totalScore, const std::vector<size_t>& uttMap, const std::vector<size_t> & uttBeginFrame, const std::vector<size_t> & uttFrameNum, const std::vector<size_t> & uttPhoneNum, const size_t samplesInRecurrentStep, const size_t maxFrameNum, const int delayConstraint, const bool isColWise);
CPUMatrix<ElemType>& InplaceSqrt();
CPUMatrix<ElemType>& AssignSqrtOf(const CPUMatrix<ElemType>& a);

Просмотреть файл

@ -4286,7 +4286,7 @@ GPUMatrix<ElemType>& GPUMatrix<ElemType>::GetARowByIndex(const GPUMatrix<ElemTyp
// uttBeginFrame(input): the positon of the first frame of each utterance in the minibatch channel. We need this because each channel may contain more than one utterance.
// uttFrameNum (input): the frame number of each utterance. The size of this vector = the number of all utterances in this minibatch
// uttPhoneNum (input): the phone number of each utterance. The size of this vector = the number of all utterances in this minibatch
// numChannels (input): channel number in this minibatch
// numParallelSequences (input): channel number in this minibatch
// maxFrameNum (input): the maximum channel frame number
// delayConstraint -- label output delay constraint introduced during training that allows to have shorter delay during inference.
// Alpha and Beta scores outside of the delay boundary are set to zero.
@ -4299,12 +4299,12 @@ GPUMatrix<ElemType>& GPUMatrix<ElemType>::AssignCTCScore(const GPUMatrix<ElemTyp
const GPUMatrix<ElemType> phoneSeq,
const GPUMatrix<ElemType> phoneBoundary,
ElemType &totalScore,
std::vector<size_t>& uttToChanInd,
std::vector<size_t> & uttBeginFrame,
std::vector<size_t> & uttFrameNum,
std::vector<size_t> & uttPhoneNum,
size_t numChannels,
const size_t maxFrameNum, int delayConstraint, const bool isColWise)
const std::vector<size_t>& uttToChanInd,
const std::vector<size_t> & uttBeginFrame,
const std::vector<size_t> & uttFrameNum,
const std::vector<size_t> & uttPhoneNum,
const size_t numParallelSequences,
const size_t maxFrameNum, const int delayConstraint, const bool isColWise)
{
if (isColWise)
{
@ -4345,21 +4345,21 @@ GPUMatrix<ElemType>& GPUMatrix<ElemType>::AssignCTCScore(const GPUMatrix<ElemTyp
for (long t = 0; t < maxFrameNum; t++)
{
_assignAlphaScore << <block_tail, thread_tail, 0, t_stream >> >(prob.Data(), alpha.Data(), phoneSeq.Data(), phoneBoundary.Data(), gpuUttToChanInd,
gpuFrameNum, gpuBeginFrame, gpuPhoneNum, numChannels, uttNum, t, maxPhoneNum, totalPhoneNum, delayConstraint);
gpuFrameNum, gpuBeginFrame, gpuPhoneNum, numParallelSequences, uttNum, t, maxPhoneNum, totalPhoneNum, delayConstraint);
}
for (long t = maxFrameNum - 1; t >= 0; t--)
{
_assignBetaScore << <block_tail, thread_tail, 0, t_stream >> >(prob.Data(), beta.Data(), phoneSeq.Data(), phoneBoundary.Data(), gpuUttToChanInd,
gpuFrameNum, gpuBeginFrame, gpuPhoneNum, numChannels, uttNum, t, maxPhoneNum, totalPhoneNum, delayConstraint);
gpuFrameNum, gpuBeginFrame, gpuPhoneNum, numParallelSequences, uttNum, t, maxPhoneNum, totalPhoneNum, delayConstraint);
}
_assignTotalScore << <uttNum, 1, 0, t_stream >> > (beta.Data(), gpuScores, uttNum, gpuUttToChanInd, gpuBeginFrame, numChannels, maxPhoneNum);
_assignTotalScore << <uttNum, 1, 0, t_stream >> > (beta.Data(), gpuScores, uttNum, gpuUttToChanInd, gpuBeginFrame, numParallelSequences, maxPhoneNum);
dim3 block_tail_2((uttNum + DEFAULT_THREAD_PER_DIM - 1) / DEFAULT_THREAD_PER_DIM, (maxFrameNum + DEFAULT_THREAD_PER_DIM - 1) / DEFAULT_THREAD_PER_DIM);
_assignCTCScore << < block_tail_2, thread_tail, 0, t_stream >> >(Data(), prob.Data(), alpha.Data(), beta.Data(), phoneSeq.Data(), uttNum, gpuUttToChanInd,
gpuBeginFrame, gpuPhoneNum, gpuFrameNum, numChannels, maxPhoneNum, totalPhoneNum);
gpuBeginFrame, gpuPhoneNum, gpuFrameNum, numParallelSequences, maxPhoneNum, totalPhoneNum);
vector<ElemType>scores(uttNum);
CUDA_CALL(cudaMemcpyAsync(scores.data(), gpuScores, sizeof(ElemType) * uttNum, cudaMemcpyDeviceToHost, t_stream));

Просмотреть файл

@ -350,8 +350,8 @@ public:
GPUMatrix<ElemType>& AssignSequenceError(const ElemType hsmoothingWeight, const GPUMatrix<ElemType>& label, const GPUMatrix<ElemType>& dnnoutput, const GPUMatrix<ElemType>& gamma, ElemType alpha);
GPUMatrix<ElemType>& AssignCTCScore(const GPUMatrix<ElemType>& prob, GPUMatrix<ElemType>& alpha, GPUMatrix<ElemType>& beta,
GPUMatrix<ElemType> phoneSeq, GPUMatrix<ElemType> phoneBoundary, ElemType &totalScore, std::vector<size_t>& uttMap, std::vector<size_t> & uttBeginFrame, std::vector<size_t> & uttFrameNum,
std::vector<size_t> & uttPhoneNum, size_t samplesInRecurrentStep, const size_t maxFrameNum, int delayConstraint, const bool isColWise);
const GPUMatrix<ElemType> phoneSeq, const GPUMatrix<ElemType> phoneBoundary, ElemType &totalScore, const std::vector<size_t>& uttMap, const std::vector<size_t> & uttBeginFrame, const std::vector<size_t> & uttFrameNum,
const std::vector<size_t> & uttPhoneNum, const size_t samplesInRecurrentStep, const size_t maxFrameNum, const int delayConstraint, const bool isColWise);
GPUMatrix<ElemType>& InplaceSqrt();
GPUMatrix<ElemType>& AssignSqrtOf(const GPUMatrix<ElemType>& a);

Просмотреть файл

@ -5291,7 +5291,7 @@ __global__ void _assignAlphaScore(
x = logaddk(x, alphaScore[alphaId_0]);
if (phoneId != 65535)
if (phoneId != SIZE_MAX)
ascore = prob[probId]; // Probability of observing given label at given time
else
ascore = 0;
@ -5382,7 +5382,7 @@ __global__ void _assignBetaScore(
x = logaddk(x, betaScore[betaid_0]);
if (phoneId != 65535)
if (phoneId != SIZE_MAX)
ascore = prob[probId];
else
ascore = 0;
@ -5439,7 +5439,7 @@ __global__ void _assignCTCScore(
LONG64 alphaId = maxPhoneNum* timeId + s;
LONG64 probId = timeId*totalPhoneNum + phoneId;
if (phoneId != 65535)
if (phoneId != SIZE_MAX)
{
ElemType logoccu = alphaScore[alphaId] + betaScore[alphaId] - prob[probId] - (ElemType)P_lx;
CTCscore[probId] = logaddk(CTCscore[probId], logoccu);

Просмотреть файл

@ -5681,16 +5681,16 @@ Matrix<ElemType>& Matrix<ElemType>::AssignSequenceError(const ElemType hsmoothin
// uttBeginFrame(input): the positon of the first frame of each utterance in the minibatch channel. We need this because each channel may contain more than one utterance.
// uttFrameNum (input): the frame number of each utterance. The size of this vector = the number of all utterances in this minibatch
// uttPhoneNum (input): the phone number of each utterance. The size of this vector = the number of all utterances in this minibatch
// numChannels (input): channel number in this minibatch
// numParallelSequences (input): num of parallel sequences
// mbsize (input): the maximum channel frame number
// delayConstraint -- label output delay constraint introduced during training that allows to have shorter delay during inference. This using the original time information to enforce that CTC tokens only get aligned within a time margin.
// Setting this parameter smaller will result in shorted delay between label output during decoding, yet may hurt accuracy.
// delayConstraint=-1 means no constraint
template<class ElemType>
Matrix<ElemType>& Matrix<ElemType>::AssignCTCScore(const Matrix<ElemType>& prob, Matrix<ElemType>& alpha, Matrix<ElemType>& beta,
Matrix<ElemType>& phoneSeq, Matrix<ElemType>& phoneBound, ElemType &totalScore, std::vector<size_t> & uttToChanInd,
std::vector<size_t> & uttBeginFrame, std::vector<size_t> & uttFrameNum, std::vector<size_t> & uttPhoneNum,
size_t numChannels, size_t & mbsize, int& delayConstraint, const bool isColWise)
const Matrix<ElemType>& phoneSeq, const Matrix<ElemType>& phoneBound, ElemType &totalScore, const std::vector<size_t> & uttToChanInd,
const std::vector<size_t> & uttBeginFrame, const std::vector<size_t> & uttFrameNum, const std::vector<size_t> & uttPhoneNum,
const size_t numParallelSequences, const size_t mbsize, const int delayConstraint, const bool isColWise)
{
DecideAndMoveToRightDevice(prob, *this);
alpha.Resize(phoneSeq.GetNumRows(), prob.GetNumCols());
@ -5705,9 +5705,9 @@ Matrix<ElemType>& Matrix<ElemType>::AssignCTCScore(const Matrix<ElemType>& prob,
DISPATCH_MATRIX_ON_FLAG(&prob,
this,
this->m_CPUMatrix->AssignCTCScore(*prob.m_CPUMatrix, *alpha.m_CPUMatrix, *beta.m_CPUMatrix, *phoneSeq.m_CPUMatrix, *phoneBound.m_CPUMatrix, totalScore,
uttToChanInd, uttBeginFrame, uttFrameNum, uttPhoneNum, numChannels, mbsize, delayConstraint, isColWise),
uttToChanInd, uttBeginFrame, uttFrameNum, uttPhoneNum, numParallelSequences, mbsize, delayConstraint, isColWise),
this->m_GPUMatrix->AssignCTCScore(*prob.m_GPUMatrix, *alpha.m_GPUMatrix, *beta.m_GPUMatrix, *phoneSeq.m_GPUMatrix, *phoneBound.m_GPUMatrix, totalScore,
uttToChanInd, uttBeginFrame, uttFrameNum, uttPhoneNum, numChannels, mbsize, delayConstraint, isColWise),
uttToChanInd, uttBeginFrame, uttFrameNum, uttPhoneNum, numParallelSequences, mbsize, delayConstraint, isColWise),
NOT_IMPLEMENTED,
NOT_IMPLEMENTED
);

Просмотреть файл

@ -381,9 +381,9 @@ public:
Matrix<ElemType>& DropFrame(const Matrix<ElemType>& label, const Matrix<ElemType>& gamma, const ElemType& threshhold);
Matrix<ElemType>& AssignSequenceError(const ElemType hsmoothingWeight, const Matrix<ElemType>& label, const Matrix<ElemType>& dnnoutput, const Matrix<ElemType>& gamma, ElemType alpha);
Matrix<ElemType>& AssignCTCScore(const Matrix<ElemType>& prob, Matrix<ElemType>& alpha, Matrix<ElemType>& beta, Matrix<ElemType>& phoneSeq, Matrix<ElemType>& phoneBound, ElemType &totalScore,
std::vector<size_t> & extraUttMap, std::vector<size_t> & uttBeginFrame, std::vector<size_t> & uttFrameNum, std::vector<size_t> & uttPhoneNum, size_t samplesInRecurrentStep,
size_t & mbSize, int& delayConstraint, const bool isColWise);
Matrix<ElemType>& AssignCTCScore(const Matrix<ElemType>& prob, Matrix<ElemType>& alpha, Matrix<ElemType>& beta, const Matrix<ElemType>& phoneSeq, const Matrix<ElemType>& phoneBound, ElemType &totalScore,
const std::vector<size_t> & extraUttMap, const std::vector<size_t> & uttBeginFrame, const std::vector<size_t> & uttFrameNum, const std::vector<size_t> & uttPhoneNum, const size_t samplesInRecurrentStep,
const size_t mbSize, const int delayConstraint, const bool isColWise);
Matrix<ElemType>& InplaceSqrt();
Matrix<ElemType>& AssignSqrtOf(const Matrix<ElemType>& a);

Просмотреть файл

@ -1394,8 +1394,8 @@ GPUMatrix<ElemType>& GPUMatrix<ElemType>::AssignSequenceError(const ElemType hsm
template <class ElemType>
GPUMatrix<ElemType>& GPUMatrix<ElemType>::AssignCTCScore(const GPUMatrix<ElemType>& prob, GPUMatrix<ElemType>& alpha, GPUMatrix<ElemType>& beta,
GPUMatrix<ElemType> phoneSeq, GPUMatrix<ElemType> phoneBound, ElemType &totalScore, std::vector<size_t>& uttMap, std::vector<size_t> & uttBeginFrame, std::vector<size_t> & uttFrameNum,
std::vector<size_t> & uttPhoneNum, size_t samplesInRecurrentStep, const size_t maxFrameNum, int delayConstraint, const bool isColWise)
const GPUMatrix<ElemType> phoneSeq, const GPUMatrix<ElemType> phoneBound, ElemType &totalScore, const std::vector<size_t>& uttMap, const std::vector<size_t> & uttBeginFrame, const std::vector<size_t> & uttFrameNum,
const std::vector<size_t> & uttPhoneNum, const size_t samplesInRecurrentStep, const size_t maxFrameNum, const int delayConstraint, const bool isColWise)
{
return *this;
}

Просмотреть файл

@ -258,6 +258,7 @@ public:
// maxValues (input): values of max elements in label input vectors
// labels (input): 1-hot vector with frame-level phone labels
// CTCPosterior (output): CTC posterior
// blankTokenId (input): id of the blank token
// delayConstraint -- label output delay constraint introduced during training that allows to have shorter delay during inference. This using the original time information to enforce that CTC tokens only get aligned within a time margin.
// Setting this parameter smaller will result in shorted delay between label output during decoding, yet may hurt accuracy.
// delayConstraint=-1 means no constraint
@ -267,6 +268,7 @@ public:
const Microsoft::MSR::CNTK::Matrix<ElemType>& maxValues,
Microsoft::MSR::CNTK::Matrix<ElemType>& CTCPosterior,
const std::shared_ptr<Microsoft::MSR::CNTK::MBLayout> pMBLayout,
size_t blankTokenId,
int delayConstraint = -1)
{
const auto numParallelSequences = pMBLayout->GetNumParallelSequences();
@ -274,28 +276,28 @@ public:
const size_t numRows = prob.GetNumRows();
const size_t numCols = prob.GetNumCols();
m_deviceid = prob.GetDeviceId();
Microsoft::MSR::CNTK::Matrix<ElemType> alpha(m_deviceid);
Microsoft::MSR::CNTK::Matrix<ElemType> beta(m_deviceid);
Microsoft::MSR::CNTK::Matrix<ElemType> rowSum(m_deviceid);
Microsoft::MSR::CNTK::Matrix<ElemType> matrixPhoneSeqs(CPUDEVICE);
Microsoft::MSR::CNTK::Matrix<ElemType> matrixPhoneBounds(CPUDEVICE);
std::vector<std::vector<size_t>> allUttPhoneSeqs;
std::vector<std::vector<size_t>> allUttPhoneBounds;
const size_t maxSizeT = 65535;
int maxPhoneNum = 0;
std::vector<size_t> phoneSeq;
std::vector<size_t> phoneBound;
ElemType finalScore = 0;
const size_t blankid = numRows - 1;
if (blankTokenId == INT_MIN)
blankTokenId = numRows - 1;
size_t mbsize = numCols / numParallelSequences;
// cal gamma for each utterance
// Prepare data structures from the reader
// the positon of the first frame of each utterance in the minibatch channel. We need this because each channel may contain more than one utterance.
std::vector<size_t> uttBeginFrame;
// the frame number of each utterance. The size of this vector = the number of all utterances in this minibatch
std::vector<size_t> uttFrameNum;
// the phone number of each utterance. The size of this vector = the number of all utterances in this minibatch
std::vector<size_t> uttPhoneNum;
// map from utterance ID to minibatch channel ID. We need this because each channel may contain more than one utterance.
std::vector<size_t> uttToChanInd;
uttBeginFrame.reserve(numSequences);
uttFrameNum.reserve(numSequences);
@ -304,46 +306,52 @@ public:
size_t seqId = 0;
for (const auto& seq : pMBLayout->GetAllSequences())
{
if (seq.seqId != GAP_SEQUENCE_ID) {
assert(seq.seqId == seqId++);
uttToChanInd.push_back(seq.s);
size_t numFrames = seq.GetNumTimeSteps();
uttBeginFrame.push_back(seq.tBegin);
uttFrameNum.push_back(numFrames);
if (seq.seqId == GAP_SEQUENCE_ID)
continue;
// Get the phone list and boundaries
phoneSeq.clear();
phoneSeq.push_back(maxSizeT);
phoneBound.clear();
phoneBound.push_back(0);
int prevPhoneId = -1;
size_t startFrameInd = seq.tBegin*numParallelSequences + seq.s;
size_t endFrameInd = seq.tEnd*numParallelSequences + seq.s;
size_t frameCounter = 0;
for (auto frameInd = startFrameInd; frameInd < endFrameInd; frameInd += numParallelSequences, frameCounter++) {
if (maxValues(0, frameInd) == 2)
{
prevPhoneId = (size_t)maxIndexes(0, frameInd);
assert(seq.seqId == seqId++);
uttToChanInd.push_back(seq.s);
size_t numFrames = seq.GetNumTimeSteps();
uttBeginFrame.push_back(seq.tBegin);
uttFrameNum.push_back(numFrames);
phoneSeq.push_back(blankid);
phoneBound.push_back(frameCounter);
phoneSeq.push_back(prevPhoneId);
phoneBound.push_back(frameCounter);
}
// Get the phone list and boundaries
phoneSeq.clear();
phoneSeq.push_back(SIZE_MAX);
phoneBound.clear();
phoneBound.push_back(0);
int prevPhoneId = -1;
size_t startFrameInd = seq.tBegin * numParallelSequences + seq.s;
size_t endFrameInd = seq.tEnd * numParallelSequences + seq.s;
size_t frameCounter = 0;
for (auto frameInd = startFrameInd; frameInd < endFrameInd; frameInd += numParallelSequences, frameCounter++)
{
// Labels are represented as 1-hot vectors for each frame
// If the 1-hot vectors may have either value 1 or 2 at the position of the phone corresponding to the frame:
// 1 means the frame is within phone boundary
// 2 means the frame is the phone boundary
if (maxValues(0, frameInd) == 2)
{
prevPhoneId = (size_t)maxIndexes(0, frameInd);
phoneSeq.push_back(blankTokenId);
phoneBound.push_back(frameCounter);
phoneSeq.push_back(prevPhoneId);
phoneBound.push_back(frameCounter);
}
phoneSeq.push_back(blankid);
phoneBound.push_back(numFrames);
phoneSeq.push_back(maxSizeT);
phoneBound.push_back(numFrames);
allUttPhoneSeqs.push_back(phoneSeq);
allUttPhoneBounds.push_back(phoneBound);
uttPhoneNum.push_back(phoneSeq.size());
if (phoneSeq.size() > maxPhoneNum)
maxPhoneNum = phoneSeq.size();
}
phoneSeq.push_back(blankTokenId);
phoneBound.push_back(numFrames);
phoneSeq.push_back(SIZE_MAX);
phoneBound.push_back(numFrames);
allUttPhoneSeqs.push_back(phoneSeq);
allUttPhoneBounds.push_back(phoneBound);
uttPhoneNum.push_back(phoneSeq.size());
if (phoneSeq.size() > maxPhoneNum)
maxPhoneNum = phoneSeq.size();
}
matrixPhoneSeqs.Resize(maxPhoneNum, numSequences);
@ -356,16 +364,22 @@ public:
matrixPhoneBounds(j, i) = (ElemType)allUttPhoneBounds[i][j];
}
}
// Once these matrices populated, move them to the active device
matrixPhoneSeqs.TransferFromDeviceToDevice(CPUDEVICE, m_deviceid);
matrixPhoneBounds.TransferFromDeviceToDevice(CPUDEVICE, m_deviceid);
// compute alpha, beta and CTC scores
Microsoft::MSR::CNTK::Matrix<ElemType> alpha(m_deviceid);
Microsoft::MSR::CNTK::Matrix<ElemType> beta(m_deviceid);
CTCPosterior.AssignCTCScore(prob, alpha, beta, matrixPhoneSeqs, matrixPhoneBounds, finalScore, uttToChanInd, uttBeginFrame,
uttFrameNum, uttPhoneNum, numParallelSequences, mbsize, delayConstraint, true);
uttFrameNum, uttPhoneNum, numParallelSequences, mbsize, delayConstraint, /*isColWise=*/true );
Microsoft::MSR::CNTK::Matrix<ElemType> rowSum(m_deviceid);
rowSum.Resize(1, numCols);
// Normalize the CTC scores
CTCPosterior.VectorSum(CTCPosterior, rowSum, true);
CTCPosterior.VectorSum(CTCPosterior, rowSum, /*isColWise=*/true);
CTCPosterior.RowElementDivideBy(rowSum);
totalScore(0, 0) = -finalScore;

Просмотреть файл

@ -14,7 +14,11 @@ BOOST_AUTO_TEST_CASE(ComputeEditDistanceErrorTest)
{
Matrix<float> firstSeq(CPUDEVICE);
Matrix<float> secondSeq(CPUDEVICE);
<<<<<<< a1e25e1073b132300c9e9bc1c134da03a0d22838
vector<size_t> samplesToIgnore;
=======
vector<int> tokensToIgnore;
>>>>>>> Exposing EditDistanceError node in BrainScript
size_t seqSize = 10;
firstSeq.Resize(1, seqSize);
secondSeq.Resize(1, seqSize);
@ -26,7 +30,7 @@ BOOST_AUTO_TEST_CASE(ComputeEditDistanceErrorTest)
MBLayoutPtr pMBLayout = make_shared<MBLayout>(1, seqSize, L"X");
pMBLayout->AddSequence(0, 0, 0, seqSize);
float ed = EditDistanceErrorNode<float>::ComputeEditDistanceError(firstSeq, secondSeq, pMBLayout, 1, 1, 1, true, samplesToIgnore);
float ed = EditDistanceErrorNode<float>::ComputeEditDistanceError(firstSeq, secondSeq, pMBLayout, 1, 1, 1, true, tokensToIgnore);
assert((int)ed == 2);
for (size_t i = 0; i < seqSize; i++)
@ -34,12 +38,12 @@ BOOST_AUTO_TEST_CASE(ComputeEditDistanceErrorTest)
secondSeq(0, i) = (float)i;
}
ed = EditDistanceErrorNode<float>::ComputeEditDistanceError(firstSeq, secondSeq, pMBLayout, 1, 1, 1, true, samplesToIgnore);
ed = EditDistanceErrorNode<float>::ComputeEditDistanceError(firstSeq, secondSeq, pMBLayout, 1, 1, 1, true, tokensToIgnore);
assert((int)ed == 0);
secondSeq(0, seqSize-1) = (float)123;
ed = EditDistanceErrorNode<float>::ComputeEditDistanceError(firstSeq, secondSeq, pMBLayout, 1, 1, 1, true, samplesToIgnore);
ed = EditDistanceErrorNode<float>::ComputeEditDistanceError(firstSeq, secondSeq, pMBLayout, 1, 1, 1, true, tokensToIgnore);
assert((int)ed == 1);
}