зеркало из https://github.com/mozilla/marian.git
refactored batch-major select() down to State::select() where it belongs
This commit is contained in:
Родитель
e723425804
Коммит
7a7e5ab864
|
@ -27,7 +27,7 @@ public:
|
|||
for(auto i : selIdx)
|
||||
selectedAttentionIndices.push_back(attentionIndices_[i]);
|
||||
|
||||
auto selectedState = New<DecoderStateHardAtt>(states_.select(selIdx, beamSize),
|
||||
auto selectedState = New<DecoderStateHardAtt>(states_.select(selIdx, beamSize, /*isBatchMajor=*/false),
|
||||
probs_,
|
||||
encStates_,
|
||||
batch_);
|
||||
|
@ -235,12 +235,13 @@ public:
|
|||
logits = out->apply(rnnInputs, decContext);
|
||||
}
|
||||
|
||||
auto newState = New<DecoderStateHardAtt>(decStates,
|
||||
auto nextState = New<DecoderStateHardAtt>(decStates,
|
||||
logits,
|
||||
stateHardAtt->getEncoderStates(),
|
||||
stateHardAtt->getBatch());
|
||||
newState->setAttentionIndices(std::vector<size_t>(stateHardAtt->getAttentionIndices()));
|
||||
return newState;
|
||||
nextState->setAttentionIndices(std::vector<size_t>(stateHardAtt->getAttentionIndices()));
|
||||
nextState->setPosition(state->getPosition() + 1); // @TODO: I added this for consistency. Correct?
|
||||
return nextState;
|
||||
}
|
||||
|
||||
const std::vector<Expr> getAlignments() {
|
||||
|
|
|
@ -59,7 +59,7 @@ public:
|
|||
virtual Ptr<DecoderState> selectHyps(const std::vector<size_t>& selIdx,
|
||||
int beamSize) const {
|
||||
auto selectedState = New<DecoderState>(
|
||||
states_.select(selIdx, beamSize), probs_, encStates_, batch_);
|
||||
states_.select(selIdx, beamSize, /*isBatchMajor=*/false), probs_, encStates_, batch_);
|
||||
|
||||
// Set positon of new state based on the target token position of current
|
||||
// state
|
||||
|
|
|
@ -548,25 +548,8 @@ public:
|
|||
|
||||
virtual Ptr<DecoderState> selectHyps(const std::vector<size_t>& selIdx,
|
||||
int beamSize) const override {
|
||||
// @TODO: merge the reordering bits with base DecoderState::select()
|
||||
int dimDepth = states_[0].output->shape()[-1];
|
||||
int dimTime = states_[0].output->shape()[-2];
|
||||
int dimBatch = selIdx.size() / beamSize;
|
||||
|
||||
std::vector<size_t> selIdx2;
|
||||
for(auto i : selIdx)
|
||||
for(int j = 0; j < dimTime; ++j)
|
||||
selIdx2.push_back(i * dimTime + j);
|
||||
|
||||
rnn::States selectedStates;
|
||||
for(const auto& state : states_) {
|
||||
auto sel = rows(flatten_2d(state.output), selIdx2);
|
||||
sel = reshape(sel, {beamSize, dimBatch, dimTime, dimDepth});
|
||||
selectedStates.push_back({sel, nullptr});
|
||||
}
|
||||
|
||||
// Create hypothesis-selected state based on current state and hyp indices
|
||||
auto selectedState = New<TransformerState>(selectedStates, probs_, encStates_, batch_);
|
||||
auto selectedState = New<TransformerState>(states_.select(selIdx, beamSize, /*isBatchMajor=*/true), probs_, encStates_, batch_);
|
||||
|
||||
// Set the same target token position as the current state
|
||||
// @TODO: This is the same as in base function.
|
||||
|
|
|
@ -12,28 +12,42 @@ struct State {
|
|||
Expr output;
|
||||
Expr cell;
|
||||
|
||||
// @TODO: This version only for time-major. Add flag so that this can be shared.
|
||||
State select(const std::vector<size_t>& indices, // [beamIndex * activeBatchSize + batchIndex]
|
||||
int beamSize) const {
|
||||
auto selectedOutput = output; // [beamSize, dimTime, dimBatch, dimDepth] (dimTime = 1 for RNN)
|
||||
auto selectedCell = cell; // [beamSize, dimTime, dimBatch, dimDepth]
|
||||
State select(const std::vector<size_t>& selIdx, // [beamIndex * activeBatchSize + batchIndex]
|
||||
int beamSize, bool isBatchMajor) const {
|
||||
auto selectedOutput = output; // [beamSize, dimTime, dimBatch, dimDepth] or [beamSize, dimBatch, dimTime, dimDepth] (dimTime = 1 for RNN)
|
||||
auto selectedCell = cell; // [beamSize, dimTime, dimBatch, dimDepth] or [beamSize, dimBatch, dimTime, dimDepth]
|
||||
|
||||
selectedOutput = atleast_4d(selectedOutput);
|
||||
|
||||
int dimBatch = selIdx.size() / beamSize;
|
||||
int dimDepth = selectedOutput->shape()[-1];
|
||||
int dimTime = selectedOutput->shape()[-3];
|
||||
int dimTime = isBatchMajor ? selectedOutput->shape()[-2] : selectedOutput->shape()[-3];
|
||||
|
||||
int dimBatch = indices.size() / beamSize;
|
||||
if (isBatchMajor) {
|
||||
// @TODO: I think this can be done more efficiently by not using flatten_2d(), but instead merging dimTime with dimDepth
|
||||
std::vector<size_t> selIdx2;
|
||||
for (auto i : selIdx)
|
||||
for (int j = 0; j < dimTime; ++j)
|
||||
selIdx2.push_back(i * dimTime + j);
|
||||
|
||||
selectedOutput = reshape(rows(flatten_2d(selectedOutput), indices),
|
||||
{ beamSize, dimTime, dimBatch, dimDepth });
|
||||
if (selectedCell)
|
||||
{
|
||||
selectedCell = atleast_4d(selectedCell);
|
||||
selectedCell = reshape(rows(flatten_2d(selectedCell), indices),
|
||||
{ beamSize, dimTime, dimBatch, dimDepth });
|
||||
selectedOutput = flatten_2d(selectedOutput);
|
||||
selectedOutput = rows(selectedOutput, selIdx2);
|
||||
selectedOutput = reshape(selectedOutput, { beamSize, isBatchMajor ? dimBatch : dimTime, isBatchMajor ? dimTime : dimBatch, dimDepth });
|
||||
ABORT_IF(selectedCell, "selectedCell must be null for Transformer");
|
||||
} else {
|
||||
ABORT_IF(dimTime != 1, "unexpected time extent for RNN state");
|
||||
selectedOutput = flatten_2d(selectedOutput);
|
||||
selectedOutput = rows(selectedOutput, selIdx);
|
||||
selectedOutput = reshape(selectedOutput, { beamSize, isBatchMajor ? dimBatch : dimTime, isBatchMajor ? dimTime : dimBatch, dimDepth });
|
||||
if (selectedCell)
|
||||
{
|
||||
selectedCell = atleast_4d(selectedCell);
|
||||
selectedCell = flatten_2d(selectedCell);
|
||||
selectedCell = rows(selectedCell, selIdx);
|
||||
selectedCell = reshape(selectedCell, { beamSize, isBatchMajor ? dimBatch : dimTime, isBatchMajor ? dimTime : dimBatch, dimDepth });
|
||||
}
|
||||
}
|
||||
return { selectedOutput, selectedCell };
|
||||
return{ selectedOutput, selectedCell };
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -75,11 +89,11 @@ public:
|
|||
void push_back(const State& state) { states_.push_back(state); }
|
||||
|
||||
// create updated set of states that reflect reordering and dropping of hypotheses
|
||||
States select(const std::vector<size_t>& indices, // [beamIndex * activeBatchSize + batchIndex]
|
||||
int beamSize) const {
|
||||
States select(const std::vector<size_t>& selIdx, // [beamIndex * activeBatchSize + batchIndex]
|
||||
int beamSize, bool isBatchMajor) const {
|
||||
States selected;
|
||||
for(auto& state : states_)
|
||||
selected.push_back(state.select(indices, beamSize));
|
||||
selected.push_back(state.select(selIdx, beamSize, isBatchMajor));
|
||||
return selected;
|
||||
}
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче