diff --git a/src/models/hardatt.h b/src/models/hardatt.h index 546b1206..81aecd22 100755 --- a/src/models/hardatt.h +++ b/src/models/hardatt.h @@ -27,7 +27,7 @@ public: for(auto i : selIdx) selectedAttentionIndices.push_back(attentionIndices_[i]); - auto selectedState = New(states_.select(selIdx, beamSize), + auto selectedState = New(states_.select(selIdx, beamSize, /*isBatchMajor=*/false), probs_, encStates_, batch_); @@ -235,12 +235,13 @@ public: logits = out->apply(rnnInputs, decContext); } - auto newState = New(decStates, + auto nextState = New(decStates, logits, stateHardAtt->getEncoderStates(), stateHardAtt->getBatch()); - newState->setAttentionIndices(std::vector(stateHardAtt->getAttentionIndices())); - return newState; + nextState->setAttentionIndices(std::vector(stateHardAtt->getAttentionIndices())); + nextState->setPosition(state->getPosition() + 1); // @TODO: I added this for consistency. Correct? + return nextState; } const std::vector getAlignments() { diff --git a/src/models/states.h b/src/models/states.h index 8b30075a..c7103466 100755 --- a/src/models/states.h +++ b/src/models/states.h @@ -59,7 +59,7 @@ public: virtual Ptr selectHyps(const std::vector& selIdx, int beamSize) const { auto selectedState = New( - states_.select(selIdx, beamSize), probs_, encStates_, batch_); + states_.select(selIdx, beamSize, /*isBatchMajor=*/false), probs_, encStates_, batch_); // Set positon of new state based on the target token position of current // state diff --git a/src/models/transformer.h b/src/models/transformer.h index a1790433..ca498643 100755 --- a/src/models/transformer.h +++ b/src/models/transformer.h @@ -548,25 +548,8 @@ public: virtual Ptr selectHyps(const std::vector& selIdx, int beamSize) const override { - // @TODO: merge the reordering bits with base DecoderState::select() - int dimDepth = states_[0].output->shape()[-1]; - int dimTime = states_[0].output->shape()[-2]; - int dimBatch = selIdx.size() / beamSize; - - std::vector selIdx2; - for(auto i : selIdx) - for(int j = 0; j < dimTime; ++j) - selIdx2.push_back(i * dimTime + j); - - rnn::States selectedStates; - for(const auto& state : states_) { - auto sel = rows(flatten_2d(state.output), selIdx2); - sel = reshape(sel, {beamSize, dimBatch, dimTime, dimDepth}); - selectedStates.push_back({sel, nullptr}); - } - // Create hypothesis-selected state based on current state and hyp indices - auto selectedState = New(selectedStates, probs_, encStates_, batch_); + auto selectedState = New(states_.select(selIdx, beamSize, /*isBatchMajor=*/true), probs_, encStates_, batch_); // Set the same target token position as the current state // @TODO: This is the same as in base function. diff --git a/src/rnn/types.h b/src/rnn/types.h index d20fa064..b90b6cca 100755 --- a/src/rnn/types.h +++ b/src/rnn/types.h @@ -12,28 +12,42 @@ struct State { Expr output; Expr cell; - // @TODO: This version only for time-major. Add flag so that this can be shared. - State select(const std::vector& indices, // [beamIndex * activeBatchSize + batchIndex] - int beamSize) const { - auto selectedOutput = output; // [beamSize, dimTime, dimBatch, dimDepth] (dimTime = 1 for RNN) - auto selectedCell = cell; // [beamSize, dimTime, dimBatch, dimDepth] + State select(const std::vector& selIdx, // [beamIndex * activeBatchSize + batchIndex] + int beamSize, bool isBatchMajor) const { + auto selectedOutput = output; // [beamSize, dimTime, dimBatch, dimDepth] or [beamSize, dimBatch, dimTime, dimDepth] (dimTime = 1 for RNN) + auto selectedCell = cell; // [beamSize, dimTime, dimBatch, dimDepth] or [beamSize, dimBatch, dimTime, dimDepth] selectedOutput = atleast_4d(selectedOutput); + int dimBatch = selIdx.size() / beamSize; int dimDepth = selectedOutput->shape()[-1]; - int dimTime = selectedOutput->shape()[-3]; + int dimTime = isBatchMajor ? selectedOutput->shape()[-2] : selectedOutput->shape()[-3]; - int dimBatch = indices.size() / beamSize; + if (isBatchMajor) { + // @TODO: I think this can be done more efficiently by not using flatten_2d(), but instead merging dimTime with dimDepth + std::vector selIdx2; + for (auto i : selIdx) + for (int j = 0; j < dimTime; ++j) + selIdx2.push_back(i * dimTime + j); - selectedOutput = reshape(rows(flatten_2d(selectedOutput), indices), - { beamSize, dimTime, dimBatch, dimDepth }); - if (selectedCell) - { - selectedCell = atleast_4d(selectedCell); - selectedCell = reshape(rows(flatten_2d(selectedCell), indices), - { beamSize, dimTime, dimBatch, dimDepth }); + selectedOutput = flatten_2d(selectedOutput); + selectedOutput = rows(selectedOutput, selIdx2); + selectedOutput = reshape(selectedOutput, { beamSize, isBatchMajor ? dimBatch : dimTime, isBatchMajor ? dimTime : dimBatch, dimDepth }); + ABORT_IF(selectedCell, "selectedCell must be null for Transformer"); + } else { + ABORT_IF(dimTime != 1, "unexpected time extent for RNN state"); + selectedOutput = flatten_2d(selectedOutput); + selectedOutput = rows(selectedOutput, selIdx); + selectedOutput = reshape(selectedOutput, { beamSize, isBatchMajor ? dimBatch : dimTime, isBatchMajor ? dimTime : dimBatch, dimDepth }); + if (selectedCell) + { + selectedCell = atleast_4d(selectedCell); + selectedCell = flatten_2d(selectedCell); + selectedCell = rows(selectedCell, selIdx); + selectedCell = reshape(selectedCell, { beamSize, isBatchMajor ? dimBatch : dimTime, isBatchMajor ? dimTime : dimBatch, dimDepth }); + } } - return { selectedOutput, selectedCell }; + return{ selectedOutput, selectedCell }; } }; @@ -75,11 +89,11 @@ public: void push_back(const State& state) { states_.push_back(state); } // create updated set of states that reflect reordering and dropping of hypotheses - States select(const std::vector& indices, // [beamIndex * activeBatchSize + batchIndex] - int beamSize) const { + States select(const std::vector& selIdx, // [beamIndex * activeBatchSize + batchIndex] + int beamSize, bool isBatchMajor) const { States selected; for(auto& state : states_) - selected.push_back(state.select(indices, beamSize)); + selected.push_back(state.select(selIdx, beamSize, isBatchMajor)); return selected; }