refactored batch-major select() down to State::select() where it belongs

This commit is contained in:
Frank Seide 2018-08-20 11:55:20 -07:00
Родитель e723425804
Коммит 7a7e5ab864
4 изменённых файлов: 39 добавлений и 41 удалений

Просмотреть файл

@ -27,7 +27,7 @@ public:
for(auto i : selIdx)
selectedAttentionIndices.push_back(attentionIndices_[i]);
auto selectedState = New<DecoderStateHardAtt>(states_.select(selIdx, beamSize),
auto selectedState = New<DecoderStateHardAtt>(states_.select(selIdx, beamSize, /*isBatchMajor=*/false),
probs_,
encStates_,
batch_);
@ -235,12 +235,13 @@ public:
logits = out->apply(rnnInputs, decContext);
}
auto newState = New<DecoderStateHardAtt>(decStates,
auto nextState = New<DecoderStateHardAtt>(decStates,
logits,
stateHardAtt->getEncoderStates(),
stateHardAtt->getBatch());
newState->setAttentionIndices(std::vector<size_t>(stateHardAtt->getAttentionIndices()));
return newState;
nextState->setAttentionIndices(std::vector<size_t>(stateHardAtt->getAttentionIndices()));
nextState->setPosition(state->getPosition() + 1); // @TODO: I added this for consistency. Correct?
return nextState;
}
const std::vector<Expr> getAlignments() {

Просмотреть файл

@ -59,7 +59,7 @@ public:
virtual Ptr<DecoderState> selectHyps(const std::vector<size_t>& selIdx,
int beamSize) const {
auto selectedState = New<DecoderState>(
states_.select(selIdx, beamSize), probs_, encStates_, batch_);
states_.select(selIdx, beamSize, /*isBatchMajor=*/false), probs_, encStates_, batch_);
// Set positon of new state based on the target token position of current
// state

Просмотреть файл

@ -548,25 +548,8 @@ public:
virtual Ptr<DecoderState> selectHyps(const std::vector<size_t>& selIdx,
int beamSize) const override {
// @TODO: merge the reordering bits with base DecoderState::select()
int dimDepth = states_[0].output->shape()[-1];
int dimTime = states_[0].output->shape()[-2];
int dimBatch = selIdx.size() / beamSize;
std::vector<size_t> selIdx2;
for(auto i : selIdx)
for(int j = 0; j < dimTime; ++j)
selIdx2.push_back(i * dimTime + j);
rnn::States selectedStates;
for(const auto& state : states_) {
auto sel = rows(flatten_2d(state.output), selIdx2);
sel = reshape(sel, {beamSize, dimBatch, dimTime, dimDepth});
selectedStates.push_back({sel, nullptr});
}
// Create hypothesis-selected state based on current state and hyp indices
auto selectedState = New<TransformerState>(selectedStates, probs_, encStates_, batch_);
auto selectedState = New<TransformerState>(states_.select(selIdx, beamSize, /*isBatchMajor=*/true), probs_, encStates_, batch_);
// Set the same target token position as the current state
// @TODO: This is the same as in base function.

Просмотреть файл

@ -12,28 +12,42 @@ struct State {
Expr output;
Expr cell;
// @TODO: This version only for time-major. Add flag so that this can be shared.
State select(const std::vector<size_t>& indices, // [beamIndex * activeBatchSize + batchIndex]
int beamSize) const {
auto selectedOutput = output; // [beamSize, dimTime, dimBatch, dimDepth] (dimTime = 1 for RNN)
auto selectedCell = cell; // [beamSize, dimTime, dimBatch, dimDepth]
State select(const std::vector<size_t>& selIdx, // [beamIndex * activeBatchSize + batchIndex]
int beamSize, bool isBatchMajor) const {
auto selectedOutput = output; // [beamSize, dimTime, dimBatch, dimDepth] or [beamSize, dimBatch, dimTime, dimDepth] (dimTime = 1 for RNN)
auto selectedCell = cell; // [beamSize, dimTime, dimBatch, dimDepth] or [beamSize, dimBatch, dimTime, dimDepth]
selectedOutput = atleast_4d(selectedOutput);
int dimBatch = selIdx.size() / beamSize;
int dimDepth = selectedOutput->shape()[-1];
int dimTime = selectedOutput->shape()[-3];
int dimTime = isBatchMajor ? selectedOutput->shape()[-2] : selectedOutput->shape()[-3];
int dimBatch = indices.size() / beamSize;
if (isBatchMajor) {
// @TODO: I think this can be done more efficiently by not using flatten_2d(), but instead merging dimTime with dimDepth
std::vector<size_t> selIdx2;
for (auto i : selIdx)
for (int j = 0; j < dimTime; ++j)
selIdx2.push_back(i * dimTime + j);
selectedOutput = reshape(rows(flatten_2d(selectedOutput), indices),
{ beamSize, dimTime, dimBatch, dimDepth });
selectedOutput = flatten_2d(selectedOutput);
selectedOutput = rows(selectedOutput, selIdx2);
selectedOutput = reshape(selectedOutput, { beamSize, isBatchMajor ? dimBatch : dimTime, isBatchMajor ? dimTime : dimBatch, dimDepth });
ABORT_IF(selectedCell, "selectedCell must be null for Transformer");
} else {
ABORT_IF(dimTime != 1, "unexpected time extent for RNN state");
selectedOutput = flatten_2d(selectedOutput);
selectedOutput = rows(selectedOutput, selIdx);
selectedOutput = reshape(selectedOutput, { beamSize, isBatchMajor ? dimBatch : dimTime, isBatchMajor ? dimTime : dimBatch, dimDepth });
if (selectedCell)
{
selectedCell = atleast_4d(selectedCell);
selectedCell = reshape(rows(flatten_2d(selectedCell), indices),
{ beamSize, dimTime, dimBatch, dimDepth });
selectedCell = flatten_2d(selectedCell);
selectedCell = rows(selectedCell, selIdx);
selectedCell = reshape(selectedCell, { beamSize, isBatchMajor ? dimBatch : dimTime, isBatchMajor ? dimTime : dimBatch, dimDepth });
}
return { selectedOutput, selectedCell };
}
return{ selectedOutput, selectedCell };
}
};
@ -75,11 +89,11 @@ public:
void push_back(const State& state) { states_.push_back(state); }
// create updated set of states that reflect reordering and dropping of hypotheses
States select(const std::vector<size_t>& indices, // [beamIndex * activeBatchSize + batchIndex]
int beamSize) const {
States select(const std::vector<size_t>& selIdx, // [beamIndex * activeBatchSize + batchIndex]
int beamSize, bool isBatchMajor) const {
States selected;
for(auto& state : states_)
selected.push_back(state.select(indices, beamSize));
selected.push_back(state.select(selIdx, beamSize, isBatchMajor));
return selected;
}