From efbed73d5cd780af76e3bd9a312aec541b374ead Mon Sep 17 00:00:00 2001 From: Reuben Morais Date: Tue, 21 Jan 2020 22:27:06 +0100 Subject: [PATCH] Improve error handling around Scorer loading --- native_client/ctcdecode/scorer.cpp | 36 +++++++++++++----------------- native_client/ctcdecode/scorer.h | 5 +++-- 2 files changed, 19 insertions(+), 22 deletions(-) diff --git a/native_client/ctcdecode/scorer.cpp b/native_client/ctcdecode/scorer.cpp index d53fe917..c5ae54a2 100644 --- a/native_client/ctcdecode/scorer.cpp +++ b/native_client/ctcdecode/scorer.cpp @@ -31,10 +31,8 @@ int Scorer::init(const std::string& lm_path, const Alphabet& alphabet) { - alphabet_ = alphabet; - setup_char_map(); - load_lm(lm_path); - return 0; + set_alphabet(alphabet); + return load_lm(lm_path); } int @@ -46,8 +44,7 @@ Scorer::init(const std::string& lm_path, return err; } setup_char_map(); - load_lm(lm_path); - return 0; + return load_lm(lm_path); } void @@ -72,15 +69,14 @@ void Scorer::setup_char_map() } } -void Scorer::load_lm(const std::string& lm_path) +int Scorer::load_lm(const std::string& lm_path) { // load language model const char* filename = lm_path.c_str(); - VALID_CHECK_EQ(access(filename, R_OK), 0, "Invalid language model path"); - lm::ngram::Config config; config.load_method = util::LoadMethod::LAZY; language_model_.reset(lm::ngram::LoadVirtual(filename, config)); + max_order_ = language_model_->Order(); uint64_t package_size; { @@ -88,26 +84,25 @@ void Scorer::load_lm(const std::string& lm_path) package_size = util::SizeFile(fd.get()); } uint64_t trie_offset = language_model_->GetEndOfSearchOffset(); - bool has_trie = package_size > trie_offset; - - if (has_trie) { - // Read metadata and trie from file - std::ifstream fin(lm_path, std::ios::binary); - fin.seekg(trie_offset); - load_trie(fin, lm_path); + if (package_size <= trie_offset) { + // File ends without a trie structure + return 1; } - max_order_ = language_model_->Order(); + // Read metadata and trie from file + std::ifstream fin(lm_path, std::ios::binary); + fin.seekg(trie_offset); + return load_trie(fin, lm_path); } -void Scorer::load_trie(std::ifstream& fin, const std::string& file_path) +int Scorer::load_trie(std::ifstream& fin, const std::string& file_path) { int magic; fin.read(reinterpret_cast(&magic), sizeof(magic)); if (magic != MAGIC) { std::cerr << "Error: Can't parse trie file, invalid header. Try updating " "your trie file." << std::endl; - throw 1; + return 1; } int version; @@ -122,7 +117,7 @@ void Scorer::load_trie(std::ifstream& fin, const std::string& file_path) std::cerr << "Downgrade your trie file or update your version of DeepSpeech."; } std::cerr << std::endl; - throw 1; + return 1; } fin.read(reinterpret_cast(&is_utf8_mode_), sizeof(is_utf8_mode_)); @@ -137,6 +132,7 @@ void Scorer::load_trie(std::ifstream& fin, const std::string& file_path) opt.mode = fst::FstReadOptions::MAP; opt.source = file_path; dictionary.reset(FstType::Read(fin, opt)); + return 0; } void Scorer::save_dictionary(const std::string& path, bool append_instead_of_overwrite) diff --git a/native_client/ctcdecode/scorer.h b/native_client/ctcdecode/scorer.h index b2e5c817..55f337ed 100644 --- a/native_client/ctcdecode/scorer.h +++ b/native_client/ctcdecode/scorer.h @@ -12,6 +12,7 @@ #include "path_trie.h" #include "alphabet.h" +#include "deepspeech.h" const double OOV_SCORE = -1000.0; const std::string START_TOKEN = ""; @@ -85,7 +86,7 @@ public: void fill_dictionary(const std::vector &vocabulary); // load language model from given path - void load_lm(const std::string &lm_path); + int load_lm(const std::string &lm_path); // language model weight double alpha = 0.; @@ -99,7 +100,7 @@ protected: // necessary setup after setting alphabet void setup_char_map(); - void load_trie(std::ifstream& fin, const std::string& file_path); + int load_trie(std::ifstream& fin, const std::string& file_path); private: std::unique_ptr language_model_;