зеркало из https://github.com/mozilla/marian.git
Add soft alignments in scorer
This commit is contained in:
Родитель
699454a42f
Коммит
883ad0b97b
|
@ -70,16 +70,22 @@ WordAlignment ConvertSoftAlignToHardAlign(SoftAlignment alignSoft,
|
|||
return align;
|
||||
}
|
||||
|
||||
std::string SoftAlignToString(SoftAlignment align) {
|
||||
std::string SoftAlignToString(SoftAlignment align,
|
||||
bool reversed /*= true*/,
|
||||
bool skipEOS /*= false*/) {
|
||||
std::stringstream str;
|
||||
for(size_t t = align.size(); t > 0; --t) {
|
||||
if(t != align.size())
|
||||
size_t shift = align.size() > 0 && skipEOS ? 1 : 0;
|
||||
bool first = true;
|
||||
for(size_t t = 0; t < align.size() - shift; ++t) {
|
||||
size_t rev = reversed ? align.size() - t - 1 : t;
|
||||
if(!first)
|
||||
str << " ";
|
||||
for(size_t s = 0; s < align[t - 1].size(); ++s) {
|
||||
for(size_t s = 0; s < align[rev].size(); ++s) {
|
||||
if(s != 0)
|
||||
str << ",";
|
||||
str << align[t - 1][s];
|
||||
str << align[rev][s];
|
||||
}
|
||||
first = false;
|
||||
}
|
||||
return str.str();
|
||||
}
|
||||
|
|
|
@ -54,7 +54,9 @@ WordAlignment ConvertSoftAlignToHardAlign(SoftAlignment alignSoft,
|
|||
bool reversed = true,
|
||||
bool skipEOS = false);
|
||||
|
||||
std::string SoftAlignToString(SoftAlignment align);
|
||||
std::string SoftAlignToString(SoftAlignment align,
|
||||
bool reversed = true,
|
||||
bool skipEOS = false);
|
||||
|
||||
} // namespace data
|
||||
} // namespace marian
|
||||
|
|
|
@ -54,6 +54,7 @@ public:
|
|||
? std::static_pointer_cast<CorpusBase>(
|
||||
New<CorpusNBest>(options_))
|
||||
: std::static_pointer_cast<CorpusBase>(New<Corpus>(options_))) {
|
||||
// @TODO: move to doValidation in Config Parser
|
||||
ABORT_IF(options_->has("summary") && options_->has("alignment"),
|
||||
"Alignments can not be produced with summarized score");
|
||||
|
||||
|
@ -96,9 +97,9 @@ public:
|
|||
Ptr<ScoreCollector> output = options_->get<bool>("n-best")
|
||||
? std::static_pointer_cast<ScoreCollector>(
|
||||
New<ScoreCollectorNBest>(options_))
|
||||
: New<ScoreCollector>();
|
||||
: New<ScoreCollector>(options_);
|
||||
|
||||
float alignment = options_->get<float>("alignment", .0f);
|
||||
std::string alignment = options_->get<std::string>("alignment", "");
|
||||
bool summarize = options_->has("summary");
|
||||
std::string summary
|
||||
= summarize ? options_->get<std::string>("summary") : "cross-entropy";
|
||||
|
@ -134,7 +135,7 @@ public:
|
|||
|
||||
// soft alignments for each sentence in the batch
|
||||
std::vector<data::SoftAlignment> aligns(batch->size());
|
||||
if(alignment > .0f) {
|
||||
if(!alignment.empty()) {
|
||||
getAlignmentsForBatch(builder->getAlignment(), batch, aligns);
|
||||
}
|
||||
|
||||
|
|
|
@ -15,7 +15,11 @@ namespace marian {
|
|||
|
||||
class ScoreCollector {
|
||||
public:
|
||||
ScoreCollector() : nextId_(0), outStrm_(new OutputFileStream(std::cout)){};
|
||||
ScoreCollector(const Ptr<Config>& options)
|
||||
: nextId_(0),
|
||||
outStrm_(new OutputFileStream(std::cout)),
|
||||
alignment_(options->get<std::string>("alignment", "")),
|
||||
alignmentThreshold_(getAlignmentThreshold(alignment_)){};
|
||||
|
||||
virtual void Write(long id, const std::string& message) {
|
||||
boost::mutex::scoped_lock lock(mutex_);
|
||||
|
@ -57,11 +61,8 @@ public:
|
|||
float score,
|
||||
const data::SoftAlignment& align = {}) {
|
||||
auto msg = std::to_string(score);
|
||||
if(!align.empty()) {
|
||||
auto wordAlign
|
||||
= data::ConvertSoftAlignToHardAlign(align, 1.f, false, true);
|
||||
msg += " ||| " + wordAlign.toString();
|
||||
}
|
||||
if(!alignment_.empty() && !align.empty())
|
||||
msg += " ||| " + getAlignment(align);
|
||||
Write(id, msg);
|
||||
}
|
||||
|
||||
|
@ -72,6 +73,33 @@ protected:
|
|||
|
||||
typedef std::map<long, std::string> Outputs;
|
||||
Outputs outputs_;
|
||||
|
||||
std::string alignment_;
|
||||
float alignmentThreshold_{0.f};
|
||||
|
||||
float getAlignmentThreshold(const std::string& str) {
|
||||
try {
|
||||
return std::max(std::stof(str), 0.f);
|
||||
} catch(...) {
|
||||
return 0.f;
|
||||
}
|
||||
}
|
||||
|
||||
std::string getAlignment(const data::SoftAlignment& align) {
|
||||
if(alignment_ == "soft") {
|
||||
return data::SoftAlignToString(align, false, true);
|
||||
} else if(alignment_ == "hard") {
|
||||
return data::ConvertSoftAlignToHardAlign(align, 1.f, false, true)
|
||||
.toString();
|
||||
} else if(alignmentThreshold_ > 0.f) {
|
||||
return data::ConvertSoftAlignToHardAlign(
|
||||
align, alignmentThreshold_, false, true)
|
||||
.toString();
|
||||
} else {
|
||||
ABORT("Unrecognized word alignment type");
|
||||
}
|
||||
return "";
|
||||
}
|
||||
};
|
||||
|
||||
class ScoreCollectorNBest : public ScoreCollector {
|
||||
|
@ -87,7 +115,9 @@ private:
|
|||
public:
|
||||
ScoreCollectorNBest() = delete;
|
||||
|
||||
ScoreCollectorNBest(const Ptr<Config>& options) : options_(options) {
|
||||
// TODO: get rid of the options_ attribute
|
||||
ScoreCollectorNBest(const Ptr<Config>& options)
|
||||
: ScoreCollector(options), options_(options) {
|
||||
auto paths = options_->get<std::vector<std::string>>("train-sets");
|
||||
nBestList_ = paths.back();
|
||||
fname_ = options_->get<std::string>("n-best-feature");
|
||||
|
@ -103,11 +133,8 @@ public:
|
|||
std::vector<std::string> fields;
|
||||
utils::Split(nbest, fields, "|||");
|
||||
std::stringstream ss;
|
||||
if(!align.empty()) {
|
||||
auto wordAlign
|
||||
= data::ConvertSoftAlignToHardAlign(align, 1.f, false, true);
|
||||
ss << " " << wordAlign.toString() << " |||";
|
||||
}
|
||||
if(!alignment_.empty() && !align.empty())
|
||||
ss << " " << getAlignment(align) << " |||";
|
||||
ss << fields[2] << feature << "= " << score << " ";
|
||||
fields[2] = ss.str();
|
||||
return utils::Join(fields, "|||");
|
||||
|
|
Загрузка…
Ссылка в новой задаче