Unify soft alignments in decoder and scorer

This commit is contained in:
Roman Grundkiewicz 2018-08-17 10:49:48 +01:00
Родитель 883ad0b97b
Коммит cc8cc4a31d
4 изменённых файлов: 14 добавлений и 22 удалений

Просмотреть файл

@ -34,7 +34,6 @@ std::string WordAlignment::toString() const {
WordAlignment ConvertSoftAlignToHardAlign(SoftAlignment alignSoft,
float threshold /*= 1.f*/,
bool reversed /*= true*/,
bool skipEOS /*= false*/) {
size_t shift = alignSoft.size() > 0 && skipEOS ? 1 : 0;
WordAlignment align;
@ -42,10 +41,9 @@ WordAlignment ConvertSoftAlignToHardAlign(SoftAlignment alignSoft,
if(threshold == 1.f) {
for(size_t t = 0; t < alignSoft.size() - shift; ++t) {
// Retrieved alignments are in reversed order
size_t rev = reversed ? alignSoft.size() - t - 1 : t;
size_t maxArg = 0;
for(size_t s = 0; s < alignSoft[0].size(); ++s) {
if(alignSoft[rev][maxArg] < alignSoft[rev][s]) {
if(alignSoft[t][maxArg] < alignSoft[t][s]) {
maxArg = s;
}
}
@ -55,9 +53,8 @@ WordAlignment ConvertSoftAlignToHardAlign(SoftAlignment alignSoft,
// Alignments by greather-than-threshold
for(size_t t = 0; t < alignSoft.size() - shift; ++t) {
// Retrieved alignments are in reversed order
size_t rev = reversed ? alignSoft.size() - t - 1 : t;
for(size_t s = 0; s < alignSoft[0].size(); ++s) {
if(alignSoft[rev][s] > threshold) {
if(alignSoft[t][s] > threshold) {
align.push_back(s, t);
}
}
@ -70,20 +67,17 @@ WordAlignment ConvertSoftAlignToHardAlign(SoftAlignment alignSoft,
return align;
}
std::string SoftAlignToString(SoftAlignment align,
bool reversed /*= true*/,
bool skipEOS /*= false*/) {
std::string SoftAlignToString(SoftAlignment align, bool skipEOS /*= false*/) {
std::stringstream str;
size_t shift = align.size() > 0 && skipEOS ? 1 : 0;
bool first = true;
for(size_t t = 0; t < align.size() - shift; ++t) {
size_t rev = reversed ? align.size() - t - 1 : t;
if(!first)
str << " ";
for(size_t s = 0; s < align[rev].size(); ++s) {
for(size_t s = 0; s < align[t].size(); ++s) {
if(s != 0)
str << ",";
str << align[rev][s];
str << align[t][s];
}
first = false;
}

Просмотреть файл

@ -51,12 +51,9 @@ typedef std::vector<std::vector<float>> SoftAlignment;
WordAlignment ConvertSoftAlignToHardAlign(SoftAlignment alignSoft,
float threshold = 1.f,
bool reversed = true,
bool skipEOS = false);
std::string SoftAlignToString(SoftAlignment align,
bool reversed = true,
bool skipEOS = false);
std::string SoftAlignToString(SoftAlignment align, bool skipEOS = false);
} // namespace data
} // namespace marian

Просмотреть файл

@ -87,13 +87,11 @@ protected:
std::string getAlignment(const data::SoftAlignment& align) {
if(alignment_ == "soft") {
return data::SoftAlignToString(align, false, true);
return data::SoftAlignToString(align, true);
} else if(alignment_ == "hard") {
return data::ConvertSoftAlignToHardAlign(align, 1.f, false, true)
.toString();
return data::ConvertSoftAlignToHardAlign(align, 1.f, true).toString();
} else if(alignmentThreshold_ > 0.f) {
return data::ConvertSoftAlignToHardAlign(
align, alignmentThreshold_, false, true)
return data::ConvertSoftAlignToHardAlign(align, alignmentThreshold_, true)
.toString();
} else {
ABORT("Unrecognized word alignment type");

Просмотреть файл

@ -4,14 +4,17 @@ namespace marian {
std::string OutputPrinter::getAlignment(const Ptr<Hypothesis>& hyp) {
data::SoftAlignment align;
// Skip EOS
// skip EOS
auto last = hyp->GetPrevHyp();
// Get soft alignments for each target word
// get soft alignments for each target word starting from the last token
while(last->GetPrevHyp().get() != nullptr) {
align.push_back(last->GetAlignment());
last = last->GetPrevHyp();
}
// reverse alignments
std::reverse(align.begin(), align.end());
if(alignment_ == "soft") {
return data::SoftAlignToString(align);
} else if(alignment_ == "hard") {