temp gpu 2
This commit is contained in:
Родитель
f2c8084510
Коммит
5a2fb887d8
|
@ -4736,10 +4736,7 @@ uttPhoneBeginIdx;
|
|||
SyncGuard syncGuard;
|
||||
_assignRNNTScore(Data(), prob.Data(), alpha.Data(), beta.Data(), phoneSeq.Data(), uttNum, uttFrameNum, uttPhoneNum, uttFrameBeginIdx, uttFrameToChanInd,
|
||||
uttBeginForOutputditribution, numParallelSequences, maxPhoneNum, totalPhoneNum, blankTokenId);
|
||||
|
||||
|
||||
_assignCTCScore << < block_tail_2, thread_tail, 0, t_stream >> >(Data(), prob.Data(), alpha.Data(), beta.Data(), phoneSeq.Data(), uttNum, gpuUttToChanInd,
|
||||
gpuBeginFrame, gpuPhoneNum, gpuFrameNum, numParallelSequences, maxPhoneNum, totalPhoneNum);
|
||||
|
||||
|
||||
CUDA_CALL(cudaFree(gpuFrameNum));
|
||||
CUDA_CALL(cudaFree(gpuPhoneNum));
|
||||
|
|
|
@ -5824,89 +5824,152 @@ __global__ void _assignRNNTAlphaScore(
|
|||
const int delayConstraint)
|
||||
{
|
||||
typedef typename TypeSelector<ElemType>::comp_t comp_t;
|
||||
int uttid = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
int uttId = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
|
||||
|
||||
// Number of phones and frames in this utterance
|
||||
LONG64 phoneNum = uttPhoneNum[uttId];
|
||||
LONG64 frameNum = uttFrameNum[uttId];
|
||||
size_t frameNum = uttFrameNum[uttId];
|
||||
if (t >= frameNum)
|
||||
continue;
|
||||
|
||||
if (uttId >= uttNum || phoneSeqId >= phoneNum - 1 || t >= frameNum || phoneSeqId == 0)
|
||||
return;
|
||||
size_t phoneNum = uttPhoneNum[uttId];
|
||||
if (u >= phoneNum)
|
||||
continue;
|
||||
|
||||
// Current and previous phone indices in phoneSeq matrix
|
||||
LONG64 labelid = uttId * maxPhoneNum + phoneSeqId;
|
||||
LONG64 labelid_2 = labelid - 2;
|
||||
// Current phone indices in phoneSeq matrix
|
||||
size_t labelid = uttId * maxPhoneNum + u;
|
||||
|
||||
// Actual current phone label
|
||||
LONG64 phoneId = (LONG64)(phoneSeq[labelid]);
|
||||
size_t phoneId = (size_t)(phoneSeq[labelid]); //phone ID of u
|
||||
|
||||
// Index of the current frame in minibatch
|
||||
LONG64 timeId = (t + uttBeginFrame[uttId]) * numChannels + uttToChanInd[uttId];
|
||||
//time Index of the current frame in minibatch
|
||||
size_t timeId = (t + uttBeginFrame[uttId]) * numChannels + uttToChanInd[uttId];
|
||||
|
||||
// Index of probability of observing phoneId at frame timeId
|
||||
LONG64 probId = timeId * totalPhoneNum + phoneId;
|
||||
// phone Index of the current frame in minibatch
|
||||
// size_t unitId = (u + uttBeginPhonePos[uttId]) * numChannels + uttToChanInd[uttId];
|
||||
|
||||
LONG64 alphaId = maxPhoneNum * timeId + phoneSeqId; // alpha_t(s)
|
||||
// Previous time frame
|
||||
LONG64 timeId_1 = timeId - numChannels; // Index corresponding to (t-1)
|
||||
LONG64 alphaId_0 = maxPhoneNum * timeId_1 + phoneSeqId; // alpha_{t-1}(s)
|
||||
LONG64 alphaId_1 = alphaId_0 - 1; // alpha_{t-1}(s-1)
|
||||
LONG64 alphaId_2 = alphaId_0 - 2; // alpha_{t-1}(s-2)
|
||||
//(t,u) index of outputdistribution in minibatch
|
||||
size_t tuID = uttBeginForMerge[uttId] + t * phoneNum + u; //tuID for (t,u)
|
||||
|
||||
if (t == 0)
|
||||
// Index of outputdistribution of observing phoneId at frame timeId
|
||||
//size_t probId = tuID*totalPhoneNum + phoneId;// ID for p(y(u)|t,u)
|
||||
|
||||
//index for alpha
|
||||
size_t alphaId = maxPhoneNum * timeId + u; // alpha_t(s)
|
||||
|
||||
if (t == 0 && u == 0)
|
||||
{
|
||||
// Initialize recursion
|
||||
if (phoneSeqId == 1 || phoneSeqId == 2)
|
||||
{
|
||||
alphaScore[alphaId] = prob[probId];
|
||||
}
|
||||
alphaScore[alphaId] = 0.0;
|
||||
}
|
||||
else if (t == 0)
|
||||
{
|
||||
size_t alphaId_1 = alphaId - 1; // alpha ID for [t,u-1]
|
||||
size_t tuID_1 = tuID - 1; //tuID for [t,u-1]
|
||||
size_t probId_1 = tuID_1 * totalPhoneNum + phoneId; //ID for p(y(u)|t,u-1)
|
||||
|
||||
alphaScore[alphaId] = alphaScore[alphaId_1] + prob[probId_1];
|
||||
}
|
||||
else if (u == 0)
|
||||
{
|
||||
size_t tuID_2 = tuID - phoneNum; //tuID for [t-1,u]
|
||||
size_t alphaId_2 = alphaId - numChannels * maxPhoneNum; //alpha ID for [t-1, u]
|
||||
size_t probId_2 = tuID_2 * totalPhoneNum + blankTokenId; //ID for p(phi|t-1,u)
|
||||
alphaScore[alphaId] = alphaScore[alphaId_2] + prob[probId_2];
|
||||
}
|
||||
else
|
||||
{
|
||||
if (phoneSeqId >= 1)
|
||||
{
|
||||
comp_t x = LZERO;
|
||||
size_t alphaId_1 = alphaId - 1; // alpha ID for [t,u-1]
|
||||
size_t tuID_1 = tuID - 1; //tuID for [t,u-1]
|
||||
size_t probId_1 = tuID_1 * totalPhoneNum + phoneId; //ID for p(y(u)|t,u-1)
|
||||
size_t tuID_2 = tuID - phoneNum; //tuID for [t-1,u]
|
||||
size_t alphaId_2 = alphaId - numChannels * maxPhoneNum; //alpha ID for [t-1, u]
|
||||
size_t probId_2 = tuID_2 * totalPhoneNum + blankTokenId; //ID for p(phi|t-1,u)
|
||||
|
||||
comp_t ascore;
|
||||
if (phoneSeqId > 2)
|
||||
{
|
||||
// if current label is not blank and not equal prev non-blank label
|
||||
if ((LONG64)(phoneSeq[labelid]) != blankTokenId && phoneId != (LONG64)(phoneSeq[labelid_2]))
|
||||
{
|
||||
x = logaddk(x, (comp_t) alphaScore[alphaId_2]);
|
||||
}
|
||||
}
|
||||
ElemType x = LZERO, y = LZERO;
|
||||
x = alphaScore[alphaId_1] + prob[probId_1];
|
||||
y = alphaScore[alphaId_2] + prob[probId_2];
|
||||
|
||||
if (phoneSeqId > 1)
|
||||
{
|
||||
x = logaddk(x, (comp_t) alphaScore[alphaId_1]);
|
||||
}
|
||||
alphaScore[alphaId] = LogAdd(x, y);
|
||||
}
|
||||
}
|
||||
|
||||
x = logaddk(x, (comp_t) alphaScore[alphaId_0]);
|
||||
// Calculate beta in forward-backward calculation for RNNT. equation (18) in "Sequence Transduction with Recurrent Neural Networks"
|
||||
template <class ElemType>
|
||||
__global__ void _assignRNNTBetaScore(
|
||||
const ElemType* prob,
|
||||
ElemType* alphaScore,
|
||||
ElemType* phoneSeq,
|
||||
ElemType* phoneBound,
|
||||
const size_t* uttFrameNum,
|
||||
const size_t* uttPhoneNum,
|
||||
const size_t* uttBeginFrame,
|
||||
const size_t* uttToChanInd,
|
||||
const size_t* uttBeginForMerge,
|
||||
const size_t numChannels,
|
||||
const size_t t,
|
||||
const size_t u,
|
||||
const size_t maxPhoneNum, // Maximum length of utterance in this MB
|
||||
const size_t totalPhoneNum, // Total number of phones
|
||||
const size_t blankTokenId,
|
||||
const int delayConstraint)
|
||||
{
|
||||
typedef typename TypeSelector<ElemType>::comp_t comp_t;
|
||||
int uttId = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
|
||||
if (phoneId != SIZE_MAX)
|
||||
ascore = prob[probId]; // Probability of observing given label at given time
|
||||
else
|
||||
ascore = 0;
|
||||
alphaScore[alphaId] = x + ascore;
|
||||
if (delayConstraint != -1)
|
||||
{
|
||||
LONG64 labelid_r = labelid + 2;
|
||||
LONG64 phoneBoundId_r = (LONG64)(phoneBound[labelid_r]);
|
||||
if (phoneId == blankTokenId)
|
||||
{
|
||||
// only constraint right side
|
||||
if (t > phoneBoundId_r + delayConstraint - 1)
|
||||
alphaScore[alphaId] = LZERO;
|
||||
}
|
||||
else if (phoneId != blankTokenId)
|
||||
{
|
||||
if (t > phoneBoundId_r + delayConstraint)
|
||||
alphaScore[alphaId] = LZERO;
|
||||
}
|
||||
}
|
||||
}
|
||||
// Number of phones and frames in this utterance
|
||||
size_t frameNum = uttFrameNum[uttId];
|
||||
if (t >= frameNum)
|
||||
continue;
|
||||
|
||||
size_t phoneNum = uttPhoneNum[uttId];
|
||||
if (u >= phoneNum)
|
||||
continue;
|
||||
|
||||
// Current and previous phone indices in phoneSeq matrix
|
||||
size_t labelid = uttId * maxPhoneNum + u;
|
||||
|
||||
// Actual current phone label
|
||||
size_t phoneId = (size_t)(phoneSeq[labelid + 1]); //phone ID of u+1
|
||||
|
||||
// Index of the current frame in minibatch
|
||||
size_t timeId = (t + uttBeginFrame[uttId]) * numChannels + uttToChanInd[uttId]; //timeid in chunk for t
|
||||
|
||||
// Index of the current frame in minibatch
|
||||
// size_t unitId = (u + uttBeginPhonePos[uttId] )* numChannels + uttToChanInd[uttId]; //phoneseq id in chunk for u
|
||||
size_t tuID = uttBeginForMerge[uttId] + t * phoneNum + u; //tuID for (t,u)
|
||||
|
||||
// Index of probability of observing phoneId at frame timeId
|
||||
size_t probId = tuID * totalPhoneNum + phoneId; // ID for p(y(u+1)|t,u)
|
||||
|
||||
size_t betaId = maxPhoneNum * timeId + u; //betaid for (t,u)
|
||||
|
||||
if (u == phoneNum - 1 && t == frameNum - 1)
|
||||
{
|
||||
size_t probId_1 = tuID * totalPhoneNum + blankTokenId; //ID for p(phi|t,u)
|
||||
betaScore[betaId] = prob[probId_1];
|
||||
}
|
||||
else if (u == phoneNum - 1)
|
||||
{
|
||||
size_t probId_1 = tuID * totalPhoneNum + blankTokenId; //ID for p(phi|t,u)
|
||||
size_t betaId_1 = betaId + numChannels * maxPhoneNum; //beta ID for (t+1,u)
|
||||
betaScore[betaId] = betaScore[betaId_1] + prob[probId_1];
|
||||
}
|
||||
else if (t == frameNum - 1)
|
||||
{
|
||||
size_t betaId_2 = betaId + 1; //beid for (t,u+1)
|
||||
betaScore[betaId] = betaScore[betaId_2] + prob[probId];
|
||||
}
|
||||
else
|
||||
{
|
||||
size_t probId_1 = tuID * totalPhoneNum + blankTokenId; //ID for p(phi|t,u)
|
||||
size_t betaId_1 = betaId + numChannels * maxPhoneNum; //beta ID for (t+1,u)
|
||||
size_t betaId_2 = betaId + 1; //beta ID for (t,u+1)
|
||||
|
||||
ElemType x = LZERO, y = LZERO;
|
||||
x = betaScore[betaId_1] + prob[probId_1];
|
||||
y = betaScore[betaId_2] + prob[probId];
|
||||
|
||||
betaScore[betaId] = LogAdd(x, y);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -358,7 +358,7 @@ void CheckEnumValuesNotModified() {
|
|||
static_cast<size_t>(PrimitiveOpType::Tan) == 95 &&
|
||||
static_cast<size_t>(PrimitiveOpType::Atan) == 96 &&
|
||||
static_cast<size_t>(PrimitiveOpType::ConvolutionSequenceShape) == 97 &&
|
||||
static_cast<size_t>(PrimitiveOpType::RNNT) == 98,
|
||||
static_cast<size_t>(PrimitiveOpType::RNNT) == 98 &&
|
||||
static_cast<size_t>(PrimitiveOpType::PlusBroadcast) == 99,
|
||||
"PrimitiveOpType enum value was modified.");
|
||||
}
|
||||
|
|
Загрузка…
Ссылка в новой задаче