зеркало из https://github.com/mozilla/marian.git
Unify soft alignments in decoder and scorer
This commit is contained in:
Родитель
883ad0b97b
Коммит
cc8cc4a31d
|
@ -34,7 +34,6 @@ std::string WordAlignment::toString() const {
|
||||||
|
|
||||||
WordAlignment ConvertSoftAlignToHardAlign(SoftAlignment alignSoft,
|
WordAlignment ConvertSoftAlignToHardAlign(SoftAlignment alignSoft,
|
||||||
float threshold /*= 1.f*/,
|
float threshold /*= 1.f*/,
|
||||||
bool reversed /*= true*/,
|
|
||||||
bool skipEOS /*= false*/) {
|
bool skipEOS /*= false*/) {
|
||||||
size_t shift = alignSoft.size() > 0 && skipEOS ? 1 : 0;
|
size_t shift = alignSoft.size() > 0 && skipEOS ? 1 : 0;
|
||||||
WordAlignment align;
|
WordAlignment align;
|
||||||
|
@ -42,10 +41,9 @@ WordAlignment ConvertSoftAlignToHardAlign(SoftAlignment alignSoft,
|
||||||
if(threshold == 1.f) {
|
if(threshold == 1.f) {
|
||||||
for(size_t t = 0; t < alignSoft.size() - shift; ++t) {
|
for(size_t t = 0; t < alignSoft.size() - shift; ++t) {
|
||||||
// Retrieved alignments are in reversed order
|
// Retrieved alignments are in reversed order
|
||||||
size_t rev = reversed ? alignSoft.size() - t - 1 : t;
|
|
||||||
size_t maxArg = 0;
|
size_t maxArg = 0;
|
||||||
for(size_t s = 0; s < alignSoft[0].size(); ++s) {
|
for(size_t s = 0; s < alignSoft[0].size(); ++s) {
|
||||||
if(alignSoft[rev][maxArg] < alignSoft[rev][s]) {
|
if(alignSoft[t][maxArg] < alignSoft[t][s]) {
|
||||||
maxArg = s;
|
maxArg = s;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -55,9 +53,8 @@ WordAlignment ConvertSoftAlignToHardAlign(SoftAlignment alignSoft,
|
||||||
// Alignments by greather-than-threshold
|
// Alignments by greather-than-threshold
|
||||||
for(size_t t = 0; t < alignSoft.size() - shift; ++t) {
|
for(size_t t = 0; t < alignSoft.size() - shift; ++t) {
|
||||||
// Retrieved alignments are in reversed order
|
// Retrieved alignments are in reversed order
|
||||||
size_t rev = reversed ? alignSoft.size() - t - 1 : t;
|
|
||||||
for(size_t s = 0; s < alignSoft[0].size(); ++s) {
|
for(size_t s = 0; s < alignSoft[0].size(); ++s) {
|
||||||
if(alignSoft[rev][s] > threshold) {
|
if(alignSoft[t][s] > threshold) {
|
||||||
align.push_back(s, t);
|
align.push_back(s, t);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -70,20 +67,17 @@ WordAlignment ConvertSoftAlignToHardAlign(SoftAlignment alignSoft,
|
||||||
return align;
|
return align;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string SoftAlignToString(SoftAlignment align,
|
std::string SoftAlignToString(SoftAlignment align, bool skipEOS /*= false*/) {
|
||||||
bool reversed /*= true*/,
|
|
||||||
bool skipEOS /*= false*/) {
|
|
||||||
std::stringstream str;
|
std::stringstream str;
|
||||||
size_t shift = align.size() > 0 && skipEOS ? 1 : 0;
|
size_t shift = align.size() > 0 && skipEOS ? 1 : 0;
|
||||||
bool first = true;
|
bool first = true;
|
||||||
for(size_t t = 0; t < align.size() - shift; ++t) {
|
for(size_t t = 0; t < align.size() - shift; ++t) {
|
||||||
size_t rev = reversed ? align.size() - t - 1 : t;
|
|
||||||
if(!first)
|
if(!first)
|
||||||
str << " ";
|
str << " ";
|
||||||
for(size_t s = 0; s < align[rev].size(); ++s) {
|
for(size_t s = 0; s < align[t].size(); ++s) {
|
||||||
if(s != 0)
|
if(s != 0)
|
||||||
str << ",";
|
str << ",";
|
||||||
str << align[rev][s];
|
str << align[t][s];
|
||||||
}
|
}
|
||||||
first = false;
|
first = false;
|
||||||
}
|
}
|
||||||
|
|
|
@ -51,12 +51,9 @@ typedef std::vector<std::vector<float>> SoftAlignment;
|
||||||
|
|
||||||
WordAlignment ConvertSoftAlignToHardAlign(SoftAlignment alignSoft,
|
WordAlignment ConvertSoftAlignToHardAlign(SoftAlignment alignSoft,
|
||||||
float threshold = 1.f,
|
float threshold = 1.f,
|
||||||
bool reversed = true,
|
|
||||||
bool skipEOS = false);
|
bool skipEOS = false);
|
||||||
|
|
||||||
std::string SoftAlignToString(SoftAlignment align,
|
std::string SoftAlignToString(SoftAlignment align, bool skipEOS = false);
|
||||||
bool reversed = true,
|
|
||||||
bool skipEOS = false);
|
|
||||||
|
|
||||||
} // namespace data
|
} // namespace data
|
||||||
} // namespace marian
|
} // namespace marian
|
||||||
|
|
|
@ -87,13 +87,11 @@ protected:
|
||||||
|
|
||||||
std::string getAlignment(const data::SoftAlignment& align) {
|
std::string getAlignment(const data::SoftAlignment& align) {
|
||||||
if(alignment_ == "soft") {
|
if(alignment_ == "soft") {
|
||||||
return data::SoftAlignToString(align, false, true);
|
return data::SoftAlignToString(align, true);
|
||||||
} else if(alignment_ == "hard") {
|
} else if(alignment_ == "hard") {
|
||||||
return data::ConvertSoftAlignToHardAlign(align, 1.f, false, true)
|
return data::ConvertSoftAlignToHardAlign(align, 1.f, true).toString();
|
||||||
.toString();
|
|
||||||
} else if(alignmentThreshold_ > 0.f) {
|
} else if(alignmentThreshold_ > 0.f) {
|
||||||
return data::ConvertSoftAlignToHardAlign(
|
return data::ConvertSoftAlignToHardAlign(align, alignmentThreshold_, true)
|
||||||
align, alignmentThreshold_, false, true)
|
|
||||||
.toString();
|
.toString();
|
||||||
} else {
|
} else {
|
||||||
ABORT("Unrecognized word alignment type");
|
ABORT("Unrecognized word alignment type");
|
||||||
|
|
|
@ -4,14 +4,17 @@ namespace marian {
|
||||||
|
|
||||||
std::string OutputPrinter::getAlignment(const Ptr<Hypothesis>& hyp) {
|
std::string OutputPrinter::getAlignment(const Ptr<Hypothesis>& hyp) {
|
||||||
data::SoftAlignment align;
|
data::SoftAlignment align;
|
||||||
// Skip EOS
|
// skip EOS
|
||||||
auto last = hyp->GetPrevHyp();
|
auto last = hyp->GetPrevHyp();
|
||||||
// Get soft alignments for each target word
|
// get soft alignments for each target word starting from the last token
|
||||||
while(last->GetPrevHyp().get() != nullptr) {
|
while(last->GetPrevHyp().get() != nullptr) {
|
||||||
align.push_back(last->GetAlignment());
|
align.push_back(last->GetAlignment());
|
||||||
last = last->GetPrevHyp();
|
last = last->GetPrevHyp();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// reverse alignments
|
||||||
|
std::reverse(align.begin(), align.end());
|
||||||
|
|
||||||
if(alignment_ == "soft") {
|
if(alignment_ == "soft") {
|
||||||
return data::SoftAlignToString(align);
|
return data::SoftAlignToString(align);
|
||||||
} else if(alignment_ == "hard") {
|
} else if(alignment_ == "hard") {
|
||||||
|
|
Загрузка…
Ссылка в новой задаче