зеркало из https://github.com/mozilla/marian.git
Unify soft alignments in decoder and scorer
This commit is contained in:
Родитель
883ad0b97b
Коммит
cc8cc4a31d
|
@ -34,7 +34,6 @@ std::string WordAlignment::toString() const {
|
|||
|
||||
WordAlignment ConvertSoftAlignToHardAlign(SoftAlignment alignSoft,
|
||||
float threshold /*= 1.f*/,
|
||||
bool reversed /*= true*/,
|
||||
bool skipEOS /*= false*/) {
|
||||
size_t shift = alignSoft.size() > 0 && skipEOS ? 1 : 0;
|
||||
WordAlignment align;
|
||||
|
@ -42,10 +41,9 @@ WordAlignment ConvertSoftAlignToHardAlign(SoftAlignment alignSoft,
|
|||
if(threshold == 1.f) {
|
||||
for(size_t t = 0; t < alignSoft.size() - shift; ++t) {
|
||||
// Retrieved alignments are in reversed order
|
||||
size_t rev = reversed ? alignSoft.size() - t - 1 : t;
|
||||
size_t maxArg = 0;
|
||||
for(size_t s = 0; s < alignSoft[0].size(); ++s) {
|
||||
if(alignSoft[rev][maxArg] < alignSoft[rev][s]) {
|
||||
if(alignSoft[t][maxArg] < alignSoft[t][s]) {
|
||||
maxArg = s;
|
||||
}
|
||||
}
|
||||
|
@ -55,9 +53,8 @@ WordAlignment ConvertSoftAlignToHardAlign(SoftAlignment alignSoft,
|
|||
// Alignments by greather-than-threshold
|
||||
for(size_t t = 0; t < alignSoft.size() - shift; ++t) {
|
||||
// Retrieved alignments are in reversed order
|
||||
size_t rev = reversed ? alignSoft.size() - t - 1 : t;
|
||||
for(size_t s = 0; s < alignSoft[0].size(); ++s) {
|
||||
if(alignSoft[rev][s] > threshold) {
|
||||
if(alignSoft[t][s] > threshold) {
|
||||
align.push_back(s, t);
|
||||
}
|
||||
}
|
||||
|
@ -70,20 +67,17 @@ WordAlignment ConvertSoftAlignToHardAlign(SoftAlignment alignSoft,
|
|||
return align;
|
||||
}
|
||||
|
||||
std::string SoftAlignToString(SoftAlignment align,
|
||||
bool reversed /*= true*/,
|
||||
bool skipEOS /*= false*/) {
|
||||
std::string SoftAlignToString(SoftAlignment align, bool skipEOS /*= false*/) {
|
||||
std::stringstream str;
|
||||
size_t shift = align.size() > 0 && skipEOS ? 1 : 0;
|
||||
bool first = true;
|
||||
for(size_t t = 0; t < align.size() - shift; ++t) {
|
||||
size_t rev = reversed ? align.size() - t - 1 : t;
|
||||
if(!first)
|
||||
str << " ";
|
||||
for(size_t s = 0; s < align[rev].size(); ++s) {
|
||||
for(size_t s = 0; s < align[t].size(); ++s) {
|
||||
if(s != 0)
|
||||
str << ",";
|
||||
str << align[rev][s];
|
||||
str << align[t][s];
|
||||
}
|
||||
first = false;
|
||||
}
|
||||
|
|
|
@ -51,12 +51,9 @@ typedef std::vector<std::vector<float>> SoftAlignment;
|
|||
|
||||
WordAlignment ConvertSoftAlignToHardAlign(SoftAlignment alignSoft,
|
||||
float threshold = 1.f,
|
||||
bool reversed = true,
|
||||
bool skipEOS = false);
|
||||
|
||||
std::string SoftAlignToString(SoftAlignment align,
|
||||
bool reversed = true,
|
||||
bool skipEOS = false);
|
||||
std::string SoftAlignToString(SoftAlignment align, bool skipEOS = false);
|
||||
|
||||
} // namespace data
|
||||
} // namespace marian
|
||||
|
|
|
@ -87,13 +87,11 @@ protected:
|
|||
|
||||
std::string getAlignment(const data::SoftAlignment& align) {
|
||||
if(alignment_ == "soft") {
|
||||
return data::SoftAlignToString(align, false, true);
|
||||
return data::SoftAlignToString(align, true);
|
||||
} else if(alignment_ == "hard") {
|
||||
return data::ConvertSoftAlignToHardAlign(align, 1.f, false, true)
|
||||
.toString();
|
||||
return data::ConvertSoftAlignToHardAlign(align, 1.f, true).toString();
|
||||
} else if(alignmentThreshold_ > 0.f) {
|
||||
return data::ConvertSoftAlignToHardAlign(
|
||||
align, alignmentThreshold_, false, true)
|
||||
return data::ConvertSoftAlignToHardAlign(align, alignmentThreshold_, true)
|
||||
.toString();
|
||||
} else {
|
||||
ABORT("Unrecognized word alignment type");
|
||||
|
|
|
@ -4,14 +4,17 @@ namespace marian {
|
|||
|
||||
std::string OutputPrinter::getAlignment(const Ptr<Hypothesis>& hyp) {
|
||||
data::SoftAlignment align;
|
||||
// Skip EOS
|
||||
// skip EOS
|
||||
auto last = hyp->GetPrevHyp();
|
||||
// Get soft alignments for each target word
|
||||
// get soft alignments for each target word starting from the last token
|
||||
while(last->GetPrevHyp().get() != nullptr) {
|
||||
align.push_back(last->GetAlignment());
|
||||
last = last->GetPrevHyp();
|
||||
}
|
||||
|
||||
// reverse alignments
|
||||
std::reverse(align.begin(), align.end());
|
||||
|
||||
if(alignment_ == "soft") {
|
||||
return data::SoftAlignToString(align);
|
||||
} else if(alignment_ == "hard") {
|
||||
|
|
Загрузка…
Ссылка в новой задаче