зеркало из https://github.com/mozilla/marian.git
Merged PR 4622: Merge with public master
This is another update to public master, changes include: * Alignment and attention matrix output for decoder and scorer * First steps at refactoring command line options to get rid of boost::options * Make stand-alone decoder decode line by line. External regressions tests all pass, there will be more refactoring once we got rid of boost::options. Related work items: #90261, #90262
This commit is contained in:
Коммит
87c98ccdb0
13
CHANGELOG.md
13
CHANGELOG.md
|
@ -8,12 +8,18 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
|
|||
## [Unreleased]
|
||||
|
||||
### Added
|
||||
- Returning hard alignments by scorer
|
||||
- Word alignment generation in scorer
|
||||
- Attention output generation in decoder and scorer with `--alignment soft`
|
||||
|
||||
### Fixed
|
||||
- Delayed output in line-by-line translation
|
||||
|
||||
### Changed
|
||||
- Generated word alignments include alignments for target EOS tokens
|
||||
|
||||
## [1.6.0] - 2018-08-08
|
||||
|
||||
### Added
|
||||
|
||||
- Faster training (20-30%) by optimizing gradient popagation of biases
|
||||
- Returning Moses-style hard alignments during decoding single models,
|
||||
ensembles and n-best lists
|
||||
|
@ -36,7 +42,6 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
|
|||
- Seamless training continuation with exponential smoothing
|
||||
|
||||
### Fixed
|
||||
|
||||
- A couple of bugs in "selection" (transpose, shift, cols, rows) operators
|
||||
during back-prob for a very specific case: one of the operators is the first
|
||||
operator after a branch, in that case gradient propgation might be
|
||||
|
@ -49,14 +54,12 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
|
|||
## [1.5.0] - 2018-06-17
|
||||
|
||||
### Added
|
||||
|
||||
- Average Attention Networks for Transformer model
|
||||
- 16-bit matrix multiplication on CPU
|
||||
- Memoization for constant nodes for decoding
|
||||
- Autotuning for decoding
|
||||
|
||||
### Fixed
|
||||
|
||||
- GPU decoding optimizations, about 2x faster decoding of transformer models
|
||||
- Multi-node MPI-based training on GPUs
|
||||
|
||||
|
|
|
@ -25,6 +25,7 @@ message(STATUS "Project version: ${PROJECT_VERSION_STRING_FULL}")
|
|||
|
||||
# Set compilation flags
|
||||
set(CMAKE_CXX_FLAGS_RELEASE " -std=c++11 -O3 -Ofast -m64 -pthread -march=native -Wl,--no-as-needed -funroll-loops -ffinite-math-only -fPIC -Wno-unused-result -Wno-deprecated -Wno-deprecated-gpu-targets")
|
||||
set(CMAKE_CXX_FLAGS_NONATIVE " -std=c++11 -O3 -Ofast -m64 -pthread -march=x86-64 -mavx -Wl,--no-as-needed -funroll-loops -ffinite-math-only -fPIC -Wno-unused-result -Wno-deprecated -Wno-deprecated-gpu-targets")
|
||||
set(CMAKE_CXX_FLAGS_DEBUG " -std=c++11 -g -O0 -pthread -fPIC -Wno-unused-result -Wno-deprecated -Wno-deprecated-gpu-targets")
|
||||
set(CMAKE_CXX_FLAGS_ST "${CMAKE_CXX_FLAGS_RELEASE} -DNDEBUG")
|
||||
set(CMAKE_CXX_FLAGS_PROFILE "${CMAKE_CXX_FLAGS_RELEASE} -pg -g")
|
||||
|
|
|
@ -52,6 +52,8 @@ add_library(marian STATIC
|
|||
models/model_factory.cpp
|
||||
models/encoder_decoder.cpp
|
||||
|
||||
rescorer/score_collector.cpp
|
||||
|
||||
translator/history.cpp
|
||||
translator/output_collector.cpp
|
||||
translator/output_printer.cpp
|
||||
|
|
|
@ -0,0 +1,124 @@
|
|||
#pragma once
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "3rd_party/yaml-cpp/yaml.h"
|
||||
#include "common/logging.h"
|
||||
|
||||
namespace marian {
|
||||
namespace cli {
|
||||
|
||||
// helper to replace environment-variable expressions of the form ${VARNAME} in
|
||||
// a string
|
||||
static std::string InterpolateEnvVars(std::string str) {
|
||||
// temporary workaround for MS-internal PhillyOnAzure cluster: warm storage
|
||||
// presently has the form /hdfs/VC instead of /{gfs,hdfs}/CLUSTER/VC
|
||||
#if 1
|
||||
if(getenv("PHILLY_JOB_ID")) {
|
||||
const char* cluster = getenv("PHILLY_CLUSTER");
|
||||
const char* vc = getenv("PHILLY_VC");
|
||||
// this environment variable exists when running on the cluster
|
||||
if(cluster && vc) {
|
||||
static const std::string s_gfsPrefix
|
||||
= std::string("/gfs/") + cluster + "/" + vc + "/";
|
||||
static const std::string s_hdfsPrefix
|
||||
= std::string("/hdfs/") + cluster + "/" + vc + "/";
|
||||
if(str.find(s_gfsPrefix) == 0)
|
||||
str = std::string("/hdfs/") + vc + "/" + str.substr(s_gfsPrefix.size());
|
||||
else if(str.find(s_hdfsPrefix) == 0)
|
||||
str = std::string("/hdfs/") + vc + "/"
|
||||
+ str.substr(s_hdfsPrefix.size());
|
||||
}
|
||||
}
|
||||
#endif
|
||||
for(;;) {
|
||||
const auto pos = str.find("${");
|
||||
if(pos == std::string::npos)
|
||||
return str;
|
||||
const auto epos = str.find("}", pos + 2);
|
||||
ABORT_IF(epos == std::string::npos,
|
||||
"interpolate-env-vars option: ${{ without matching }} in '{}'",
|
||||
str.c_str());
|
||||
// isolate the variable name
|
||||
const auto var = str.substr(pos + 2, epos - (pos + 2));
|
||||
const auto val = getenv(var.c_str());
|
||||
ABORT_IF(!val,
|
||||
"interpolate-env-vars option: environment variable '{}' not "
|
||||
"defined in '{}'",
|
||||
var.c_str(),
|
||||
str.c_str());
|
||||
// replace it; then try again for further replacements
|
||||
str = str.substr(0, pos) + val + str.substr(epos + 1);
|
||||
}
|
||||
}
|
||||
|
||||
// helper to implement interpolate-env-vars and relative-paths options
|
||||
static void ProcessPaths(
|
||||
YAML::Node& node,
|
||||
const std::function<std::string(std::string)>& TransformPath,
|
||||
const std::set<std::string>& PATHS,
|
||||
bool isPath = false) {
|
||||
if(isPath) {
|
||||
if(node.Type() == YAML::NodeType::Scalar) {
|
||||
std::string nodePath = node.as<std::string>();
|
||||
// transform the path
|
||||
if(!nodePath.empty())
|
||||
node = TransformPath(nodePath);
|
||||
}
|
||||
|
||||
if(node.Type() == YAML::NodeType::Sequence) {
|
||||
for(auto&& sub : node) {
|
||||
ProcessPaths(sub, TransformPath, PATHS, true);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
switch(node.Type()) {
|
||||
case YAML::NodeType::Sequence:
|
||||
for(auto&& sub : node) {
|
||||
ProcessPaths(sub, TransformPath, PATHS, false);
|
||||
}
|
||||
break;
|
||||
case YAML::NodeType::Map:
|
||||
for(auto&& sub : node) {
|
||||
std::string key = sub.first.as<std::string>();
|
||||
ProcessPaths(sub.second, TransformPath, PATHS, PATHS.count(key) > 0);
|
||||
}
|
||||
break;
|
||||
default:
|
||||
// it is OK
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// helper to convert a Yaml node recursively into a string
|
||||
static void OutputYaml(const YAML::Node node, YAML::Emitter& out) {
|
||||
std::set<std::string> sorter;
|
||||
switch(node.Type()) {
|
||||
case YAML::NodeType::Null: out << node; break;
|
||||
case YAML::NodeType::Scalar: out << node; break;
|
||||
case YAML::NodeType::Sequence:
|
||||
out << YAML::BeginSeq;
|
||||
for(auto&& n : node)
|
||||
OutputYaml(n, out);
|
||||
out << YAML::EndSeq;
|
||||
break;
|
||||
case YAML::NodeType::Map:
|
||||
for(auto& n : node)
|
||||
sorter.insert(n.first.as<std::string>());
|
||||
out << YAML::BeginMap;
|
||||
for(auto& key : sorter) {
|
||||
out << YAML::Key;
|
||||
out << key;
|
||||
out << YAML::Value;
|
||||
OutputYaml(node[key], out);
|
||||
}
|
||||
out << YAML::EndMap;
|
||||
break;
|
||||
case YAML::NodeType::Undefined: out << node; break;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
} // namespace cli
|
||||
} // namespace marian
|
|
@ -29,7 +29,7 @@ YAML::Node& Config::get() {
|
|||
|
||||
void Config::log() {
|
||||
YAML::Emitter out;
|
||||
OutputYaml(config_, out);
|
||||
cli::OutputYaml(config_, out);
|
||||
std::string configString = out.c_str();
|
||||
|
||||
std::vector<std::string> results;
|
||||
|
|
|
@ -1,14 +1,18 @@
|
|||
#pragma once
|
||||
|
||||
#include <boost/program_options.hpp>
|
||||
|
||||
#include "3rd_party/yaml-cpp/yaml.h"
|
||||
#include "common/cli_helper.h"
|
||||
#include "common/config_parser.h"
|
||||
#include "common/file_stream.h"
|
||||
#include "common/io.h"
|
||||
#include "common/logging.h"
|
||||
#include "common/utils.h"
|
||||
#ifndef _WIN32 // TODO: why are these needed by a config parser? Can they be
|
||||
// removed for Linux as well?
|
||||
|
||||
// TODO: why are these needed by a config parser? Can they be removed for Linux
|
||||
// as well?
|
||||
#ifndef _WIN32
|
||||
#include <sys/ioctl.h>
|
||||
#include <unistd.h>
|
||||
#endif
|
||||
|
@ -125,7 +129,7 @@ public:
|
|||
|
||||
friend std::ostream& operator<<(std::ostream& out, const Config& config) {
|
||||
YAML::Emitter outYaml;
|
||||
OutputYaml(config.get(), outYaml);
|
||||
cli::OutputYaml(config.get(), outYaml);
|
||||
out << outYaml.c_str();
|
||||
return out;
|
||||
}
|
||||
|
|
|
@ -1,28 +1,27 @@
|
|||
#include <algorithm>
|
||||
#include <boost/algorithm/string.hpp>
|
||||
#include <set>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
|
||||
#include <boost/algorithm/string.hpp>
|
||||
|
||||
#if MKL_FOUND
|
||||
//#include <omp.h>
|
||||
#include <mkl.h>
|
||||
#else
|
||||
#if BLAS_FOUND
|
||||
//#include <omp.h>
|
||||
#include <cblas.h>
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#include "common/definitions.h"
|
||||
|
||||
#include "common/cli_helper.h"
|
||||
#include "common/config.h"
|
||||
#include "common/config_parser.h"
|
||||
#include "common/file_stream.h"
|
||||
#include "common/logging.h"
|
||||
#include "common/version.h"
|
||||
|
||||
#include "common/regex.h"
|
||||
#include "common/version.h"
|
||||
|
||||
#define SET_OPTION(key, type) \
|
||||
do { \
|
||||
|
@ -61,34 +60,6 @@ uint16_t guess_terminal_width(uint16_t max_width) {
|
|||
return max_width ? std::min(cols, max_width) : cols;
|
||||
}
|
||||
|
||||
// helper to convert a Yaml node recursively into a string
|
||||
void OutputYaml(const YAML::Node node, YAML::Emitter& out) {
|
||||
std::set<std::string> sorter;
|
||||
switch(node.Type()) {
|
||||
case YAML::NodeType::Null: out << node; break;
|
||||
case YAML::NodeType::Scalar: out << node; break;
|
||||
case YAML::NodeType::Sequence:
|
||||
out << YAML::BeginSeq;
|
||||
for(auto&& n : node)
|
||||
OutputYaml(n, out);
|
||||
out << YAML::EndSeq;
|
||||
break;
|
||||
case YAML::NodeType::Map:
|
||||
for(auto& n : node)
|
||||
sorter.insert(n.first.as<std::string>());
|
||||
out << YAML::BeginMap;
|
||||
for(auto& key : sorter) {
|
||||
out << YAML::Key;
|
||||
out << key;
|
||||
out << YAML::Value;
|
||||
OutputYaml(node[key], out);
|
||||
}
|
||||
out << YAML::EndMap;
|
||||
break;
|
||||
case YAML::NodeType::Undefined: out << node; break;
|
||||
}
|
||||
}
|
||||
|
||||
const std::set<std::string> PATHS = {"model",
|
||||
"models",
|
||||
"train-sets",
|
||||
|
@ -100,88 +71,6 @@ const std::set<std::string> PATHS = {"model",
|
|||
"valid-translation-output",
|
||||
"log"};
|
||||
|
||||
// helper to implement interpolate-env-vars and relative-paths options
|
||||
static void processPaths(
|
||||
YAML::Node& node,
|
||||
const std::function<std::string(std::string)>& TransformPath,
|
||||
bool isPath = false) {
|
||||
if(isPath) {
|
||||
if(node.Type() == YAML::NodeType::Scalar) {
|
||||
std::string nodePath = node.as<std::string>();
|
||||
// transform the path
|
||||
if(!nodePath.empty())
|
||||
node = TransformPath(nodePath);
|
||||
}
|
||||
|
||||
if(node.Type() == YAML::NodeType::Sequence) {
|
||||
for(auto&& sub : node) {
|
||||
processPaths(sub, TransformPath, true);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
switch(node.Type()) {
|
||||
case YAML::NodeType::Sequence:
|
||||
for(auto&& sub : node) {
|
||||
processPaths(sub, TransformPath, false);
|
||||
}
|
||||
break;
|
||||
case YAML::NodeType::Map:
|
||||
for(auto&& sub : node) {
|
||||
std::string key = sub.first.as<std::string>();
|
||||
processPaths(sub.second, TransformPath, PATHS.count(key) > 0);
|
||||
}
|
||||
break;
|
||||
default:
|
||||
// it is OK
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// helper to replace environment-variable expressions of the form ${VARNAME} in
|
||||
// a string
|
||||
static std::string interpolateEnvVars(std::string str) {
|
||||
// temporary workaround for MS-internal PhillyOnAzure cluster: warm storage
|
||||
// presently has the form /hdfs/VC instead of /{gfs,hdfs}/CLUSTER/VC
|
||||
// @TODO: remove this workaround
|
||||
#if 1
|
||||
if(getenv("PHILLY_JOB_ID")) {
|
||||
const char* cluster = getenv("PHILLY_CLUSTER");
|
||||
const char* vc = getenv("PHILLY_VC");
|
||||
// this environment variable exists when running on the cluster
|
||||
if(cluster && vc) {
|
||||
static const std::string s_gfsPrefix
|
||||
= std::string("/gfs/") + cluster + "/" + vc + "/";
|
||||
static const std::string s_hdfsPrefix
|
||||
= std::string("/hdfs/") + cluster + "/" + vc + "/";
|
||||
if(str.find(s_gfsPrefix) == 0)
|
||||
str = std::string("/hdfs/") + vc + "/" + str.substr(s_gfsPrefix.size());
|
||||
else if(str.find(s_hdfsPrefix) == 0)
|
||||
str = std::string("/hdfs/") + vc + "/"
|
||||
+ str.substr(s_hdfsPrefix.size());
|
||||
}
|
||||
}
|
||||
#endif
|
||||
for(;;) {
|
||||
const auto pos = str.find("${");
|
||||
if(pos == std::string::npos)
|
||||
return str;
|
||||
const auto epos = str.find("}", pos + 2);
|
||||
ABORT_IF(epos == std::string::npos,
|
||||
"interpolate-env-vars option: ${{ without matching }} in '{}'",
|
||||
str.c_str());
|
||||
// isolate the variable name
|
||||
const auto var = str.substr(pos + 2, epos - (pos + 2));
|
||||
const auto val = getenv(var.c_str());
|
||||
ABORT_IF(!val,
|
||||
"interpolate-env-vars option: environment variable '{}' not "
|
||||
"defined in '{}'",
|
||||
var.c_str(),
|
||||
str.c_str());
|
||||
// replace it; then try again for further replacements
|
||||
str = str.substr(0, pos) + val + str.substr(epos + 1);
|
||||
}
|
||||
}
|
||||
|
||||
bool ConfigParser::has(const std::string& key) const {
|
||||
return config_[key];
|
||||
|
@ -353,7 +242,7 @@ void ConfigParser::addOptionsModel(po::options_description& desc) {
|
|||
("ignore-model-config", po::value<bool>()->zero_tokens()->default_value(false),
|
||||
"Ignore the model configuration saved in npz file")
|
||||
("type", po::value<std::string>()->default_value("amun"),
|
||||
"Model type (possible values: amun, nematus, s2s, multi-s2s, transformer)")
|
||||
"Model type: amun, nematus, s2s, multi-s2s, transformer")
|
||||
("dim-vocabs", po::value<std::vector<int>>()
|
||||
->multitoken()
|
||||
->default_value(std::vector<int>({0, 0}), "0 0"),
|
||||
|
@ -541,7 +430,7 @@ void ConfigParser::addOptionsTraining(po::options_description& desc) {
|
|||
("maxi-batch", po::value<int>()->default_value(100),
|
||||
"Number of batches to preload for length-based sorting")
|
||||
("maxi-batch-sort", po::value<std::string>()->default_value("trg"),
|
||||
"Sorting strategy for maxi-batch: trg (default) src none")
|
||||
"Sorting strategy for maxi-batch: trg, src, none")
|
||||
("optimizer,o", po::value<std::string>()->default_value("adam"),
|
||||
"Optimization algorithm (possible values: sgd, adagrad, adam")
|
||||
("optimizer-params", po::value<std::vector<float>>()
|
||||
|
@ -554,8 +443,8 @@ void ConfigParser::addOptionsTraining(po::options_description& desc) {
|
|||
("lr-decay", po::value<double>()->default_value(0.0),
|
||||
"Decay factor for learning rate: lr = lr * arg (0 to disable)")
|
||||
("lr-decay-strategy", po::value<std::string>()->default_value("epoch+stalled"),
|
||||
"Strategy for learning rate decaying "
|
||||
"(possible values: epoch, batches, stalled, epoch+batches, epoch+stalled)")
|
||||
"Strategy for learning rate decaying: epoch, batches, stalled, "
|
||||
"epoch+batches, epoch+stalled")
|
||||
("lr-decay-start", po::value<std::vector<size_t>>()
|
||||
->multitoken()
|
||||
->default_value(std::vector<size_t>({10,1}), "10 1"),
|
||||
|
@ -603,14 +492,14 @@ void ConfigParser::addOptionsTraining(po::options_description& desc) {
|
|||
("guided-alignment", po::value<std::string>(),
|
||||
"Use guided alignment to guide attention")
|
||||
("guided-alignment-cost", po::value<std::string>()->default_value("ce"),
|
||||
"Cost type for guided alignment. Possible values: ce (cross-entropy), "
|
||||
"mse (mean square error), mult (multiplication)")
|
||||
"Cost type for guided alignment: ce (cross-entropy), mse (mean square "
|
||||
"error), mult (multiplication)")
|
||||
("guided-alignment-weight", po::value<double>()->default_value(1),
|
||||
"Weight for guided alignment cost")
|
||||
("data-weighting", po::value<std::string>(),
|
||||
"File with sentence or word weights")
|
||||
("data-weighting-type", po::value<std::string>()->default_value("sentence"),
|
||||
"Processing level for data weighting. Possible values: sentence, word")
|
||||
"Processing level for data weighting: sentence, word")
|
||||
|
||||
//("drop-rate", po::value<double>()->default_value(0),
|
||||
// "Gradient drop ratio (read: https://arxiv.org/abs/1704.05021)")
|
||||
|
@ -743,15 +632,15 @@ void ConfigParser::addOptionsTranslate(po::options_description& desc) {
|
|||
("maxi-batch", po::value<int>()->default_value(1),
|
||||
"Number of batches to preload for length-based sorting")
|
||||
("maxi-batch-sort", po::value<std::string>()->default_value("none"),
|
||||
"Sorting strategy for maxi-batch: none (default) src")
|
||||
"Sorting strategy for maxi-batch: none, src")
|
||||
("n-best", po::value<bool>()->zero_tokens()->default_value(false),
|
||||
"Display n-best list")
|
||||
("shortlist", po::value<std::vector<std::string>>()->multitoken(),
|
||||
"Use softmax shortlist: path first best prune")
|
||||
("weights", po::value<std::vector<float>>()->multitoken(),
|
||||
"Scorer weights")
|
||||
("alignment", po::value<float>()->default_value(0.f)->implicit_value(1.f),
|
||||
"Return word alignments")
|
||||
("alignment", po::value<std::string>()->implicit_value("1"),
|
||||
"Return word alignment. Possible values: 0.0-1.0, hard, soft")
|
||||
// TODO: the options should be available only in server
|
||||
("port,p", po::value<size_t>()->default_value(8080),
|
||||
"Port number for web socket server")
|
||||
|
@ -805,9 +694,9 @@ void ConfigParser::addOptionsRescore(po::options_description& desc) {
|
|||
("maxi-batch", po::value<int>()->default_value(100),
|
||||
"Number of batches to preload for length-based sorting")
|
||||
("maxi-batch-sort", po::value<std::string>()->default_value("trg"),
|
||||
"Sorting strategy for maxi-batch: trg (default) src none")
|
||||
("alignment", po::value<float>()->default_value(0.f)->implicit_value(1.f),
|
||||
"Return word alignments")
|
||||
"Sorting strategy for maxi-batch: trg (default), src, none")
|
||||
("alignment", po::value<std::string>()->implicit_value("1"),
|
||||
"Return word alignments. Possible values: 0.0-1.0, hard, soft")
|
||||
;
|
||||
// clang-format on
|
||||
desc.add(rescore);
|
||||
|
@ -856,17 +745,17 @@ void ConfigParser::parseOptions(int argc, char** argv, bool doValidate) {
|
|||
exit(0);
|
||||
}
|
||||
|
||||
const auto& interpolateEnvVarsIfRequested
|
||||
const auto& InterpolateEnvVarsIfRequested
|
||||
= [&](std::string str) -> std::string {
|
||||
if(vm_["interpolate-env-vars"].as<bool>())
|
||||
str = interpolateEnvVars(str);
|
||||
str = cli::InterpolateEnvVars(str);
|
||||
return str;
|
||||
};
|
||||
|
||||
bool loadConfig = vm_.count("config");
|
||||
bool reloadConfig
|
||||
= (mode_ == ConfigMode::training)
|
||||
&& boost::filesystem::exists(interpolateEnvVarsIfRequested(
|
||||
&& boost::filesystem::exists(InterpolateEnvVarsIfRequested(
|
||||
vm_["model"].as<std::string>() + ".yml"))
|
||||
&& !vm_["no-reload"].as<bool>();
|
||||
std::vector<std::string> configPaths;
|
||||
|
@ -875,14 +764,14 @@ void ConfigParser::parseOptions(int argc, char** argv, bool doValidate) {
|
|||
configPaths = vm_["config"].as<std::vector<std::string>>();
|
||||
config_ = YAML::Node();
|
||||
for(auto& configPath : configPaths) {
|
||||
configPath = interpolateEnvVarsIfRequested(
|
||||
configPath = InterpolateEnvVarsIfRequested(
|
||||
configPath); // (note: this updates the configPaths array)
|
||||
for(const auto& it :
|
||||
YAML::Load(InputFileStream(configPath))) // later file overrides
|
||||
config_[it.first.as<std::string>()] = it.second;
|
||||
}
|
||||
} else if(reloadConfig) {
|
||||
auto configPath = interpolateEnvVarsIfRequested(
|
||||
auto configPath = InterpolateEnvVarsIfRequested(
|
||||
vm_["model"].as<std::string>() + ".yml");
|
||||
config_ = YAML::Load(InputFileStream(configPath));
|
||||
configPaths = {configPath};
|
||||
|
@ -1042,7 +931,7 @@ void ConfigParser::parseOptions(int argc, char** argv, bool doValidate) {
|
|||
SET_OPTION("n-best-feature", std::string);
|
||||
SET_OPTION_NONDEFAULT("summary", std::string);
|
||||
SET_OPTION("optimize", bool);
|
||||
SET_OPTION("alignment", float);
|
||||
SET_OPTION_NONDEFAULT("alignment", std::string);
|
||||
}
|
||||
|
||||
if(mode_ == ConfigMode::translating) {
|
||||
|
@ -1055,7 +944,7 @@ void ConfigParser::parseOptions(int argc, char** argv, bool doValidate) {
|
|||
SET_OPTION("mini-batch-words", int);
|
||||
SET_OPTION_NONDEFAULT("weights", std::vector<float>);
|
||||
SET_OPTION_NONDEFAULT("shortlist", std::vector<std::string>);
|
||||
SET_OPTION("alignment", float);
|
||||
SET_OPTION_NONDEFAULT("alignment", std::string);
|
||||
SET_OPTION("port", size_t);
|
||||
SET_OPTION("optimize", bool);
|
||||
SET_OPTION("max-length-factor", float);
|
||||
|
@ -1120,7 +1009,7 @@ void ConfigParser::parseOptions(int argc, char** argv, bool doValidate) {
|
|||
}
|
||||
|
||||
if(get<bool>("interpolate-env-vars")) {
|
||||
processPaths(config_, interpolateEnvVars);
|
||||
cli::ProcessPaths(config_, cli::InterpolateEnvVars, PATHS);
|
||||
}
|
||||
|
||||
if(get<bool>("relative-paths") && !vm_["dump-config"].as<bool>()) {
|
||||
|
@ -1134,20 +1023,22 @@ void ConfigParser::parseOptions(int argc, char** argv, bool doValidate) {
|
|||
ABORT_IF(boost::filesystem::path{configPath}.parent_path() != configDir,
|
||||
"relative-paths option requires all config files to be in the "
|
||||
"same directory");
|
||||
processPaths(config_, [&](const std::string& nodePath) -> std::string {
|
||||
|
||||
auto transformFunc = [&](const std::string& nodePath) -> std::string {
|
||||
// replace relative path w.r.t. configDir
|
||||
using namespace boost::filesystem;
|
||||
try {
|
||||
return canonical(path{nodePath}, configDir).string();
|
||||
} catch(
|
||||
boost::filesystem::filesystem_error&
|
||||
e) { // will fail if file does not exist; use parent in that case
|
||||
} catch(boost::filesystem::filesystem_error& e) {
|
||||
// will fail if file does not exist; use parent in that case
|
||||
std::cerr << e.what() << std::endl;
|
||||
auto parentPath = path{nodePath}.parent_path();
|
||||
return (canonical(parentPath, configDir) / path{nodePath}.filename())
|
||||
.string();
|
||||
}
|
||||
});
|
||||
};
|
||||
|
||||
cli::ProcessPaths(config_, transformFunc, PATHS);
|
||||
}
|
||||
|
||||
if(doValidate) {
|
||||
|
@ -1165,18 +1056,10 @@ void ConfigParser::parseOptions(int argc, char** argv, bool doValidate) {
|
|||
|
||||
if(vm_["dump-config"].as<bool>()) {
|
||||
YAML::Emitter emit;
|
||||
OutputYaml(config_, emit);
|
||||
cli::OutputYaml(config_, emit);
|
||||
std::cout << emit.c_str() << std::endl;
|
||||
exit(0);
|
||||
}
|
||||
|
||||
// @TODO: this should probably be in processOptionDevices()
|
||||
//#ifdef BLAS_FOUND
|
||||
// //omp_set_num_threads(vm_["omp-threads"].as<size_t>());
|
||||
//#ifdef MKL_FOUND
|
||||
// mkl_set_num_threads(vm_["omp-threads"].as<size_t>());
|
||||
//#endif
|
||||
//#endif
|
||||
}
|
||||
|
||||
std::vector<DeviceId> ConfigParser::getDevices() {
|
||||
|
|
|
@ -1,12 +1,15 @@
|
|||
#pragma once
|
||||
|
||||
#include <boost/program_options.hpp>
|
||||
|
||||
#include "3rd_party/yaml-cpp/yaml.h"
|
||||
#include "common/definitions.h"
|
||||
#include "common/file_stream.h"
|
||||
#include "common/logging.h"
|
||||
#ifndef _WIN32 // TODO: why are these needed by a config parser? Can they be
|
||||
// removed for Linux as well?
|
||||
|
||||
// TODO: why are these needed by a config parser? Can they be removed for Linux
|
||||
// as well?
|
||||
#ifndef _WIN32
|
||||
#include <sys/ioctl.h>
|
||||
#include <unistd.h>
|
||||
#endif
|
||||
|
@ -22,8 +25,6 @@ enum struct ConfigMode {
|
|||
// try to determine the width of the terminal
|
||||
uint16_t guess_terminal_width(uint16_t max_width = 180);
|
||||
|
||||
void OutputYaml(const YAML::Node node, YAML::Emitter& out);
|
||||
|
||||
class ConfigParser {
|
||||
public:
|
||||
ConfigParser(int argc, char** argv, ConfigMode mode, bool validate = false)
|
||||
|
|
|
@ -14,7 +14,6 @@
|
|||
// CPU decoding.
|
||||
|
||||
namespace marian {
|
||||
|
||||
namespace io {
|
||||
|
||||
bool isNpz(const std::string& fileName);
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "spdlog/spdlog.h"
|
||||
|
||||
/**
|
||||
|
|
|
@ -33,19 +33,15 @@ std::string WordAlignment::toString() const {
|
|||
}
|
||||
|
||||
WordAlignment ConvertSoftAlignToHardAlign(SoftAlignment alignSoft,
|
||||
float threshold /*= 1.f*/,
|
||||
bool reversed /*= true*/,
|
||||
bool skipEOS /*= false*/) {
|
||||
size_t shift = alignSoft.size() > 0 && skipEOS ? 1 : 0;
|
||||
float threshold /*= 1.f*/) {
|
||||
WordAlignment align;
|
||||
// Alignments by maximum value
|
||||
if(threshold == 1.f) {
|
||||
for(size_t t = 0; t < alignSoft.size() - shift; ++t) {
|
||||
for(size_t t = 0; t < alignSoft.size(); ++t) {
|
||||
// Retrieved alignments are in reversed order
|
||||
size_t rev = reversed ? alignSoft.size() - t - 1 : t;
|
||||
size_t maxArg = 0;
|
||||
for(size_t s = 0; s < alignSoft[0].size(); ++s) {
|
||||
if(alignSoft[rev][maxArg] < alignSoft[rev][s]) {
|
||||
if(alignSoft[t][maxArg] < alignSoft[t][s]) {
|
||||
maxArg = s;
|
||||
}
|
||||
}
|
||||
|
@ -53,11 +49,10 @@ WordAlignment ConvertSoftAlignToHardAlign(SoftAlignment alignSoft,
|
|||
}
|
||||
} else {
|
||||
// Alignments by greather-than-threshold
|
||||
for(size_t t = 0; t < alignSoft.size() - shift; ++t) {
|
||||
for(size_t t = 0; t < alignSoft.size(); ++t) {
|
||||
// Retrieved alignments are in reversed order
|
||||
size_t rev = reversed ? alignSoft.size() - t - 1 : t;
|
||||
for(size_t s = 0; s < alignSoft[0].size(); ++s) {
|
||||
if(alignSoft[rev][s] > threshold) {
|
||||
if(alignSoft[t][s] > threshold) {
|
||||
align.push_back(s, t);
|
||||
}
|
||||
}
|
||||
|
@ -70,5 +65,21 @@ WordAlignment ConvertSoftAlignToHardAlign(SoftAlignment alignSoft,
|
|||
return align;
|
||||
}
|
||||
|
||||
std::string SoftAlignToString(SoftAlignment align) {
|
||||
std::stringstream str;
|
||||
bool first = true;
|
||||
for(size_t t = 0; t < align.size(); ++t) {
|
||||
if(!first)
|
||||
str << " ";
|
||||
for(size_t s = 0; s < align[t].size(); ++s) {
|
||||
if(s != 0)
|
||||
str << ",";
|
||||
str << align[t][s];
|
||||
}
|
||||
first = false;
|
||||
}
|
||||
return str.str();
|
||||
}
|
||||
|
||||
} // namespace data
|
||||
} // namespace marian
|
||||
|
|
|
@ -50,9 +50,9 @@ public:
|
|||
typedef std::vector<std::vector<float>> SoftAlignment;
|
||||
|
||||
WordAlignment ConvertSoftAlignToHardAlign(SoftAlignment alignSoft,
|
||||
float threshold = 1.f,
|
||||
bool reversed = true,
|
||||
bool skipEOS = false);
|
||||
float threshold = 1.f);
|
||||
|
||||
std::string SoftAlignToString(SoftAlignment align);
|
||||
|
||||
} // namespace data
|
||||
} // namespace marian
|
||||
|
|
|
@ -13,7 +13,6 @@
|
|||
#include "training/training_state.h"
|
||||
|
||||
namespace marian {
|
||||
|
||||
namespace data {
|
||||
|
||||
template <class DataSet>
|
||||
|
|
|
@ -8,7 +8,6 @@
|
|||
#include "data/vocab.h"
|
||||
|
||||
namespace marian {
|
||||
|
||||
namespace data {
|
||||
|
||||
class BatchStats {
|
||||
|
|
|
@ -7,7 +7,6 @@
|
|||
#include "3rd_party/exception.h"
|
||||
#include "common/definitions.h"
|
||||
|
||||
// Parent namespace for the Marian project
|
||||
namespace marian {
|
||||
|
||||
#define NodeOp(op) [=]() { op; }
|
||||
|
|
|
@ -156,7 +156,8 @@ public:
|
|||
params_->clear();
|
||||
}
|
||||
|
||||
void setDevice(DeviceId deviceId = {0, DeviceType::gpu}, Ptr<Device> device = nullptr);
|
||||
void setDevice(DeviceId deviceId = {0, DeviceType::gpu},
|
||||
Ptr<Device> device = nullptr);
|
||||
|
||||
DeviceId getDeviceId() { return backend_->getDeviceId(); }
|
||||
|
||||
|
|
|
@ -97,9 +97,7 @@ public:
|
|||
}
|
||||
}
|
||||
|
||||
void setWorkspace(uint8_t* data, size_t size) {
|
||||
device_->set(data, size);
|
||||
}
|
||||
void setWorkspace(uint8_t* data, size_t size) { device_->set(data, size); }
|
||||
|
||||
QSNBestBatch decode(const QSBatch& qsBatch,
|
||||
size_t maxLength,
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
#include "encoder_decoder.h"
|
||||
#include "common/cli_helper.h"
|
||||
|
||||
namespace marian {
|
||||
|
||||
|
@ -91,7 +92,7 @@ Config::YamlNode EncoderDecoder::getModelParameters() {
|
|||
std::string EncoderDecoder::getModelParametersAsString() {
|
||||
auto yaml = getModelParameters();
|
||||
YAML::Emitter out;
|
||||
OutputYaml(yaml, out);
|
||||
cli::OutputYaml(yaml, out);
|
||||
return std::string(out.c_str());
|
||||
}
|
||||
|
||||
|
|
|
@ -6,7 +6,6 @@
|
|||
#include "models/encoder_decoder.h"
|
||||
|
||||
namespace marian {
|
||||
|
||||
namespace models {
|
||||
|
||||
class EncoderFactory : public Factory {
|
||||
|
|
|
@ -54,7 +54,7 @@ public:
|
|||
? std::static_pointer_cast<CorpusBase>(
|
||||
New<CorpusNBest>(options_))
|
||||
: std::static_pointer_cast<CorpusBase>(New<Corpus>(options_))) {
|
||||
ABORT_IF(options_->has("summary") && options_->get<float>("alignment", .0f),
|
||||
ABORT_IF(options_->has("summary") && options_->has("alignment"),
|
||||
"Alignments can not be produced with summarized score");
|
||||
|
||||
corpus_->prepare();
|
||||
|
@ -96,9 +96,9 @@ public:
|
|||
Ptr<ScoreCollector> output = options_->get<bool>("n-best")
|
||||
? std::static_pointer_cast<ScoreCollector>(
|
||||
New<ScoreCollectorNBest>(options_))
|
||||
: New<ScoreCollector>();
|
||||
: New<ScoreCollector>(options_);
|
||||
|
||||
float alignment = options_->get<float>("alignment", .0f);
|
||||
std::string alignment = options_->get<std::string>("alignment", "");
|
||||
bool summarize = options_->has("summary");
|
||||
std::string summary
|
||||
= summarize ? options_->get<std::string>("summary") : "cross-entropy";
|
||||
|
@ -134,7 +134,7 @@ public:
|
|||
|
||||
// soft alignments for each sentence in the batch
|
||||
std::vector<data::SoftAlignment> aligns(batch->size());
|
||||
if(alignment > .0f) {
|
||||
if(!alignment.empty()) {
|
||||
getAlignmentsForBatch(builder->getAlignment(), batch, aligns);
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,122 @@
|
|||
#include "rescorer/score_collector.h"
|
||||
|
||||
#include "common/logging.h"
|
||||
#include "common/utils.h"
|
||||
|
||||
#include <iostream>
|
||||
|
||||
namespace marian {
|
||||
|
||||
ScoreCollector::ScoreCollector(const Ptr<Config>& options)
|
||||
: nextId_(0),
|
||||
outStrm_(new OutputFileStream(std::cout)),
|
||||
alignment_(options->get<std::string>("alignment", "")),
|
||||
alignmentThreshold_(getAlignmentThreshold(alignment_)) {}
|
||||
|
||||
void ScoreCollector::Write(long id, const std::string& message) {
|
||||
boost::mutex::scoped_lock lock(mutex_);
|
||||
if(id == nextId_) {
|
||||
((std::ostream&)*outStrm_) << message << std::endl;
|
||||
|
||||
++nextId_;
|
||||
|
||||
typename Outputs::const_iterator iter, iterNext;
|
||||
iter = outputs_.begin();
|
||||
while(iter != outputs_.end()) {
|
||||
long currId = iter->first;
|
||||
|
||||
if(currId == nextId_) {
|
||||
// 1st element in the map is the next
|
||||
((std::ostream&)*outStrm_) << iter->second << std::endl;
|
||||
|
||||
++nextId_;
|
||||
|
||||
// delete current record, move iter on 1
|
||||
iterNext = iter;
|
||||
++iterNext;
|
||||
outputs_.erase(iter);
|
||||
iter = iterNext;
|
||||
} else {
|
||||
// not the next. stop iterating
|
||||
assert(nextId_ < currId);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
} else {
|
||||
// save for later
|
||||
outputs_[id] = message;
|
||||
}
|
||||
}
|
||||
|
||||
void ScoreCollector::Write(long id,
|
||||
float score,
|
||||
const data::SoftAlignment& align) {
|
||||
auto msg = std::to_string(score);
|
||||
if(!alignment_.empty() && !align.empty())
|
||||
msg += " ||| " + getAlignment(align);
|
||||
Write(id, msg);
|
||||
}
|
||||
|
||||
std::string ScoreCollector::getAlignment(const data::SoftAlignment& align) {
|
||||
if(alignment_ == "soft") {
|
||||
return data::SoftAlignToString(align);
|
||||
} else if(alignment_ == "hard") {
|
||||
return data::ConvertSoftAlignToHardAlign(align, 1.f).toString();
|
||||
} else if(alignmentThreshold_ > 0.f) {
|
||||
return data::ConvertSoftAlignToHardAlign(align, alignmentThreshold_)
|
||||
.toString();
|
||||
} else {
|
||||
ABORT("Unrecognized word alignment type");
|
||||
}
|
||||
return "";
|
||||
}
|
||||
|
||||
ScoreCollectorNBest::ScoreCollectorNBest(const Ptr<Config>& options)
|
||||
: ScoreCollector(options),
|
||||
nBestList_(options->get<std::vector<std::string>>("train-sets").back()),
|
||||
fname_(options->get<std::string>("n-best-feature")) {
|
||||
file_.reset(new InputFileStream(nBestList_));
|
||||
}
|
||||
|
||||
void ScoreCollectorNBest::Write(long id,
|
||||
float score,
|
||||
const data::SoftAlignment& align) {
|
||||
std::string line;
|
||||
{
|
||||
boost::mutex::scoped_lock lock(mutex_);
|
||||
auto iter = buffer_.find(id);
|
||||
if(iter == buffer_.end()) {
|
||||
ABORT_IF(lastRead_ >= id,
|
||||
"Entry {} < {} already read but not in buffer",
|
||||
id,
|
||||
lastRead_);
|
||||
std::string line;
|
||||
while(lastRead_ < id && utils::GetLine((std::istream&)*file_, line)) {
|
||||
lastRead_++;
|
||||
iter = buffer_.emplace(lastRead_, line).first;
|
||||
}
|
||||
}
|
||||
|
||||
line = iter->second;
|
||||
buffer_.erase(iter);
|
||||
}
|
||||
|
||||
ScoreCollector::Write(id, addToNBest(line, fname_, score, align));
|
||||
}
|
||||
|
||||
std::string ScoreCollectorNBest::addToNBest(const std::string nbest,
|
||||
const std::string feature,
|
||||
float score,
|
||||
const data::SoftAlignment& align) {
|
||||
std::vector<std::string> fields;
|
||||
utils::Split(nbest, fields, "|||");
|
||||
std::stringstream ss;
|
||||
if(!alignment_.empty() && !align.empty())
|
||||
ss << " " << getAlignment(align) << " |||";
|
||||
ss << fields[2] << feature << "= " << score << " ";
|
||||
fields[2] = ss.str();
|
||||
return utils::Join(fields, "|||");
|
||||
}
|
||||
|
||||
} // namespace marian
|
|
@ -1,69 +1,23 @@
|
|||
#pragma once
|
||||
|
||||
#include <boost/thread/mutex.hpp>
|
||||
#include <boost/unordered_map.hpp>
|
||||
#include <iostream>
|
||||
#include <map>
|
||||
|
||||
#include "common/config.h"
|
||||
#include "common/definitions.h"
|
||||
#include "common/file_stream.h"
|
||||
#include "common/logging.h"
|
||||
#include "common/utils.h"
|
||||
#include "data/alignment.h"
|
||||
|
||||
namespace marian {
|
||||
|
||||
class ScoreCollector {
|
||||
public:
|
||||
ScoreCollector() : nextId_(0), outStrm_(new OutputFileStream(std::cout)){};
|
||||
|
||||
virtual void Write(long id, const std::string& message) {
|
||||
boost::mutex::scoped_lock lock(mutex_);
|
||||
if(id == nextId_) {
|
||||
((std::ostream&)*outStrm_) << message << std::endl;
|
||||
|
||||
++nextId_;
|
||||
|
||||
typename Outputs::const_iterator iter, iterNext;
|
||||
iter = outputs_.begin();
|
||||
while(iter != outputs_.end()) {
|
||||
long currId = iter->first;
|
||||
|
||||
if(currId == nextId_) {
|
||||
// 1st element in the map is the next
|
||||
((std::ostream&)*outStrm_) << iter->second << std::endl;
|
||||
|
||||
++nextId_;
|
||||
|
||||
// delete current record, move iter on 1
|
||||
iterNext = iter;
|
||||
++iterNext;
|
||||
outputs_.erase(iter);
|
||||
iter = iterNext;
|
||||
} else {
|
||||
// not the next. stop iterating
|
||||
assert(nextId_ < currId);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
} else {
|
||||
// save for later
|
||||
outputs_[id] = message;
|
||||
}
|
||||
}
|
||||
ScoreCollector(const Ptr<Config>& options);
|
||||
|
||||
virtual void Write(long id, const std::string& message);
|
||||
virtual void Write(long id,
|
||||
float score,
|
||||
const data::SoftAlignment& align = {}) {
|
||||
auto msg = std::to_string(score);
|
||||
if(!align.empty()) {
|
||||
auto wordAlign
|
||||
= data::ConvertSoftAlignToHardAlign(align, 1.f, false, true);
|
||||
msg += " ||| " + wordAlign.toString();
|
||||
}
|
||||
Write(id, msg);
|
||||
}
|
||||
const data::SoftAlignment& align = {});
|
||||
|
||||
protected:
|
||||
long nextId_{0};
|
||||
|
@ -72,69 +26,42 @@ protected:
|
|||
|
||||
typedef std::map<long, std::string> Outputs;
|
||||
Outputs outputs_;
|
||||
|
||||
std::string alignment_;
|
||||
float alignmentThreshold_{0.f};
|
||||
|
||||
std::string getAlignment(const data::SoftAlignment& align);
|
||||
|
||||
float getAlignmentThreshold(const std::string& str) {
|
||||
try {
|
||||
return std::max(std::stof(str), 0.f);
|
||||
} catch(...) {
|
||||
return 0.f;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
class ScoreCollectorNBest : public ScoreCollector {
|
||||
private:
|
||||
Ptr<Config> options_;
|
||||
public:
|
||||
ScoreCollectorNBest() = delete;
|
||||
|
||||
ScoreCollectorNBest(const Ptr<Config>& options);
|
||||
ScoreCollectorNBest(const ScoreCollectorNBest&) = delete;
|
||||
|
||||
virtual void Write(long id,
|
||||
float score,
|
||||
const data::SoftAlignment& align = {});
|
||||
|
||||
private:
|
||||
std::string nBestList_;
|
||||
std::string fname_;
|
||||
long lastRead_{-1};
|
||||
UPtr<InputFileStream> file_;
|
||||
std::map<long, std::string> buffer_;
|
||||
|
||||
public:
|
||||
ScoreCollectorNBest() = delete;
|
||||
|
||||
ScoreCollectorNBest(const Ptr<Config>& options) : options_(options) {
|
||||
auto paths = options_->get<std::vector<std::string>>("train-sets");
|
||||
nBestList_ = paths.back();
|
||||
fname_ = options_->get<std::string>("n-best-feature");
|
||||
file_.reset(new InputFileStream(nBestList_));
|
||||
}
|
||||
|
||||
ScoreCollectorNBest(const ScoreCollectorNBest&) = delete;
|
||||
|
||||
std::string addToNBest(const std::string nbest,
|
||||
const std::string feature,
|
||||
float score,
|
||||
const data::SoftAlignment& align = {}) {
|
||||
std::vector<std::string> fields;
|
||||
utils::Split(nbest, fields, "|||");
|
||||
std::stringstream ss;
|
||||
if(!align.empty()) {
|
||||
auto wordAlign
|
||||
= data::ConvertSoftAlignToHardAlign(align, 1.f, false, true);
|
||||
ss << " " << wordAlign.toString() << " |||";
|
||||
}
|
||||
ss << fields[2] << feature << "= " << score << " ";
|
||||
fields[2] = ss.str();
|
||||
return utils::Join(fields, "|||");
|
||||
}
|
||||
|
||||
virtual void Write(long id, float score, const data::SoftAlignment& align) {
|
||||
std::string line;
|
||||
{
|
||||
boost::mutex::scoped_lock lock(mutex_);
|
||||
auto iter = buffer_.find(id);
|
||||
if(iter == buffer_.end()) {
|
||||
ABORT_IF(lastRead_ >= id,
|
||||
"Entry {} < {} already read but not in buffer",
|
||||
id,
|
||||
lastRead_);
|
||||
std::string line;
|
||||
while(lastRead_ < id && utils::GetLine((std::istream&)*file_, line)) {
|
||||
lastRead_++;
|
||||
iter = buffer_.emplace(lastRead_, line).first;
|
||||
}
|
||||
}
|
||||
|
||||
line = iter->second;
|
||||
buffer_.erase(iter);
|
||||
}
|
||||
|
||||
ScoreCollector::Write(id, addToNBest(line, fname_, score, align));
|
||||
}
|
||||
const data::SoftAlignment& align = {});
|
||||
};
|
||||
} // namespace marian
|
||||
|
|
|
@ -5,7 +5,6 @@
|
|||
#include "rnn/types.h"
|
||||
|
||||
namespace marian {
|
||||
|
||||
namespace rnn {
|
||||
|
||||
Expr attOps(Expr va, Expr context, Expr state);
|
||||
|
|
|
@ -173,10 +173,7 @@ public:
|
|||
size_t bytes,
|
||||
size_t step,
|
||||
size_t alignment = 256)
|
||||
: device_(device),
|
||||
available_(0),
|
||||
step_(step),
|
||||
alignment_(alignment) {
|
||||
: device_(device), available_(0), step_(step), alignment_(alignment) {
|
||||
reserve(bytes);
|
||||
}
|
||||
|
||||
|
|
|
@ -70,9 +70,11 @@ public:
|
|||
|
||||
// doesn't allocate anything, just checks size.
|
||||
void reserve(size_t size) {
|
||||
ABORT_IF(size > size_, "Requested size {} is larger than pre-allocated size {}", size, size_);
|
||||
ABORT_IF(size > size_,
|
||||
"Requested size {} is larger than pre-allocated size {}",
|
||||
size,
|
||||
size_);
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
} // namespace cpu
|
||||
|
|
|
@ -177,8 +177,8 @@ public:
|
|||
}
|
||||
#ifdef CUDA_FOUND
|
||||
else {
|
||||
std::vector<int> outputs
|
||||
= gpu::lower_bounds(indices(), values, size(), backend_->getDeviceId());
|
||||
std::vector<int> outputs = gpu::lower_bounds(
|
||||
indices(), values, size(), backend_->getDeviceId());
|
||||
|
||||
startOffset = outputs[0];
|
||||
endOffset = outputs[1];
|
||||
|
|
|
@ -41,10 +41,10 @@ public:
|
|||
Ptr<data::CorpusBatch> batch) {
|
||||
Beams newBeams(beams.size());
|
||||
|
||||
std::vector<float> alignments;
|
||||
if(options_->get<float>("alignment", 0.f))
|
||||
std::vector<float> align;
|
||||
if(options_->has("alignment"))
|
||||
// Use alignments from the first scorer, even if ensemble
|
||||
alignments = scorers_[0]->getAlignment();
|
||||
align = scorers_[0]->getAlignment();
|
||||
|
||||
for(size_t i = 0; i < keys.size(); ++i) {
|
||||
// Keys contains indices to vocab items in the entire beam.
|
||||
|
@ -93,10 +93,9 @@ public:
|
|||
}
|
||||
|
||||
// Set alignments
|
||||
if(!alignments.empty()) {
|
||||
auto align = getAlignmentsForHypothesis(
|
||||
alignments, batch, beamSize, beamHypIdx, beamIdx);
|
||||
hyp->SetAlignment(align);
|
||||
if(!align.empty()) {
|
||||
hyp->SetAlignment(
|
||||
getAlignmentsForHypothesis(align, batch, beamHypIdx, beamIdx));
|
||||
}
|
||||
|
||||
newBeam.push_back(hyp);
|
||||
|
@ -106,9 +105,8 @@ public:
|
|||
}
|
||||
|
||||
std::vector<float> getAlignmentsForHypothesis(
|
||||
const std::vector<float> alignments,
|
||||
const std::vector<float> alignAll,
|
||||
Ptr<data::CorpusBatch> batch,
|
||||
int beamSize,
|
||||
int beamHypIdx,
|
||||
int beamIdx) {
|
||||
// Let's B be the beam size, N be the number of batched sentences,
|
||||
|
@ -136,7 +134,7 @@ public:
|
|||
size_t a = ((batchWidth * beamHypIdx) + beamIdx) + (batchSize * w);
|
||||
size_t m = a % batchWidth;
|
||||
if(batch->front()->mask()[m] != 0)
|
||||
align.emplace_back(alignments[a]);
|
||||
align.emplace_back(alignAll[a]);
|
||||
}
|
||||
|
||||
return align;
|
||||
|
|
|
@ -2,18 +2,30 @@
|
|||
|
||||
namespace marian {
|
||||
|
||||
data::WordAlignment OutputPrinter::getAlignment(const Ptr<Hypothesis>& hyp,
|
||||
float threshold) {
|
||||
data::SoftAlignment aligns;
|
||||
// Skip EOS
|
||||
auto last = hyp->GetPrevHyp();
|
||||
// Get soft alignments for each target word
|
||||
std::string OutputPrinter::getAlignment(const Ptr<Hypothesis>& hyp) {
|
||||
data::SoftAlignment align;
|
||||
auto last = hyp;
|
||||
// get soft alignments for each target word starting from the last one
|
||||
while(last->GetPrevHyp().get() != nullptr) {
|
||||
aligns.push_back(last->GetAlignment());
|
||||
align.push_back(last->GetAlignment());
|
||||
last = last->GetPrevHyp();
|
||||
}
|
||||
|
||||
return data::ConvertSoftAlignToHardAlign(aligns, threshold, true);
|
||||
// reverse alignments
|
||||
std::reverse(align.begin(), align.end());
|
||||
|
||||
if(alignment_ == "soft") {
|
||||
return data::SoftAlignToString(align);
|
||||
} else if(alignment_ == "hard") {
|
||||
return data::ConvertSoftAlignToHardAlign(align, 1.f).toString();
|
||||
} else if(alignmentThreshold_ > 0.f) {
|
||||
return data::ConvertSoftAlignToHardAlign(align, alignmentThreshold_)
|
||||
.toString();
|
||||
} else {
|
||||
ABORT("Unrecognized word alignment type");
|
||||
}
|
||||
|
||||
return "";
|
||||
}
|
||||
|
||||
} // namespace marian
|
||||
|
|
|
@ -19,7 +19,8 @@ public:
|
|||
nbest_(options->get<bool>("n-best", false)
|
||||
? options->get<size_t>("beam-size")
|
||||
: 0),
|
||||
alignment_(options->get<float>("alignment", 0.f)) {}
|
||||
alignment_(options->get<std::string>("alignment", "")),
|
||||
alignmentThreshold_(getAlignmentThreshold(alignment_)) {}
|
||||
|
||||
template <class OStream>
|
||||
void print(Ptr<History> history, OStream& best1, OStream& bestn) {
|
||||
|
@ -33,12 +34,10 @@ public:
|
|||
std::string translation = utils::Join((*vocab_)(words), " ", reverse_);
|
||||
bestn << history->GetLineNum() << " ||| " << translation;
|
||||
|
||||
if(alignment_ > 0.f) {
|
||||
bestn << " ||| " << getAlignment(hypo, alignment_).toString();
|
||||
}
|
||||
if(!alignment_.empty())
|
||||
bestn << " ||| " << getAlignment(hypo);
|
||||
|
||||
bestn << " |||";
|
||||
|
||||
if(hypo->GetCostBreakdown().empty()) {
|
||||
bestn << " F0=" << hypo->GetCost();
|
||||
} else {
|
||||
|
@ -62,9 +61,9 @@ public:
|
|||
std::string translation = utils::Join((*vocab_)(words), " ", reverse_);
|
||||
|
||||
best1 << translation;
|
||||
if(alignment_ > 0.f) {
|
||||
if(!alignment_.empty()) {
|
||||
const auto& hypo = std::get<1>(result);
|
||||
best1 << " ||| " << getAlignment(hypo, alignment_).toString();
|
||||
best1 << " ||| " << getAlignment(hypo);
|
||||
}
|
||||
best1 << std::flush;
|
||||
}
|
||||
|
@ -73,8 +72,17 @@ private:
|
|||
Ptr<Vocab> vocab_;
|
||||
bool reverse_{false};
|
||||
size_t nbest_{0};
|
||||
float alignment_{0.f};
|
||||
std::string alignment_;
|
||||
float alignmentThreshold_{0.f};
|
||||
|
||||
data::WordAlignment getAlignment(const Ptr<Hypothesis>& hyp, float threshold);
|
||||
std::string getAlignment(const Ptr<Hypothesis>& hyp);
|
||||
|
||||
float getAlignmentThreshold(const std::string& str) {
|
||||
try {
|
||||
return std::max(std::stof(str), 0.f);
|
||||
} catch(...) {
|
||||
return 0.f;
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace marian
|
||||
|
|
Загрузка…
Ссылка в новой задаче