Improve error handling around Scorer loading

This commit is contained in:
Reuben Morais 2020-01-21 22:27:06 +01:00
Родитель 3b54f54524
Коммит efbed73d5c
2 изменённых файлов: 19 добавлений и 22 удалений

Просмотреть файл

@ -31,10 +31,8 @@ int
Scorer::init(const std::string& lm_path,
const Alphabet& alphabet)
{
alphabet_ = alphabet;
setup_char_map();
load_lm(lm_path);
return 0;
set_alphabet(alphabet);
return load_lm(lm_path);
}
int
@ -46,8 +44,7 @@ Scorer::init(const std::string& lm_path,
return err;
}
setup_char_map();
load_lm(lm_path);
return 0;
return load_lm(lm_path);
}
void
@ -72,15 +69,14 @@ void Scorer::setup_char_map()
}
}
void Scorer::load_lm(const std::string& lm_path)
int Scorer::load_lm(const std::string& lm_path)
{
// load language model
const char* filename = lm_path.c_str();
VALID_CHECK_EQ(access(filename, R_OK), 0, "Invalid language model path");
lm::ngram::Config config;
config.load_method = util::LoadMethod::LAZY;
language_model_.reset(lm::ngram::LoadVirtual(filename, config));
max_order_ = language_model_->Order();
uint64_t package_size;
{
@ -88,26 +84,25 @@ void Scorer::load_lm(const std::string& lm_path)
package_size = util::SizeFile(fd.get());
}
uint64_t trie_offset = language_model_->GetEndOfSearchOffset();
bool has_trie = package_size > trie_offset;
if (has_trie) {
// Read metadata and trie from file
std::ifstream fin(lm_path, std::ios::binary);
fin.seekg(trie_offset);
load_trie(fin, lm_path);
if (package_size <= trie_offset) {
// File ends without a trie structure
return 1;
}
max_order_ = language_model_->Order();
// Read metadata and trie from file
std::ifstream fin(lm_path, std::ios::binary);
fin.seekg(trie_offset);
return load_trie(fin, lm_path);
}
void Scorer::load_trie(std::ifstream& fin, const std::string& file_path)
int Scorer::load_trie(std::ifstream& fin, const std::string& file_path)
{
int magic;
fin.read(reinterpret_cast<char*>(&magic), sizeof(magic));
if (magic != MAGIC) {
std::cerr << "Error: Can't parse trie file, invalid header. Try updating "
"your trie file." << std::endl;
throw 1;
return 1;
}
int version;
@ -122,7 +117,7 @@ void Scorer::load_trie(std::ifstream& fin, const std::string& file_path)
std::cerr << "Downgrade your trie file or update your version of DeepSpeech.";
}
std::cerr << std::endl;
throw 1;
return 1;
}
fin.read(reinterpret_cast<char*>(&is_utf8_mode_), sizeof(is_utf8_mode_));
@ -137,6 +132,7 @@ void Scorer::load_trie(std::ifstream& fin, const std::string& file_path)
opt.mode = fst::FstReadOptions::MAP;
opt.source = file_path;
dictionary.reset(FstType::Read(fin, opt));
return 0;
}
void Scorer::save_dictionary(const std::string& path, bool append_instead_of_overwrite)

Просмотреть файл

@ -12,6 +12,7 @@
#include "path_trie.h"
#include "alphabet.h"
#include "deepspeech.h"
const double OOV_SCORE = -1000.0;
const std::string START_TOKEN = "<s>";
@ -85,7 +86,7 @@ public:
void fill_dictionary(const std::vector<std::string> &vocabulary);
// load language model from given path
void load_lm(const std::string &lm_path);
int load_lm(const std::string &lm_path);
// language model weight
double alpha = 0.;
@ -99,7 +100,7 @@ protected:
// necessary setup after setting alphabet
void setup_char_map();
void load_trie(std::ifstream& fin, const std::string& file_path);
int load_trie(std::ifstream& fin, const std::string& file_path);
private:
std::unique_ptr<lm::base::Model> language_model_;