This commit is contained in:
Roman Grundkiewicz 2018-08-17 10:09:05 +01:00
Родитель 699454a42f
Коммит 883ad0b97b
4 изменённых файлов: 57 добавлений и 21 удалений

Просмотреть файл

@ -70,16 +70,22 @@ WordAlignment ConvertSoftAlignToHardAlign(SoftAlignment alignSoft,
return align;
}
std::string SoftAlignToString(SoftAlignment align) {
std::string SoftAlignToString(SoftAlignment align,
bool reversed /*= true*/,
bool skipEOS /*= false*/) {
std::stringstream str;
for(size_t t = align.size(); t > 0; --t) {
if(t != align.size())
size_t shift = align.size() > 0 && skipEOS ? 1 : 0;
bool first = true;
for(size_t t = 0; t < align.size() - shift; ++t) {
size_t rev = reversed ? align.size() - t - 1 : t;
if(!first)
str << " ";
for(size_t s = 0; s < align[t - 1].size(); ++s) {
for(size_t s = 0; s < align[rev].size(); ++s) {
if(s != 0)
str << ",";
str << align[t - 1][s];
str << align[rev][s];
}
first = false;
}
return str.str();
}

Просмотреть файл

@ -54,7 +54,9 @@ WordAlignment ConvertSoftAlignToHardAlign(SoftAlignment alignSoft,
bool reversed = true,
bool skipEOS = false);
std::string SoftAlignToString(SoftAlignment align);
std::string SoftAlignToString(SoftAlignment align,
bool reversed = true,
bool skipEOS = false);
} // namespace data
} // namespace marian

Просмотреть файл

@ -54,6 +54,7 @@ public:
? std::static_pointer_cast<CorpusBase>(
New<CorpusNBest>(options_))
: std::static_pointer_cast<CorpusBase>(New<Corpus>(options_))) {
// @TODO: move to doValidation in Config Parser
ABORT_IF(options_->has("summary") && options_->has("alignment"),
"Alignments can not be produced with summarized score");
@ -96,9 +97,9 @@ public:
Ptr<ScoreCollector> output = options_->get<bool>("n-best")
? std::static_pointer_cast<ScoreCollector>(
New<ScoreCollectorNBest>(options_))
: New<ScoreCollector>();
: New<ScoreCollector>(options_);
float alignment = options_->get<float>("alignment", .0f);
std::string alignment = options_->get<std::string>("alignment", "");
bool summarize = options_->has("summary");
std::string summary
= summarize ? options_->get<std::string>("summary") : "cross-entropy";
@ -134,7 +135,7 @@ public:
// soft alignments for each sentence in the batch
std::vector<data::SoftAlignment> aligns(batch->size());
if(alignment > .0f) {
if(!alignment.empty()) {
getAlignmentsForBatch(builder->getAlignment(), batch, aligns);
}

Просмотреть файл

@ -15,7 +15,11 @@ namespace marian {
class ScoreCollector {
public:
ScoreCollector() : nextId_(0), outStrm_(new OutputFileStream(std::cout)){};
ScoreCollector(const Ptr<Config>& options)
: nextId_(0),
outStrm_(new OutputFileStream(std::cout)),
alignment_(options->get<std::string>("alignment", "")),
alignmentThreshold_(getAlignmentThreshold(alignment_)){};
virtual void Write(long id, const std::string& message) {
boost::mutex::scoped_lock lock(mutex_);
@ -57,11 +61,8 @@ public:
float score,
const data::SoftAlignment& align = {}) {
auto msg = std::to_string(score);
if(!align.empty()) {
auto wordAlign
= data::ConvertSoftAlignToHardAlign(align, 1.f, false, true);
msg += " ||| " + wordAlign.toString();
}
if(!alignment_.empty() && !align.empty())
msg += " ||| " + getAlignment(align);
Write(id, msg);
}
@ -72,6 +73,33 @@ protected:
typedef std::map<long, std::string> Outputs;
Outputs outputs_;
std::string alignment_;
float alignmentThreshold_{0.f};
float getAlignmentThreshold(const std::string& str) {
try {
return std::max(std::stof(str), 0.f);
} catch(...) {
return 0.f;
}
}
std::string getAlignment(const data::SoftAlignment& align) {
if(alignment_ == "soft") {
return data::SoftAlignToString(align, false, true);
} else if(alignment_ == "hard") {
return data::ConvertSoftAlignToHardAlign(align, 1.f, false, true)
.toString();
} else if(alignmentThreshold_ > 0.f) {
return data::ConvertSoftAlignToHardAlign(
align, alignmentThreshold_, false, true)
.toString();
} else {
ABORT("Unrecognized word alignment type");
}
return "";
}
};
class ScoreCollectorNBest : public ScoreCollector {
@ -87,7 +115,9 @@ private:
public:
ScoreCollectorNBest() = delete;
ScoreCollectorNBest(const Ptr<Config>& options) : options_(options) {
// TODO: get rid of the options_ attribute
ScoreCollectorNBest(const Ptr<Config>& options)
: ScoreCollector(options), options_(options) {
auto paths = options_->get<std::vector<std::string>>("train-sets");
nBestList_ = paths.back();
fname_ = options_->get<std::string>("n-best-feature");
@ -103,11 +133,8 @@ public:
std::vector<std::string> fields;
utils::Split(nbest, fields, "|||");
std::stringstream ss;
if(!align.empty()) {
auto wordAlign
= data::ConvertSoftAlignToHardAlign(align, 1.f, false, true);
ss << " " << wordAlign.toString() << " |||";
}
if(!alignment_.empty() && !align.empty())
ss << " " << getAlignment(align) << " |||";
ss << fields[2] << feature << "= " << score << " ";
fields[2] = ss.str();
return utils::Join(fields, "|||");