Unify soft alignments in decoder and scorer

This commit is contained in:
Roman Grundkiewicz 2018-08-17 10:49:48 +01:00
Родитель 883ad0b97b
Коммит cc8cc4a31d
4 изменённых файлов: 14 добавлений и 22 удалений

Просмотреть файл

@ -34,7 +34,6 @@ std::string WordAlignment::toString() const {
WordAlignment ConvertSoftAlignToHardAlign(SoftAlignment alignSoft, WordAlignment ConvertSoftAlignToHardAlign(SoftAlignment alignSoft,
float threshold /*= 1.f*/, float threshold /*= 1.f*/,
bool reversed /*= true*/,
bool skipEOS /*= false*/) { bool skipEOS /*= false*/) {
size_t shift = alignSoft.size() > 0 && skipEOS ? 1 : 0; size_t shift = alignSoft.size() > 0 && skipEOS ? 1 : 0;
WordAlignment align; WordAlignment align;
@ -42,10 +41,9 @@ WordAlignment ConvertSoftAlignToHardAlign(SoftAlignment alignSoft,
if(threshold == 1.f) { if(threshold == 1.f) {
for(size_t t = 0; t < alignSoft.size() - shift; ++t) { for(size_t t = 0; t < alignSoft.size() - shift; ++t) {
// Retrieved alignments are in reversed order // Retrieved alignments are in reversed order
size_t rev = reversed ? alignSoft.size() - t - 1 : t;
size_t maxArg = 0; size_t maxArg = 0;
for(size_t s = 0; s < alignSoft[0].size(); ++s) { for(size_t s = 0; s < alignSoft[0].size(); ++s) {
if(alignSoft[rev][maxArg] < alignSoft[rev][s]) { if(alignSoft[t][maxArg] < alignSoft[t][s]) {
maxArg = s; maxArg = s;
} }
} }
@ -55,9 +53,8 @@ WordAlignment ConvertSoftAlignToHardAlign(SoftAlignment alignSoft,
// Alignments by greather-than-threshold // Alignments by greather-than-threshold
for(size_t t = 0; t < alignSoft.size() - shift; ++t) { for(size_t t = 0; t < alignSoft.size() - shift; ++t) {
// Retrieved alignments are in reversed order // Retrieved alignments are in reversed order
size_t rev = reversed ? alignSoft.size() - t - 1 : t;
for(size_t s = 0; s < alignSoft[0].size(); ++s) { for(size_t s = 0; s < alignSoft[0].size(); ++s) {
if(alignSoft[rev][s] > threshold) { if(alignSoft[t][s] > threshold) {
align.push_back(s, t); align.push_back(s, t);
} }
} }
@ -70,20 +67,17 @@ WordAlignment ConvertSoftAlignToHardAlign(SoftAlignment alignSoft,
return align; return align;
} }
std::string SoftAlignToString(SoftAlignment align, std::string SoftAlignToString(SoftAlignment align, bool skipEOS /*= false*/) {
bool reversed /*= true*/,
bool skipEOS /*= false*/) {
std::stringstream str; std::stringstream str;
size_t shift = align.size() > 0 && skipEOS ? 1 : 0; size_t shift = align.size() > 0 && skipEOS ? 1 : 0;
bool first = true; bool first = true;
for(size_t t = 0; t < align.size() - shift; ++t) { for(size_t t = 0; t < align.size() - shift; ++t) {
size_t rev = reversed ? align.size() - t - 1 : t;
if(!first) if(!first)
str << " "; str << " ";
for(size_t s = 0; s < align[rev].size(); ++s) { for(size_t s = 0; s < align[t].size(); ++s) {
if(s != 0) if(s != 0)
str << ","; str << ",";
str << align[rev][s]; str << align[t][s];
} }
first = false; first = false;
} }

Просмотреть файл

@ -51,12 +51,9 @@ typedef std::vector<std::vector<float>> SoftAlignment;
WordAlignment ConvertSoftAlignToHardAlign(SoftAlignment alignSoft, WordAlignment ConvertSoftAlignToHardAlign(SoftAlignment alignSoft,
float threshold = 1.f, float threshold = 1.f,
bool reversed = true,
bool skipEOS = false); bool skipEOS = false);
std::string SoftAlignToString(SoftAlignment align, std::string SoftAlignToString(SoftAlignment align, bool skipEOS = false);
bool reversed = true,
bool skipEOS = false);
} // namespace data } // namespace data
} // namespace marian } // namespace marian

Просмотреть файл

@ -87,13 +87,11 @@ protected:
std::string getAlignment(const data::SoftAlignment& align) { std::string getAlignment(const data::SoftAlignment& align) {
if(alignment_ == "soft") { if(alignment_ == "soft") {
return data::SoftAlignToString(align, false, true); return data::SoftAlignToString(align, true);
} else if(alignment_ == "hard") { } else if(alignment_ == "hard") {
return data::ConvertSoftAlignToHardAlign(align, 1.f, false, true) return data::ConvertSoftAlignToHardAlign(align, 1.f, true).toString();
.toString();
} else if(alignmentThreshold_ > 0.f) { } else if(alignmentThreshold_ > 0.f) {
return data::ConvertSoftAlignToHardAlign( return data::ConvertSoftAlignToHardAlign(align, alignmentThreshold_, true)
align, alignmentThreshold_, false, true)
.toString(); .toString();
} else { } else {
ABORT("Unrecognized word alignment type"); ABORT("Unrecognized word alignment type");

Просмотреть файл

@ -4,14 +4,17 @@ namespace marian {
std::string OutputPrinter::getAlignment(const Ptr<Hypothesis>& hyp) { std::string OutputPrinter::getAlignment(const Ptr<Hypothesis>& hyp) {
data::SoftAlignment align; data::SoftAlignment align;
// Skip EOS // skip EOS
auto last = hyp->GetPrevHyp(); auto last = hyp->GetPrevHyp();
// Get soft alignments for each target word // get soft alignments for each target word starting from the last token
while(last->GetPrevHyp().get() != nullptr) { while(last->GetPrevHyp().get() != nullptr) {
align.push_back(last->GetAlignment()); align.push_back(last->GetAlignment());
last = last->GetPrevHyp(); last = last->GetPrevHyp();
} }
// reverse alignments
std::reverse(align.begin(), align.end());
if(alignment_ == "soft") { if(alignment_ == "soft") {
return data::SoftAlignToString(align); return data::SoftAlignToString(align);
} else if(alignment_ == "hard") { } else if(alignment_ == "hard") {