This commit is contained in:
Rui Zhao (SPEECH) 2018-12-19 23:15:41 -08:00
Родитель f2c8084510
Коммит 5a2fb887d8
3 изменённых файлов: 129 добавлений и 69 удалений

Просмотреть файл

@ -4736,10 +4736,7 @@ uttPhoneBeginIdx;
SyncGuard syncGuard;
_assignRNNTScore(Data(), prob.Data(), alpha.Data(), beta.Data(), phoneSeq.Data(), uttNum, uttFrameNum, uttPhoneNum, uttFrameBeginIdx, uttFrameToChanInd,
uttBeginForOutputditribution, numParallelSequences, maxPhoneNum, totalPhoneNum, blankTokenId);
_assignCTCScore << < block_tail_2, thread_tail, 0, t_stream >> >(Data(), prob.Data(), alpha.Data(), beta.Data(), phoneSeq.Data(), uttNum, gpuUttToChanInd,
gpuBeginFrame, gpuPhoneNum, gpuFrameNum, numParallelSequences, maxPhoneNum, totalPhoneNum);
CUDA_CALL(cudaFree(gpuFrameNum));
CUDA_CALL(cudaFree(gpuPhoneNum));

Просмотреть файл

@ -5824,89 +5824,152 @@ __global__ void _assignRNNTAlphaScore(
const int delayConstraint)
{
typedef typename TypeSelector<ElemType>::comp_t comp_t;
int uttid = blockDim.x * blockIdx.x + threadIdx.x;
int uttId = blockDim.x * blockIdx.x + threadIdx.x;
// Number of phones and frames in this utterance
LONG64 phoneNum = uttPhoneNum[uttId];
LONG64 frameNum = uttFrameNum[uttId];
size_t frameNum = uttFrameNum[uttId];
if (t >= frameNum)
continue;
if (uttId >= uttNum || phoneSeqId >= phoneNum - 1 || t >= frameNum || phoneSeqId == 0)
return;
size_t phoneNum = uttPhoneNum[uttId];
if (u >= phoneNum)
continue;
// Current and previous phone indices in phoneSeq matrix
LONG64 labelid = uttId * maxPhoneNum + phoneSeqId;
LONG64 labelid_2 = labelid - 2;
// Current phone indices in phoneSeq matrix
size_t labelid = uttId * maxPhoneNum + u;
// Actual current phone label
LONG64 phoneId = (LONG64)(phoneSeq[labelid]);
size_t phoneId = (size_t)(phoneSeq[labelid]); //phone ID of u
// Index of the current frame in minibatch
LONG64 timeId = (t + uttBeginFrame[uttId]) * numChannels + uttToChanInd[uttId];
//time Index of the current frame in minibatch
size_t timeId = (t + uttBeginFrame[uttId]) * numChannels + uttToChanInd[uttId];
// Index of probability of observing phoneId at frame timeId
LONG64 probId = timeId * totalPhoneNum + phoneId;
// phone Index of the current frame in minibatch
// size_t unitId = (u + uttBeginPhonePos[uttId]) * numChannels + uttToChanInd[uttId];
LONG64 alphaId = maxPhoneNum * timeId + phoneSeqId; // alpha_t(s)
// Previous time frame
LONG64 timeId_1 = timeId - numChannels; // Index corresponding to (t-1)
LONG64 alphaId_0 = maxPhoneNum * timeId_1 + phoneSeqId; // alpha_{t-1}(s)
LONG64 alphaId_1 = alphaId_0 - 1; // alpha_{t-1}(s-1)
LONG64 alphaId_2 = alphaId_0 - 2; // alpha_{t-1}(s-2)
//(t,u) index of outputdistribution in minibatch
size_t tuID = uttBeginForMerge[uttId] + t * phoneNum + u; //tuID for (t,u)
if (t == 0)
// Index of outputdistribution of observing phoneId at frame timeId
//size_t probId = tuID*totalPhoneNum + phoneId;// ID for p(y(u)|t,u)
//index for alpha
size_t alphaId = maxPhoneNum * timeId + u; // alpha_t(s)
if (t == 0 && u == 0)
{
// Initialize recursion
if (phoneSeqId == 1 || phoneSeqId == 2)
{
alphaScore[alphaId] = prob[probId];
}
alphaScore[alphaId] = 0.0;
}
else if (t == 0)
{
size_t alphaId_1 = alphaId - 1; // alpha ID for [t,u-1]
size_t tuID_1 = tuID - 1; //tuID for [t,u-1]
size_t probId_1 = tuID_1 * totalPhoneNum + phoneId; //ID for p(y(u)|t,u-1)
alphaScore[alphaId] = alphaScore[alphaId_1] + prob[probId_1];
}
else if (u == 0)
{
size_t tuID_2 = tuID - phoneNum; //tuID for [t-1,u]
size_t alphaId_2 = alphaId - numChannels * maxPhoneNum; //alpha ID for [t-1, u]
size_t probId_2 = tuID_2 * totalPhoneNum + blankTokenId; //ID for p(phi|t-1,u)
alphaScore[alphaId] = alphaScore[alphaId_2] + prob[probId_2];
}
else
{
if (phoneSeqId >= 1)
{
comp_t x = LZERO;
size_t alphaId_1 = alphaId - 1; // alpha ID for [t,u-1]
size_t tuID_1 = tuID - 1; //tuID for [t,u-1]
size_t probId_1 = tuID_1 * totalPhoneNum + phoneId; //ID for p(y(u)|t,u-1)
size_t tuID_2 = tuID - phoneNum; //tuID for [t-1,u]
size_t alphaId_2 = alphaId - numChannels * maxPhoneNum; //alpha ID for [t-1, u]
size_t probId_2 = tuID_2 * totalPhoneNum + blankTokenId; //ID for p(phi|t-1,u)
comp_t ascore;
if (phoneSeqId > 2)
{
// if current label is not blank and not equal prev non-blank label
if ((LONG64)(phoneSeq[labelid]) != blankTokenId && phoneId != (LONG64)(phoneSeq[labelid_2]))
{
x = logaddk(x, (comp_t) alphaScore[alphaId_2]);
}
}
ElemType x = LZERO, y = LZERO;
x = alphaScore[alphaId_1] + prob[probId_1];
y = alphaScore[alphaId_2] + prob[probId_2];
if (phoneSeqId > 1)
{
x = logaddk(x, (comp_t) alphaScore[alphaId_1]);
}
alphaScore[alphaId] = LogAdd(x, y);
}
}
x = logaddk(x, (comp_t) alphaScore[alphaId_0]);
// Calculate beta in forward-backward calculation for RNNT. equation (18) in "Sequence Transduction with Recurrent Neural Networks"
template <class ElemType>
__global__ void _assignRNNTBetaScore(
const ElemType* prob,
ElemType* alphaScore,
ElemType* phoneSeq,
ElemType* phoneBound,
const size_t* uttFrameNum,
const size_t* uttPhoneNum,
const size_t* uttBeginFrame,
const size_t* uttToChanInd,
const size_t* uttBeginForMerge,
const size_t numChannels,
const size_t t,
const size_t u,
const size_t maxPhoneNum, // Maximum length of utterance in this MB
const size_t totalPhoneNum, // Total number of phones
const size_t blankTokenId,
const int delayConstraint)
{
typedef typename TypeSelector<ElemType>::comp_t comp_t;
int uttId = blockDim.x * blockIdx.x + threadIdx.x;
if (phoneId != SIZE_MAX)
ascore = prob[probId]; // Probability of observing given label at given time
else
ascore = 0;
alphaScore[alphaId] = x + ascore;
if (delayConstraint != -1)
{
LONG64 labelid_r = labelid + 2;
LONG64 phoneBoundId_r = (LONG64)(phoneBound[labelid_r]);
if (phoneId == blankTokenId)
{
// only constraint right side
if (t > phoneBoundId_r + delayConstraint - 1)
alphaScore[alphaId] = LZERO;
}
else if (phoneId != blankTokenId)
{
if (t > phoneBoundId_r + delayConstraint)
alphaScore[alphaId] = LZERO;
}
}
}
// Number of phones and frames in this utterance
size_t frameNum = uttFrameNum[uttId];
if (t >= frameNum)
continue;
size_t phoneNum = uttPhoneNum[uttId];
if (u >= phoneNum)
continue;
// Current and previous phone indices in phoneSeq matrix
size_t labelid = uttId * maxPhoneNum + u;
// Actual current phone label
size_t phoneId = (size_t)(phoneSeq[labelid + 1]); //phone ID of u+1
// Index of the current frame in minibatch
size_t timeId = (t + uttBeginFrame[uttId]) * numChannels + uttToChanInd[uttId]; //timeid in chunk for t
// Index of the current frame in minibatch
// size_t unitId = (u + uttBeginPhonePos[uttId] )* numChannels + uttToChanInd[uttId]; //phoneseq id in chunk for u
size_t tuID = uttBeginForMerge[uttId] + t * phoneNum + u; //tuID for (t,u)
// Index of probability of observing phoneId at frame timeId
size_t probId = tuID * totalPhoneNum + phoneId; // ID for p(y(u+1)|t,u)
size_t betaId = maxPhoneNum * timeId + u; //betaid for (t,u)
if (u == phoneNum - 1 && t == frameNum - 1)
{
size_t probId_1 = tuID * totalPhoneNum + blankTokenId; //ID for p(phi|t,u)
betaScore[betaId] = prob[probId_1];
}
else if (u == phoneNum - 1)
{
size_t probId_1 = tuID * totalPhoneNum + blankTokenId; //ID for p(phi|t,u)
size_t betaId_1 = betaId + numChannels * maxPhoneNum; //beta ID for (t+1,u)
betaScore[betaId] = betaScore[betaId_1] + prob[probId_1];
}
else if (t == frameNum - 1)
{
size_t betaId_2 = betaId + 1; //beid for (t,u+1)
betaScore[betaId] = betaScore[betaId_2] + prob[probId];
}
else
{
size_t probId_1 = tuID * totalPhoneNum + blankTokenId; //ID for p(phi|t,u)
size_t betaId_1 = betaId + numChannels * maxPhoneNum; //beta ID for (t+1,u)
size_t betaId_2 = betaId + 1; //beta ID for (t,u+1)
ElemType x = LZERO, y = LZERO;
x = betaScore[betaId_1] + prob[probId_1];
y = betaScore[betaId_2] + prob[probId];
betaScore[betaId] = LogAdd(x, y);
}
}

Просмотреть файл

@ -358,7 +358,7 @@ void CheckEnumValuesNotModified() {
static_cast<size_t>(PrimitiveOpType::Tan) == 95 &&
static_cast<size_t>(PrimitiveOpType::Atan) == 96 &&
static_cast<size_t>(PrimitiveOpType::ConvolutionSequenceShape) == 97 &&
static_cast<size_t>(PrimitiveOpType::RNNT) == 98,
static_cast<size_t>(PrimitiveOpType::RNNT) == 98 &&
static_cast<size_t>(PrimitiveOpType::PlusBroadcast) == 99,
"PrimitiveOpType enum value was modified.");
}