Fix batch translation for lstm/mlstm

Batch translation often gave results that were off by a word or two.
This was due to the cell-state not being cleared in the encoder's
backward pass for the smaller sentences in the batch.
This commit is contained in:
Rihards Krišlauks 2017-11-16 10:10:01 +02:00
Родитель 9116172903
Коммит c0e9de9d00
1 изменённых файлов: 3 добавлений и 0 удалений

Просмотреть файл

@ -92,6 +92,9 @@ class Encoder {
//std::cerr << "mapping=" << mblas::Debug(*mapping) << std::endl; //std::cerr << "mapping=" << mblas::Debug(*mapping) << std::endl;
//mblas::MapMatrix(*(State_.cell), *sentencesMask, n - i - 1); //mblas::MapMatrix(*(State_.cell), *sentencesMask, n - i - 1);
mblas::MapMatrix(*(State_.output), *sentencesMask, n - i - 1); mblas::MapMatrix(*(State_.output), *sentencesMask, n - i - 1);
if (State_.cell->size()) {
mblas::MapMatrix(*(State_.cell), *sentencesMask, n - i - 1);
}
//std::cerr << "2State_=" << State_.Debug(1) << std::endl; //std::cerr << "2State_=" << State_.Debug(1) << std::endl;
mblas::PasteRows(Context, *(State_.output), (n - i - 1), gru_->GetStateLength().output); mblas::PasteRows(Context, *(State_.output), (n - i - 1), gru_->GetStateLength().output);