зеркало из https://github.com/mozilla/marian.git
Fix batch translation for lstm/mlstm
Batch translation often gave results that were off by a word or two. This was due to the cell-state not being cleared in the encoder's backward pass for the smaller sentences in the batch.
This commit is contained in:
Родитель
9116172903
Коммит
c0e9de9d00
|
@ -92,6 +92,9 @@ class Encoder {
|
||||||
//std::cerr << "mapping=" << mblas::Debug(*mapping) << std::endl;
|
//std::cerr << "mapping=" << mblas::Debug(*mapping) << std::endl;
|
||||||
//mblas::MapMatrix(*(State_.cell), *sentencesMask, n - i - 1);
|
//mblas::MapMatrix(*(State_.cell), *sentencesMask, n - i - 1);
|
||||||
mblas::MapMatrix(*(State_.output), *sentencesMask, n - i - 1);
|
mblas::MapMatrix(*(State_.output), *sentencesMask, n - i - 1);
|
||||||
|
if (State_.cell->size()) {
|
||||||
|
mblas::MapMatrix(*(State_.cell), *sentencesMask, n - i - 1);
|
||||||
|
}
|
||||||
//std::cerr << "2State_=" << State_.Debug(1) << std::endl;
|
//std::cerr << "2State_=" << State_.Debug(1) << std::endl;
|
||||||
|
|
||||||
mblas::PasteRows(Context, *(State_.output), (n - i - 1), gru_->GetStateLength().output);
|
mblas::PasteRows(Context, *(State_.output), (n - i - 1), gru_->GetStateLength().output);
|
||||||
|
|
Загрузка…
Ссылка в новой задаче