diff --git a/CHANGELOG.md b/CHANGELOG.md index 8226b012..047b3168 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,12 +8,18 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. ## [Unreleased] ### Added -- Returning hard alignments by scorer +- Word alignment generation in scorer +- Attention output generation in decoder and scorer with `--alignment soft` + +### Fixed +- Delayed output in line-by-line translation + +### Changed +- Generated word alignments include alignments for target EOS tokens ## [1.6.0] - 2018-08-08 ### Added - - Faster training (20-30%) by optimizing gradient popagation of biases - Returning Moses-style hard alignments during decoding single models, ensembles and n-best lists @@ -36,7 +42,6 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. - Seamless training continuation with exponential smoothing ### Fixed - - A couple of bugs in "selection" (transpose, shift, cols, rows) operators during back-prob for a very specific case: one of the operators is the first operator after a branch, in that case gradient propgation might be @@ -49,14 +54,12 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. ## [1.5.0] - 2018-06-17 ### Added - - Average Attention Networks for Transformer model - 16-bit matrix multiplication on CPU - Memoization for constant nodes for decoding - Autotuning for decoding ### Fixed - - GPU decoding optimizations, about 2x faster decoding of transformer models - Multi-node MPI-based training on GPUs diff --git a/CMakeLists.txt b/CMakeLists.txt index fb5c79f6..816386b8 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -25,6 +25,7 @@ message(STATUS "Project version: ${PROJECT_VERSION_STRING_FULL}") # Set compilation flags set(CMAKE_CXX_FLAGS_RELEASE " -std=c++11 -O3 -Ofast -m64 -pthread -march=native -Wl,--no-as-needed -funroll-loops -ffinite-math-only -fPIC -Wno-unused-result -Wno-deprecated -Wno-deprecated-gpu-targets") +set(CMAKE_CXX_FLAGS_NONATIVE " -std=c++11 -O3 -Ofast -m64 -pthread -march=x86-64 -mavx -Wl,--no-as-needed -funroll-loops -ffinite-math-only -fPIC -Wno-unused-result -Wno-deprecated -Wno-deprecated-gpu-targets") set(CMAKE_CXX_FLAGS_DEBUG " -std=c++11 -g -O0 -pthread -fPIC -Wno-unused-result -Wno-deprecated -Wno-deprecated-gpu-targets") set(CMAKE_CXX_FLAGS_ST "${CMAKE_CXX_FLAGS_RELEASE} -DNDEBUG") set(CMAKE_CXX_FLAGS_PROFILE "${CMAKE_CXX_FLAGS_RELEASE} -pg -g") diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 6176103e..9a6af6f8 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -52,6 +52,8 @@ add_library(marian STATIC models/model_factory.cpp models/encoder_decoder.cpp + rescorer/score_collector.cpp + translator/history.cpp translator/output_collector.cpp translator/output_printer.cpp diff --git a/src/common/cli_helper.h b/src/common/cli_helper.h new file mode 100644 index 00000000..036e6185 --- /dev/null +++ b/src/common/cli_helper.h @@ -0,0 +1,124 @@ +#pragma once + +#include + +#include "3rd_party/yaml-cpp/yaml.h" +#include "common/logging.h" + +namespace marian { +namespace cli { + +// helper to replace environment-variable expressions of the form ${VARNAME} in +// a string +static std::string InterpolateEnvVars(std::string str) { +// temporary workaround for MS-internal PhillyOnAzure cluster: warm storage +// presently has the form /hdfs/VC instead of /{gfs,hdfs}/CLUSTER/VC +#if 1 + if(getenv("PHILLY_JOB_ID")) { + const char* cluster = getenv("PHILLY_CLUSTER"); + const char* vc = getenv("PHILLY_VC"); + // this environment variable exists when running on the cluster + if(cluster && vc) { + static const std::string s_gfsPrefix + = std::string("/gfs/") + cluster + "/" + vc + "/"; + static const std::string s_hdfsPrefix + = std::string("/hdfs/") + cluster + "/" + vc + "/"; + if(str.find(s_gfsPrefix) == 0) + str = std::string("/hdfs/") + vc + "/" + str.substr(s_gfsPrefix.size()); + else if(str.find(s_hdfsPrefix) == 0) + str = std::string("/hdfs/") + vc + "/" + + str.substr(s_hdfsPrefix.size()); + } + } +#endif + for(;;) { + const auto pos = str.find("${"); + if(pos == std::string::npos) + return str; + const auto epos = str.find("}", pos + 2); + ABORT_IF(epos == std::string::npos, + "interpolate-env-vars option: ${{ without matching }} in '{}'", + str.c_str()); + // isolate the variable name + const auto var = str.substr(pos + 2, epos - (pos + 2)); + const auto val = getenv(var.c_str()); + ABORT_IF(!val, + "interpolate-env-vars option: environment variable '{}' not " + "defined in '{}'", + var.c_str(), + str.c_str()); + // replace it; then try again for further replacements + str = str.substr(0, pos) + val + str.substr(epos + 1); + } +} + +// helper to implement interpolate-env-vars and relative-paths options +static void ProcessPaths( + YAML::Node& node, + const std::function& TransformPath, + const std::set& PATHS, + bool isPath = false) { + if(isPath) { + if(node.Type() == YAML::NodeType::Scalar) { + std::string nodePath = node.as(); + // transform the path + if(!nodePath.empty()) + node = TransformPath(nodePath); + } + + if(node.Type() == YAML::NodeType::Sequence) { + for(auto&& sub : node) { + ProcessPaths(sub, TransformPath, PATHS, true); + } + } + } else { + switch(node.Type()) { + case YAML::NodeType::Sequence: + for(auto&& sub : node) { + ProcessPaths(sub, TransformPath, PATHS, false); + } + break; + case YAML::NodeType::Map: + for(auto&& sub : node) { + std::string key = sub.first.as(); + ProcessPaths(sub.second, TransformPath, PATHS, PATHS.count(key) > 0); + } + break; + default: + // it is OK + break; + } + } +} + +// helper to convert a Yaml node recursively into a string +static void OutputYaml(const YAML::Node node, YAML::Emitter& out) { + std::set sorter; + switch(node.Type()) { + case YAML::NodeType::Null: out << node; break; + case YAML::NodeType::Scalar: out << node; break; + case YAML::NodeType::Sequence: + out << YAML::BeginSeq; + for(auto&& n : node) + OutputYaml(n, out); + out << YAML::EndSeq; + break; + case YAML::NodeType::Map: + for(auto& n : node) + sorter.insert(n.first.as()); + out << YAML::BeginMap; + for(auto& key : sorter) { + out << YAML::Key; + out << key; + out << YAML::Value; + OutputYaml(node[key], out); + } + out << YAML::EndMap; + break; + case YAML::NodeType::Undefined: out << node; break; + } +} + + +} // namespace cli +} // namespace marian diff --git a/src/common/config.cpp b/src/common/config.cpp index f92a01d5..d189ef38 100644 --- a/src/common/config.cpp +++ b/src/common/config.cpp @@ -29,7 +29,7 @@ YAML::Node& Config::get() { void Config::log() { YAML::Emitter out; - OutputYaml(config_, out); + cli::OutputYaml(config_, out); std::string configString = out.c_str(); std::vector results; diff --git a/src/common/config.h b/src/common/config.h index d8afc2c2..a3db596d 100644 --- a/src/common/config.h +++ b/src/common/config.h @@ -1,14 +1,18 @@ #pragma once #include + #include "3rd_party/yaml-cpp/yaml.h" +#include "common/cli_helper.h" #include "common/config_parser.h" #include "common/file_stream.h" #include "common/io.h" #include "common/logging.h" #include "common/utils.h" -#ifndef _WIN32 // TODO: why are these needed by a config parser? Can they be - // removed for Linux as well? + +// TODO: why are these needed by a config parser? Can they be removed for Linux +// as well? +#ifndef _WIN32 #include #include #endif @@ -125,7 +129,7 @@ public: friend std::ostream& operator<<(std::ostream& out, const Config& config) { YAML::Emitter outYaml; - OutputYaml(config.get(), outYaml); + cli::OutputYaml(config.get(), outYaml); out << outYaml.c_str(); return out; } diff --git a/src/common/config_parser.cpp b/src/common/config_parser.cpp index eb39abee..ca93a7c3 100644 --- a/src/common/config_parser.cpp +++ b/src/common/config_parser.cpp @@ -1,28 +1,27 @@ #include -#include #include #include #include +#include + #if MKL_FOUND -//#include #include #else #if BLAS_FOUND -//#include #include #endif #endif #include "common/definitions.h" +#include "common/cli_helper.h" #include "common/config.h" #include "common/config_parser.h" #include "common/file_stream.h" #include "common/logging.h" -#include "common/version.h" - #include "common/regex.h" +#include "common/version.h" #define SET_OPTION(key, type) \ do { \ @@ -61,34 +60,6 @@ uint16_t guess_terminal_width(uint16_t max_width) { return max_width ? std::min(cols, max_width) : cols; } -// helper to convert a Yaml node recursively into a string -void OutputYaml(const YAML::Node node, YAML::Emitter& out) { - std::set sorter; - switch(node.Type()) { - case YAML::NodeType::Null: out << node; break; - case YAML::NodeType::Scalar: out << node; break; - case YAML::NodeType::Sequence: - out << YAML::BeginSeq; - for(auto&& n : node) - OutputYaml(n, out); - out << YAML::EndSeq; - break; - case YAML::NodeType::Map: - for(auto& n : node) - sorter.insert(n.first.as()); - out << YAML::BeginMap; - for(auto& key : sorter) { - out << YAML::Key; - out << key; - out << YAML::Value; - OutputYaml(node[key], out); - } - out << YAML::EndMap; - break; - case YAML::NodeType::Undefined: out << node; break; - } -} - const std::set PATHS = {"model", "models", "train-sets", @@ -100,88 +71,6 @@ const std::set PATHS = {"model", "valid-translation-output", "log"}; -// helper to implement interpolate-env-vars and relative-paths options -static void processPaths( - YAML::Node& node, - const std::function& TransformPath, - bool isPath = false) { - if(isPath) { - if(node.Type() == YAML::NodeType::Scalar) { - std::string nodePath = node.as(); - // transform the path - if(!nodePath.empty()) - node = TransformPath(nodePath); - } - - if(node.Type() == YAML::NodeType::Sequence) { - for(auto&& sub : node) { - processPaths(sub, TransformPath, true); - } - } - } else { - switch(node.Type()) { - case YAML::NodeType::Sequence: - for(auto&& sub : node) { - processPaths(sub, TransformPath, false); - } - break; - case YAML::NodeType::Map: - for(auto&& sub : node) { - std::string key = sub.first.as(); - processPaths(sub.second, TransformPath, PATHS.count(key) > 0); - } - break; - default: - // it is OK - break; - } - } -} - -// helper to replace environment-variable expressions of the form ${VARNAME} in -// a string -static std::string interpolateEnvVars(std::string str) { -// temporary workaround for MS-internal PhillyOnAzure cluster: warm storage -// presently has the form /hdfs/VC instead of /{gfs,hdfs}/CLUSTER/VC -// @TODO: remove this workaround -#if 1 - if(getenv("PHILLY_JOB_ID")) { - const char* cluster = getenv("PHILLY_CLUSTER"); - const char* vc = getenv("PHILLY_VC"); - // this environment variable exists when running on the cluster - if(cluster && vc) { - static const std::string s_gfsPrefix - = std::string("/gfs/") + cluster + "/" + vc + "/"; - static const std::string s_hdfsPrefix - = std::string("/hdfs/") + cluster + "/" + vc + "/"; - if(str.find(s_gfsPrefix) == 0) - str = std::string("/hdfs/") + vc + "/" + str.substr(s_gfsPrefix.size()); - else if(str.find(s_hdfsPrefix) == 0) - str = std::string("/hdfs/") + vc + "/" - + str.substr(s_hdfsPrefix.size()); - } - } -#endif - for(;;) { - const auto pos = str.find("${"); - if(pos == std::string::npos) - return str; - const auto epos = str.find("}", pos + 2); - ABORT_IF(epos == std::string::npos, - "interpolate-env-vars option: ${{ without matching }} in '{}'", - str.c_str()); - // isolate the variable name - const auto var = str.substr(pos + 2, epos - (pos + 2)); - const auto val = getenv(var.c_str()); - ABORT_IF(!val, - "interpolate-env-vars option: environment variable '{}' not " - "defined in '{}'", - var.c_str(), - str.c_str()); - // replace it; then try again for further replacements - str = str.substr(0, pos) + val + str.substr(epos + 1); - } -} bool ConfigParser::has(const std::string& key) const { return config_[key]; @@ -353,7 +242,7 @@ void ConfigParser::addOptionsModel(po::options_description& desc) { ("ignore-model-config", po::value()->zero_tokens()->default_value(false), "Ignore the model configuration saved in npz file") ("type", po::value()->default_value("amun"), - "Model type (possible values: amun, nematus, s2s, multi-s2s, transformer)") + "Model type: amun, nematus, s2s, multi-s2s, transformer") ("dim-vocabs", po::value>() ->multitoken() ->default_value(std::vector({0, 0}), "0 0"), @@ -541,7 +430,7 @@ void ConfigParser::addOptionsTraining(po::options_description& desc) { ("maxi-batch", po::value()->default_value(100), "Number of batches to preload for length-based sorting") ("maxi-batch-sort", po::value()->default_value("trg"), - "Sorting strategy for maxi-batch: trg (default) src none") + "Sorting strategy for maxi-batch: trg, src, none") ("optimizer,o", po::value()->default_value("adam"), "Optimization algorithm (possible values: sgd, adagrad, adam") ("optimizer-params", po::value>() @@ -554,8 +443,8 @@ void ConfigParser::addOptionsTraining(po::options_description& desc) { ("lr-decay", po::value()->default_value(0.0), "Decay factor for learning rate: lr = lr * arg (0 to disable)") ("lr-decay-strategy", po::value()->default_value("epoch+stalled"), - "Strategy for learning rate decaying " - "(possible values: epoch, batches, stalled, epoch+batches, epoch+stalled)") + "Strategy for learning rate decaying: epoch, batches, stalled, " + "epoch+batches, epoch+stalled") ("lr-decay-start", po::value>() ->multitoken() ->default_value(std::vector({10,1}), "10 1"), @@ -603,14 +492,14 @@ void ConfigParser::addOptionsTraining(po::options_description& desc) { ("guided-alignment", po::value(), "Use guided alignment to guide attention") ("guided-alignment-cost", po::value()->default_value("ce"), - "Cost type for guided alignment. Possible values: ce (cross-entropy), " - "mse (mean square error), mult (multiplication)") + "Cost type for guided alignment: ce (cross-entropy), mse (mean square " + "error), mult (multiplication)") ("guided-alignment-weight", po::value()->default_value(1), "Weight for guided alignment cost") ("data-weighting", po::value(), "File with sentence or word weights") ("data-weighting-type", po::value()->default_value("sentence"), - "Processing level for data weighting. Possible values: sentence, word") + "Processing level for data weighting: sentence, word") //("drop-rate", po::value()->default_value(0), // "Gradient drop ratio (read: https://arxiv.org/abs/1704.05021)") @@ -743,15 +632,15 @@ void ConfigParser::addOptionsTranslate(po::options_description& desc) { ("maxi-batch", po::value()->default_value(1), "Number of batches to preload for length-based sorting") ("maxi-batch-sort", po::value()->default_value("none"), - "Sorting strategy for maxi-batch: none (default) src") + "Sorting strategy for maxi-batch: none, src") ("n-best", po::value()->zero_tokens()->default_value(false), "Display n-best list") ("shortlist", po::value>()->multitoken(), "Use softmax shortlist: path first best prune") ("weights", po::value>()->multitoken(), "Scorer weights") - ("alignment", po::value()->default_value(0.f)->implicit_value(1.f), - "Return word alignments") + ("alignment", po::value()->implicit_value("1"), + "Return word alignment. Possible values: 0.0-1.0, hard, soft") // TODO: the options should be available only in server ("port,p", po::value()->default_value(8080), "Port number for web socket server") @@ -805,9 +694,9 @@ void ConfigParser::addOptionsRescore(po::options_description& desc) { ("maxi-batch", po::value()->default_value(100), "Number of batches to preload for length-based sorting") ("maxi-batch-sort", po::value()->default_value("trg"), - "Sorting strategy for maxi-batch: trg (default) src none") - ("alignment", po::value()->default_value(0.f)->implicit_value(1.f), - "Return word alignments") + "Sorting strategy for maxi-batch: trg (default), src, none") + ("alignment", po::value()->implicit_value("1"), + "Return word alignments. Possible values: 0.0-1.0, hard, soft") ; // clang-format on desc.add(rescore); @@ -856,17 +745,17 @@ void ConfigParser::parseOptions(int argc, char** argv, bool doValidate) { exit(0); } - const auto& interpolateEnvVarsIfRequested + const auto& InterpolateEnvVarsIfRequested = [&](std::string str) -> std::string { if(vm_["interpolate-env-vars"].as()) - str = interpolateEnvVars(str); + str = cli::InterpolateEnvVars(str); return str; }; bool loadConfig = vm_.count("config"); bool reloadConfig = (mode_ == ConfigMode::training) - && boost::filesystem::exists(interpolateEnvVarsIfRequested( + && boost::filesystem::exists(InterpolateEnvVarsIfRequested( vm_["model"].as() + ".yml")) && !vm_["no-reload"].as(); std::vector configPaths; @@ -875,14 +764,14 @@ void ConfigParser::parseOptions(int argc, char** argv, bool doValidate) { configPaths = vm_["config"].as>(); config_ = YAML::Node(); for(auto& configPath : configPaths) { - configPath = interpolateEnvVarsIfRequested( + configPath = InterpolateEnvVarsIfRequested( configPath); // (note: this updates the configPaths array) for(const auto& it : YAML::Load(InputFileStream(configPath))) // later file overrides config_[it.first.as()] = it.second; } } else if(reloadConfig) { - auto configPath = interpolateEnvVarsIfRequested( + auto configPath = InterpolateEnvVarsIfRequested( vm_["model"].as() + ".yml"); config_ = YAML::Load(InputFileStream(configPath)); configPaths = {configPath}; @@ -1042,7 +931,7 @@ void ConfigParser::parseOptions(int argc, char** argv, bool doValidate) { SET_OPTION("n-best-feature", std::string); SET_OPTION_NONDEFAULT("summary", std::string); SET_OPTION("optimize", bool); - SET_OPTION("alignment", float); + SET_OPTION_NONDEFAULT("alignment", std::string); } if(mode_ == ConfigMode::translating) { @@ -1055,7 +944,7 @@ void ConfigParser::parseOptions(int argc, char** argv, bool doValidate) { SET_OPTION("mini-batch-words", int); SET_OPTION_NONDEFAULT("weights", std::vector); SET_OPTION_NONDEFAULT("shortlist", std::vector); - SET_OPTION("alignment", float); + SET_OPTION_NONDEFAULT("alignment", std::string); SET_OPTION("port", size_t); SET_OPTION("optimize", bool); SET_OPTION("max-length-factor", float); @@ -1120,7 +1009,7 @@ void ConfigParser::parseOptions(int argc, char** argv, bool doValidate) { } if(get("interpolate-env-vars")) { - processPaths(config_, interpolateEnvVars); + cli::ProcessPaths(config_, cli::InterpolateEnvVars, PATHS); } if(get("relative-paths") && !vm_["dump-config"].as()) { @@ -1134,20 +1023,22 @@ void ConfigParser::parseOptions(int argc, char** argv, bool doValidate) { ABORT_IF(boost::filesystem::path{configPath}.parent_path() != configDir, "relative-paths option requires all config files to be in the " "same directory"); - processPaths(config_, [&](const std::string& nodePath) -> std::string { + + auto transformFunc = [&](const std::string& nodePath) -> std::string { // replace relative path w.r.t. configDir using namespace boost::filesystem; try { return canonical(path{nodePath}, configDir).string(); - } catch( - boost::filesystem::filesystem_error& - e) { // will fail if file does not exist; use parent in that case + } catch(boost::filesystem::filesystem_error& e) { + // will fail if file does not exist; use parent in that case std::cerr << e.what() << std::endl; auto parentPath = path{nodePath}.parent_path(); return (canonical(parentPath, configDir) / path{nodePath}.filename()) .string(); } - }); + }; + + cli::ProcessPaths(config_, transformFunc, PATHS); } if(doValidate) { @@ -1165,18 +1056,10 @@ void ConfigParser::parseOptions(int argc, char** argv, bool doValidate) { if(vm_["dump-config"].as()) { YAML::Emitter emit; - OutputYaml(config_, emit); + cli::OutputYaml(config_, emit); std::cout << emit.c_str() << std::endl; exit(0); } - - // @TODO: this should probably be in processOptionDevices() - //#ifdef BLAS_FOUND - // //omp_set_num_threads(vm_["omp-threads"].as()); - //#ifdef MKL_FOUND - // mkl_set_num_threads(vm_["omp-threads"].as()); - //#endif - //#endif } std::vector ConfigParser::getDevices() { diff --git a/src/common/config_parser.h b/src/common/config_parser.h index 0d6c2579..a4020a27 100644 --- a/src/common/config_parser.h +++ b/src/common/config_parser.h @@ -1,12 +1,15 @@ #pragma once #include + #include "3rd_party/yaml-cpp/yaml.h" #include "common/definitions.h" #include "common/file_stream.h" #include "common/logging.h" -#ifndef _WIN32 // TODO: why are these needed by a config parser? Can they be - // removed for Linux as well? + +// TODO: why are these needed by a config parser? Can they be removed for Linux +// as well? +#ifndef _WIN32 #include #include #endif @@ -22,8 +25,6 @@ enum struct ConfigMode { // try to determine the width of the terminal uint16_t guess_terminal_width(uint16_t max_width = 180); -void OutputYaml(const YAML::Node node, YAML::Emitter& out); - class ConfigParser { public: ConfigParser(int argc, char** argv, ConfigMode mode, bool validate = false) diff --git a/src/common/io.h b/src/common/io.h index d0b9b0c8..8f2578d3 100644 --- a/src/common/io.h +++ b/src/common/io.h @@ -14,7 +14,6 @@ // CPU decoding. namespace marian { - namespace io { bool isNpz(const std::string& fileName); diff --git a/src/common/logging.h b/src/common/logging.h index ea7c3aff..bf6a31ca 100644 --- a/src/common/logging.h +++ b/src/common/logging.h @@ -1,5 +1,7 @@ #pragma once +#include + #include "spdlog/spdlog.h" /** diff --git a/src/data/alignment.cpp b/src/data/alignment.cpp index 57b0ce52..708de91d 100644 --- a/src/data/alignment.cpp +++ b/src/data/alignment.cpp @@ -33,19 +33,15 @@ std::string WordAlignment::toString() const { } WordAlignment ConvertSoftAlignToHardAlign(SoftAlignment alignSoft, - float threshold /*= 1.f*/, - bool reversed /*= true*/, - bool skipEOS /*= false*/) { - size_t shift = alignSoft.size() > 0 && skipEOS ? 1 : 0; + float threshold /*= 1.f*/) { WordAlignment align; // Alignments by maximum value if(threshold == 1.f) { - for(size_t t = 0; t < alignSoft.size() - shift; ++t) { + for(size_t t = 0; t < alignSoft.size(); ++t) { // Retrieved alignments are in reversed order - size_t rev = reversed ? alignSoft.size() - t - 1 : t; size_t maxArg = 0; for(size_t s = 0; s < alignSoft[0].size(); ++s) { - if(alignSoft[rev][maxArg] < alignSoft[rev][s]) { + if(alignSoft[t][maxArg] < alignSoft[t][s]) { maxArg = s; } } @@ -53,11 +49,10 @@ WordAlignment ConvertSoftAlignToHardAlign(SoftAlignment alignSoft, } } else { // Alignments by greather-than-threshold - for(size_t t = 0; t < alignSoft.size() - shift; ++t) { + for(size_t t = 0; t < alignSoft.size(); ++t) { // Retrieved alignments are in reversed order - size_t rev = reversed ? alignSoft.size() - t - 1 : t; for(size_t s = 0; s < alignSoft[0].size(); ++s) { - if(alignSoft[rev][s] > threshold) { + if(alignSoft[t][s] > threshold) { align.push_back(s, t); } } @@ -70,5 +65,21 @@ WordAlignment ConvertSoftAlignToHardAlign(SoftAlignment alignSoft, return align; } +std::string SoftAlignToString(SoftAlignment align) { + std::stringstream str; + bool first = true; + for(size_t t = 0; t < align.size(); ++t) { + if(!first) + str << " "; + for(size_t s = 0; s < align[t].size(); ++s) { + if(s != 0) + str << ","; + str << align[t][s]; + } + first = false; + } + return str.str(); +} + } // namespace data } // namespace marian diff --git a/src/data/alignment.h b/src/data/alignment.h index 375429fc..2ec99ff3 100644 --- a/src/data/alignment.h +++ b/src/data/alignment.h @@ -50,9 +50,9 @@ public: typedef std::vector> SoftAlignment; WordAlignment ConvertSoftAlignToHardAlign(SoftAlignment alignSoft, - float threshold = 1.f, - bool reversed = true, - bool skipEOS = false); + float threshold = 1.f); + +std::string SoftAlignToString(SoftAlignment align); } // namespace data } // namespace marian diff --git a/src/data/batch_generator.h b/src/data/batch_generator.h index 85bf5b78..a6130a67 100644 --- a/src/data/batch_generator.h +++ b/src/data/batch_generator.h @@ -13,7 +13,6 @@ #include "training/training_state.h" namespace marian { - namespace data { template diff --git a/src/data/batch_stats.h b/src/data/batch_stats.h index ad840648..993c3d11 100644 --- a/src/data/batch_stats.h +++ b/src/data/batch_stats.h @@ -8,7 +8,6 @@ #include "data/vocab.h" namespace marian { - namespace data { class BatchStats { diff --git a/src/graph/chainable.h b/src/graph/chainable.h index da380818..39e9d6b4 100644 --- a/src/graph/chainable.h +++ b/src/graph/chainable.h @@ -7,7 +7,6 @@ #include "3rd_party/exception.h" #include "common/definitions.h" -// Parent namespace for the Marian project namespace marian { #define NodeOp(op) [=]() { op; } diff --git a/src/graph/expression_graph.h b/src/graph/expression_graph.h index 99a7588e..a2952d9e 100644 --- a/src/graph/expression_graph.h +++ b/src/graph/expression_graph.h @@ -37,7 +37,7 @@ public: shortterm_(New()), longterm_(New()) {} -Tensors(Ptr backend, Ptr device) + Tensors(Ptr backend, Ptr device) : tensors_(New(backend, device)), cache_(New(backend)), shortterm_(New()), @@ -156,7 +156,8 @@ public: params_->clear(); } - void setDevice(DeviceId deviceId = {0, DeviceType::gpu}, Ptr device = nullptr); + void setDevice(DeviceId deviceId = {0, DeviceType::gpu}, + Ptr device = nullptr); DeviceId getDeviceId() { return backend_->getDeviceId(); } diff --git a/src/microsoft/quicksand.cpp b/src/microsoft/quicksand.cpp index 27c881e1..a3198421 100644 --- a/src/microsoft/quicksand.cpp +++ b/src/microsoft/quicksand.cpp @@ -97,9 +97,7 @@ public: } } - void setWorkspace(uint8_t* data, size_t size) { - device_->set(data, size); - } + void setWorkspace(uint8_t* data, size_t size) { device_->set(data, size); } QSNBestBatch decode(const QSBatch& qsBatch, size_t maxLength, diff --git a/src/models/encoder_decoder.cpp b/src/models/encoder_decoder.cpp index f9dd0bf7..3afd1b12 100755 --- a/src/models/encoder_decoder.cpp +++ b/src/models/encoder_decoder.cpp @@ -1,4 +1,5 @@ #include "encoder_decoder.h" +#include "common/cli_helper.h" namespace marian { @@ -91,7 +92,7 @@ Config::YamlNode EncoderDecoder::getModelParameters() { std::string EncoderDecoder::getModelParametersAsString() { auto yaml = getModelParameters(); YAML::Emitter out; - OutputYaml(yaml, out); + cli::OutputYaml(yaml, out); return std::string(out.c_str()); } diff --git a/src/models/model_factory.h b/src/models/model_factory.h index 2b23ab9a..2ec7fe75 100644 --- a/src/models/model_factory.h +++ b/src/models/model_factory.h @@ -6,7 +6,6 @@ #include "models/encoder_decoder.h" namespace marian { - namespace models { class EncoderFactory : public Factory { diff --git a/src/models/transformer.h b/src/models/transformer.h index cc7c5954..e49d2ab6 100755 --- a/src/models/transformer.h +++ b/src/models/transformer.h @@ -702,7 +702,7 @@ public: rnn::States decoderStates; // apply decoder layers auto decDepth = opt("dec-depth"); - std::vector tiedLayers = opt>("transformer-tied-layers", + std::vector tiedLayers = opt>("transformer-tied-layers", std::vector()); ABORT_IF(!tiedLayers.empty() && tiedLayers.size() != decDepth, "Specified layer tying for {} layers, but decoder has {} layers", diff --git a/src/rescorer/rescorer.h b/src/rescorer/rescorer.h index 773851e6..d2d6df04 100644 --- a/src/rescorer/rescorer.h +++ b/src/rescorer/rescorer.h @@ -54,7 +54,7 @@ public: ? std::static_pointer_cast( New(options_)) : std::static_pointer_cast(New(options_))) { - ABORT_IF(options_->has("summary") && options_->get("alignment", .0f), + ABORT_IF(options_->has("summary") && options_->has("alignment"), "Alignments can not be produced with summarized score"); corpus_->prepare(); @@ -96,9 +96,9 @@ public: Ptr output = options_->get("n-best") ? std::static_pointer_cast( New(options_)) - : New(); + : New(options_); - float alignment = options_->get("alignment", .0f); + std::string alignment = options_->get("alignment", ""); bool summarize = options_->has("summary"); std::string summary = summarize ? options_->get("summary") : "cross-entropy"; @@ -134,7 +134,7 @@ public: // soft alignments for each sentence in the batch std::vector aligns(batch->size()); - if(alignment > .0f) { + if(!alignment.empty()) { getAlignmentsForBatch(builder->getAlignment(), batch, aligns); } diff --git a/src/rescorer/score_collector.cpp b/src/rescorer/score_collector.cpp new file mode 100644 index 00000000..f8e0d85f --- /dev/null +++ b/src/rescorer/score_collector.cpp @@ -0,0 +1,122 @@ +#include "rescorer/score_collector.h" + +#include "common/logging.h" +#include "common/utils.h" + +#include + +namespace marian { + +ScoreCollector::ScoreCollector(const Ptr& options) + : nextId_(0), + outStrm_(new OutputFileStream(std::cout)), + alignment_(options->get("alignment", "")), + alignmentThreshold_(getAlignmentThreshold(alignment_)) {} + +void ScoreCollector::Write(long id, const std::string& message) { + boost::mutex::scoped_lock lock(mutex_); + if(id == nextId_) { + ((std::ostream&)*outStrm_) << message << std::endl; + + ++nextId_; + + typename Outputs::const_iterator iter, iterNext; + iter = outputs_.begin(); + while(iter != outputs_.end()) { + long currId = iter->first; + + if(currId == nextId_) { + // 1st element in the map is the next + ((std::ostream&)*outStrm_) << iter->second << std::endl; + + ++nextId_; + + // delete current record, move iter on 1 + iterNext = iter; + ++iterNext; + outputs_.erase(iter); + iter = iterNext; + } else { + // not the next. stop iterating + assert(nextId_ < currId); + break; + } + } + + } else { + // save for later + outputs_[id] = message; + } +} + +void ScoreCollector::Write(long id, + float score, + const data::SoftAlignment& align) { + auto msg = std::to_string(score); + if(!alignment_.empty() && !align.empty()) + msg += " ||| " + getAlignment(align); + Write(id, msg); +} + +std::string ScoreCollector::getAlignment(const data::SoftAlignment& align) { + if(alignment_ == "soft") { + return data::SoftAlignToString(align); + } else if(alignment_ == "hard") { + return data::ConvertSoftAlignToHardAlign(align, 1.f).toString(); + } else if(alignmentThreshold_ > 0.f) { + return data::ConvertSoftAlignToHardAlign(align, alignmentThreshold_) + .toString(); + } else { + ABORT("Unrecognized word alignment type"); + } + return ""; +} + +ScoreCollectorNBest::ScoreCollectorNBest(const Ptr& options) + : ScoreCollector(options), + nBestList_(options->get>("train-sets").back()), + fname_(options->get("n-best-feature")) { + file_.reset(new InputFileStream(nBestList_)); +} + +void ScoreCollectorNBest::Write(long id, + float score, + const data::SoftAlignment& align) { + std::string line; + { + boost::mutex::scoped_lock lock(mutex_); + auto iter = buffer_.find(id); + if(iter == buffer_.end()) { + ABORT_IF(lastRead_ >= id, + "Entry {} < {} already read but not in buffer", + id, + lastRead_); + std::string line; + while(lastRead_ < id && utils::GetLine((std::istream&)*file_, line)) { + lastRead_++; + iter = buffer_.emplace(lastRead_, line).first; + } + } + + line = iter->second; + buffer_.erase(iter); + } + + ScoreCollector::Write(id, addToNBest(line, fname_, score, align)); +} + +std::string ScoreCollectorNBest::addToNBest(const std::string nbest, + const std::string feature, + float score, + const data::SoftAlignment& align) { + std::vector fields; + utils::Split(nbest, fields, "|||"); + std::stringstream ss; + if(!alignment_.empty() && !align.empty()) + ss << " " << getAlignment(align) << " |||"; + ss << fields[2] << feature << "= " << score << " "; + fields[2] = ss.str(); + return utils::Join(fields, "|||"); +} + +} // namespace marian diff --git a/src/rescorer/score_collector.h b/src/rescorer/score_collector.h index a31f0065..d0c979c2 100644 --- a/src/rescorer/score_collector.h +++ b/src/rescorer/score_collector.h @@ -1,69 +1,23 @@ #pragma once #include -#include -#include #include +#include "common/config.h" #include "common/definitions.h" #include "common/file_stream.h" -#include "common/logging.h" -#include "common/utils.h" #include "data/alignment.h" namespace marian { class ScoreCollector { public: - ScoreCollector() : nextId_(0), outStrm_(new OutputFileStream(std::cout)){}; - - virtual void Write(long id, const std::string& message) { - boost::mutex::scoped_lock lock(mutex_); - if(id == nextId_) { - ((std::ostream&)*outStrm_) << message << std::endl; - - ++nextId_; - - typename Outputs::const_iterator iter, iterNext; - iter = outputs_.begin(); - while(iter != outputs_.end()) { - long currId = iter->first; - - if(currId == nextId_) { - // 1st element in the map is the next - ((std::ostream&)*outStrm_) << iter->second << std::endl; - - ++nextId_; - - // delete current record, move iter on 1 - iterNext = iter; - ++iterNext; - outputs_.erase(iter); - iter = iterNext; - } else { - // not the next. stop iterating - assert(nextId_ < currId); - break; - } - } - - } else { - // save for later - outputs_[id] = message; - } - } + ScoreCollector(const Ptr& options); + virtual void Write(long id, const std::string& message); virtual void Write(long id, float score, - const data::SoftAlignment& align = {}) { - auto msg = std::to_string(score); - if(!align.empty()) { - auto wordAlign - = data::ConvertSoftAlignToHardAlign(align, 1.f, false, true); - msg += " ||| " + wordAlign.toString(); - } - Write(id, msg); - } + const data::SoftAlignment& align = {}); protected: long nextId_{0}; @@ -72,69 +26,42 @@ protected: typedef std::map Outputs; Outputs outputs_; + + std::string alignment_; + float alignmentThreshold_{0.f}; + + std::string getAlignment(const data::SoftAlignment& align); + + float getAlignmentThreshold(const std::string& str) { + try { + return std::max(std::stof(str), 0.f); + } catch(...) { + return 0.f; + } + } }; class ScoreCollectorNBest : public ScoreCollector { -private: - Ptr options_; +public: + ScoreCollectorNBest() = delete; + ScoreCollectorNBest(const Ptr& options); + ScoreCollectorNBest(const ScoreCollectorNBest&) = delete; + + virtual void Write(long id, + float score, + const data::SoftAlignment& align = {}); + +private: std::string nBestList_; std::string fname_; long lastRead_{-1}; UPtr file_; std::map buffer_; -public: - ScoreCollectorNBest() = delete; - - ScoreCollectorNBest(const Ptr& options) : options_(options) { - auto paths = options_->get>("train-sets"); - nBestList_ = paths.back(); - fname_ = options_->get("n-best-feature"); - file_.reset(new InputFileStream(nBestList_)); - } - - ScoreCollectorNBest(const ScoreCollectorNBest&) = delete; - std::string addToNBest(const std::string nbest, const std::string feature, float score, - const data::SoftAlignment& align = {}) { - std::vector fields; - utils::Split(nbest, fields, "|||"); - std::stringstream ss; - if(!align.empty()) { - auto wordAlign - = data::ConvertSoftAlignToHardAlign(align, 1.f, false, true); - ss << " " << wordAlign.toString() << " |||"; - } - ss << fields[2] << feature << "= " << score << " "; - fields[2] = ss.str(); - return utils::Join(fields, "|||"); - } - - virtual void Write(long id, float score, const data::SoftAlignment& align) { - std::string line; - { - boost::mutex::scoped_lock lock(mutex_); - auto iter = buffer_.find(id); - if(iter == buffer_.end()) { - ABORT_IF(lastRead_ >= id, - "Entry {} < {} already read but not in buffer", - id, - lastRead_); - std::string line; - while(lastRead_ < id && utils::GetLine((std::istream&)*file_, line)) { - lastRead_++; - iter = buffer_.emplace(lastRead_, line).first; - } - } - - line = iter->second; - buffer_.erase(iter); - } - - ScoreCollector::Write(id, addToNBest(line, fname_, score, align)); - } + const data::SoftAlignment& align = {}); }; } // namespace marian diff --git a/src/rnn/attention.h b/src/rnn/attention.h index a744cc07..74c40172 100644 --- a/src/rnn/attention.h +++ b/src/rnn/attention.h @@ -5,7 +5,6 @@ #include "rnn/types.h" namespace marian { - namespace rnn { Expr attOps(Expr va, Expr context, Expr state); diff --git a/src/tensors/allocator.h b/src/tensors/allocator.h index 17d9d5cf..068905ea 100644 --- a/src/tensors/allocator.h +++ b/src/tensors/allocator.h @@ -173,10 +173,7 @@ public: size_t bytes, size_t step, size_t alignment = 256) - : device_(device), - available_(0), - step_(step), - alignment_(alignment) { + : device_(device), available_(0), step_(step), alignment_(alignment) { reserve(bytes); } diff --git a/src/tensors/device.h b/src/tensors/device.h index 9f03c7cf..297153a9 100644 --- a/src/tensors/device.h +++ b/src/tensors/device.h @@ -60,7 +60,7 @@ public: class WrappedDevice : public marian::Device { public: WrappedDevice(DeviceId deviceId, size_t alignment = 256) - : marian::Device(deviceId, alignment) {} + : marian::Device(deviceId, alignment) {} ~WrappedDevice() {} void set(uint8_t* data, size_t size) { @@ -70,9 +70,11 @@ public: // doesn't allocate anything, just checks size. void reserve(size_t size) { - ABORT_IF(size > size_, "Requested size {} is larger than pre-allocated size {}", size, size_); + ABORT_IF(size > size_, + "Requested size {} is larger than pre-allocated size {}", + size, + size_); } - }; } // namespace cpu diff --git a/src/tensors/dispatch.h b/src/tensors/dispatch.h index f0ef22f4..094f156c 100644 --- a/src/tensors/dispatch.h +++ b/src/tensors/dispatch.h @@ -2,46 +2,46 @@ #ifdef CUDA_FOUND -#define DISPATCH1(Function, Arg1) \ - namespace gpu { \ - void Function(Arg1); \ - } \ - namespace cpu { \ - void Function(Arg1); \ - } \ - void Function(Arg1 arg1) { \ +#define DISPATCH1(Function, Arg1) \ + namespace gpu { \ + void Function(Arg1); \ + } \ + namespace cpu { \ + void Function(Arg1); \ + } \ + void Function(Arg1 arg1) { \ if(arg1->getBackend()->getDeviceId().type == DeviceType::gpu) \ - gpu::Function(arg1); \ - else \ - cpu::Function(arg1); \ + gpu::Function(arg1); \ + else \ + cpu::Function(arg1); \ } -#define DISPATCH2(Function, Arg1, Arg2) \ - namespace gpu { \ - void Function(Arg1, Arg2); \ - } \ - namespace cpu { \ - void Function(Arg1, Arg2); \ - } \ - static inline void Function(Arg1 arg1, Arg2 arg2) { \ +#define DISPATCH2(Function, Arg1, Arg2) \ + namespace gpu { \ + void Function(Arg1, Arg2); \ + } \ + namespace cpu { \ + void Function(Arg1, Arg2); \ + } \ + static inline void Function(Arg1 arg1, Arg2 arg2) { \ if(arg1->getBackend()->getDeviceId().type == DeviceType::gpu) \ - gpu::Function(arg1, arg2); \ - else \ - cpu::Function(arg1, arg2); \ + gpu::Function(arg1, arg2); \ + else \ + cpu::Function(arg1, arg2); \ } -#define DISPATCH3(Function, Arg1, Arg2, Arg3) \ - namespace gpu { \ - void Function(Arg1, Arg2, Arg3); \ - } \ - namespace cpu { \ - void Function(Arg1, Arg2, Arg3); \ - } \ - static inline void Function(Arg1 arg1, Arg2 arg2, Arg3 arg3) { \ - if(arg1->getBackend()->getDeviceId().type == DeviceType::gpu) \ - gpu::Function(arg1, arg2, arg3); \ - else \ - cpu::Function(arg1, arg2, arg3); \ +#define DISPATCH3(Function, Arg1, Arg2, Arg3) \ + namespace gpu { \ + void Function(Arg1, Arg2, Arg3); \ + } \ + namespace cpu { \ + void Function(Arg1, Arg2, Arg3); \ + } \ + static inline void Function(Arg1 arg1, Arg2 arg2, Arg3 arg3) { \ + if(arg1->getBackend()->getDeviceId().type == DeviceType::gpu) \ + gpu::Function(arg1, arg2, arg3); \ + else \ + cpu::Function(arg1, arg2, arg3); \ } #define DISPATCH4(Function, Arg1, Arg2, Arg3, Arg4) \ @@ -52,25 +52,25 @@ void Function(Arg1, Arg2, Arg3, Arg4); \ } \ static inline void Function(Arg1 arg1, Arg2 arg2, Arg3 arg3, Arg4 arg4) { \ - if(arg1->getBackend()->getDeviceId().type == DeviceType::gpu) \ + if(arg1->getBackend()->getDeviceId().type == DeviceType::gpu) \ gpu::Function(arg1, arg2, arg3, arg4); \ else \ cpu::Function(arg1, arg2, arg3, arg4); \ } -#define DISPATCH5(Function, Arg1, Arg2, Arg3, Arg4, Arg5) \ - namespace gpu { \ - void Function(Arg1, Arg2, Arg3, Arg4, Arg5); \ - } \ - namespace cpu { \ - void Function(Arg1, Arg2, Arg3, Arg4, Arg5); \ - } \ - static inline void Function( \ - Arg1 arg1, Arg2 arg2, Arg3 arg3, Arg4 arg4, Arg5 arg5) { \ +#define DISPATCH5(Function, Arg1, Arg2, Arg3, Arg4, Arg5) \ + namespace gpu { \ + void Function(Arg1, Arg2, Arg3, Arg4, Arg5); \ + } \ + namespace cpu { \ + void Function(Arg1, Arg2, Arg3, Arg4, Arg5); \ + } \ + static inline void Function( \ + Arg1 arg1, Arg2 arg2, Arg3 arg3, Arg4 arg4, Arg5 arg5) { \ if(arg1->getBackend()->getDeviceId().type == DeviceType::gpu) \ - gpu::Function(arg1, arg2, arg3, arg4, arg5); \ - else \ - cpu::Function(arg1, arg2, arg3, arg4, arg5); \ + gpu::Function(arg1, arg2, arg3, arg4, arg5); \ + else \ + cpu::Function(arg1, arg2, arg3, arg4, arg5); \ } #define DISPATCH6(Function, Arg1, Arg2, Arg3, Arg4, Arg5, Arg6) \ @@ -82,7 +82,7 @@ } \ static inline void Function( \ Arg1 arg1, Arg2 arg2, Arg3 arg3, Arg4 arg4, Arg5 arg5, Arg6 arg6) { \ - if(arg1->getBackend()->getDeviceId().type == DeviceType::gpu) \ + if(arg1->getBackend()->getDeviceId().type == DeviceType::gpu) \ gpu::Function(arg1, arg2, arg3, arg4, arg5, arg6); \ else \ cpu::Function(arg1, arg2, arg3, arg4, arg5, arg6); \ @@ -102,7 +102,7 @@ Arg5 arg5, \ Arg6 arg6, \ Arg7 arg7) { \ - if(arg1->getBackend()->getDeviceId().type == DeviceType::gpu) \ + if(arg1->getBackend()->getDeviceId().type == DeviceType::gpu) \ gpu::Function(arg1, arg2, arg3, arg4, arg5, arg6, arg7); \ else \ cpu::Function(arg1, arg2, arg3, arg4, arg5, arg6, arg7); \ @@ -123,7 +123,7 @@ Arg6 arg6, \ Arg7 arg7, \ Arg8 arg8) { \ - if(arg1->getBackend()->getDeviceId().type == DeviceType::gpu) \ + if(arg1->getBackend()->getDeviceId().type == DeviceType::gpu) \ gpu::Function(arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8); \ else \ cpu::Function(arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8); \ @@ -146,7 +146,7 @@ Arg7 arg7, \ Arg8 arg8, \ Arg9 arg9) { \ - if(arg1->getBackend()->getDeviceId().type == DeviceType::gpu) \ + if(arg1->getBackend()->getDeviceId().type == DeviceType::gpu) \ gpu::Function(arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9); \ else \ cpu::Function(arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9); \ diff --git a/src/training/gradient_dropping/sparse_tensor.h b/src/training/gradient_dropping/sparse_tensor.h index ac1726b4..462d9b88 100644 --- a/src/training/gradient_dropping/sparse_tensor.h +++ b/src/training/gradient_dropping/sparse_tensor.h @@ -177,8 +177,8 @@ public: } #ifdef CUDA_FOUND else { - std::vector outputs - = gpu::lower_bounds(indices(), values, size(), backend_->getDeviceId()); + std::vector outputs = gpu::lower_bounds( + indices(), values, size(), backend_->getDeviceId()); startOffset = outputs[0]; endOffset = outputs[1]; diff --git a/src/translator/beam_search.h b/src/translator/beam_search.h index dc5ababd..c532c5f4 100755 --- a/src/translator/beam_search.h +++ b/src/translator/beam_search.h @@ -41,10 +41,10 @@ public: Ptr batch) { Beams newBeams(beams.size()); - std::vector alignments; - if(options_->get("alignment", 0.f)) + std::vector align; + if(options_->has("alignment")) // Use alignments from the first scorer, even if ensemble - alignments = scorers_[0]->getAlignment(); + align = scorers_[0]->getAlignment(); for(size_t i = 0; i < keys.size(); ++i) { // Keys contains indices to vocab items in the entire beam. @@ -93,10 +93,9 @@ public: } // Set alignments - if(!alignments.empty()) { - auto align = getAlignmentsForHypothesis( - alignments, batch, beamSize, beamHypIdx, beamIdx); - hyp->SetAlignment(align); + if(!align.empty()) { + hyp->SetAlignment( + getAlignmentsForHypothesis(align, batch, beamHypIdx, beamIdx)); } newBeam.push_back(hyp); @@ -106,9 +105,8 @@ public: } std::vector getAlignmentsForHypothesis( - const std::vector alignments, + const std::vector alignAll, Ptr batch, - int beamSize, int beamHypIdx, int beamIdx) { // Let's B be the beam size, N be the number of batched sentences, @@ -136,7 +134,7 @@ public: size_t a = ((batchWidth * beamHypIdx) + beamIdx) + (batchSize * w); size_t m = a % batchWidth; if(batch->front()->mask()[m] != 0) - align.emplace_back(alignments[a]); + align.emplace_back(alignAll[a]); } return align; diff --git a/src/translator/output_printer.cpp b/src/translator/output_printer.cpp index 45ec61ad..9a5df364 100644 --- a/src/translator/output_printer.cpp +++ b/src/translator/output_printer.cpp @@ -2,18 +2,30 @@ namespace marian { -data::WordAlignment OutputPrinter::getAlignment(const Ptr& hyp, - float threshold) { - data::SoftAlignment aligns; - // Skip EOS - auto last = hyp->GetPrevHyp(); - // Get soft alignments for each target word +std::string OutputPrinter::getAlignment(const Ptr& hyp) { + data::SoftAlignment align; + auto last = hyp; + // get soft alignments for each target word starting from the last one while(last->GetPrevHyp().get() != nullptr) { - aligns.push_back(last->GetAlignment()); + align.push_back(last->GetAlignment()); last = last->GetPrevHyp(); } - return data::ConvertSoftAlignToHardAlign(aligns, threshold, true); + // reverse alignments + std::reverse(align.begin(), align.end()); + + if(alignment_ == "soft") { + return data::SoftAlignToString(align); + } else if(alignment_ == "hard") { + return data::ConvertSoftAlignToHardAlign(align, 1.f).toString(); + } else if(alignmentThreshold_ > 0.f) { + return data::ConvertSoftAlignToHardAlign(align, alignmentThreshold_) + .toString(); + } else { + ABORT("Unrecognized word alignment type"); + } + + return ""; } } // namespace marian diff --git a/src/translator/output_printer.h b/src/translator/output_printer.h index b1ec5647..ce4cfa7d 100644 --- a/src/translator/output_printer.h +++ b/src/translator/output_printer.h @@ -19,7 +19,8 @@ public: nbest_(options->get("n-best", false) ? options->get("beam-size") : 0), - alignment_(options->get("alignment", 0.f)) {} + alignment_(options->get("alignment", "")), + alignmentThreshold_(getAlignmentThreshold(alignment_)) {} template void print(Ptr history, OStream& best1, OStream& bestn) { @@ -33,12 +34,10 @@ public: std::string translation = utils::Join((*vocab_)(words), " ", reverse_); bestn << history->GetLineNum() << " ||| " << translation; - if(alignment_ > 0.f) { - bestn << " ||| " << getAlignment(hypo, alignment_).toString(); - } + if(!alignment_.empty()) + bestn << " ||| " << getAlignment(hypo); bestn << " |||"; - if(hypo->GetCostBreakdown().empty()) { bestn << " F0=" << hypo->GetCost(); } else { @@ -62,9 +61,9 @@ public: std::string translation = utils::Join((*vocab_)(words), " ", reverse_); best1 << translation; - if(alignment_ > 0.f) { + if(!alignment_.empty()) { const auto& hypo = std::get<1>(result); - best1 << " ||| " << getAlignment(hypo, alignment_).toString(); + best1 << " ||| " << getAlignment(hypo); } best1 << std::flush; } @@ -73,8 +72,17 @@ private: Ptr vocab_; bool reverse_{false}; size_t nbest_{0}; - float alignment_{0.f}; + std::string alignment_; + float alignmentThreshold_{0.f}; - data::WordAlignment getAlignment(const Ptr& hyp, float threshold); + std::string getAlignment(const Ptr& hyp); + + float getAlignmentThreshold(const std::string& str) { + try { + return std::max(std::stof(str), 0.f); + } catch(...) { + return 0.f; + } + } }; } // namespace marian