зеркало из https://github.com/mozilla/DeepSpeech.git
Improve error handling around Scorer loading
This commit is contained in:
Родитель
3b54f54524
Коммит
efbed73d5c
|
@ -31,10 +31,8 @@ int
|
|||
Scorer::init(const std::string& lm_path,
|
||||
const Alphabet& alphabet)
|
||||
{
|
||||
alphabet_ = alphabet;
|
||||
setup_char_map();
|
||||
load_lm(lm_path);
|
||||
return 0;
|
||||
set_alphabet(alphabet);
|
||||
return load_lm(lm_path);
|
||||
}
|
||||
|
||||
int
|
||||
|
@ -46,8 +44,7 @@ Scorer::init(const std::string& lm_path,
|
|||
return err;
|
||||
}
|
||||
setup_char_map();
|
||||
load_lm(lm_path);
|
||||
return 0;
|
||||
return load_lm(lm_path);
|
||||
}
|
||||
|
||||
void
|
||||
|
@ -72,15 +69,14 @@ void Scorer::setup_char_map()
|
|||
}
|
||||
}
|
||||
|
||||
void Scorer::load_lm(const std::string& lm_path)
|
||||
int Scorer::load_lm(const std::string& lm_path)
|
||||
{
|
||||
// load language model
|
||||
const char* filename = lm_path.c_str();
|
||||
VALID_CHECK_EQ(access(filename, R_OK), 0, "Invalid language model path");
|
||||
|
||||
lm::ngram::Config config;
|
||||
config.load_method = util::LoadMethod::LAZY;
|
||||
language_model_.reset(lm::ngram::LoadVirtual(filename, config));
|
||||
max_order_ = language_model_->Order();
|
||||
|
||||
uint64_t package_size;
|
||||
{
|
||||
|
@ -88,26 +84,25 @@ void Scorer::load_lm(const std::string& lm_path)
|
|||
package_size = util::SizeFile(fd.get());
|
||||
}
|
||||
uint64_t trie_offset = language_model_->GetEndOfSearchOffset();
|
||||
bool has_trie = package_size > trie_offset;
|
||||
|
||||
if (has_trie) {
|
||||
// Read metadata and trie from file
|
||||
std::ifstream fin(lm_path, std::ios::binary);
|
||||
fin.seekg(trie_offset);
|
||||
load_trie(fin, lm_path);
|
||||
if (package_size <= trie_offset) {
|
||||
// File ends without a trie structure
|
||||
return 1;
|
||||
}
|
||||
|
||||
max_order_ = language_model_->Order();
|
||||
// Read metadata and trie from file
|
||||
std::ifstream fin(lm_path, std::ios::binary);
|
||||
fin.seekg(trie_offset);
|
||||
return load_trie(fin, lm_path);
|
||||
}
|
||||
|
||||
void Scorer::load_trie(std::ifstream& fin, const std::string& file_path)
|
||||
int Scorer::load_trie(std::ifstream& fin, const std::string& file_path)
|
||||
{
|
||||
int magic;
|
||||
fin.read(reinterpret_cast<char*>(&magic), sizeof(magic));
|
||||
if (magic != MAGIC) {
|
||||
std::cerr << "Error: Can't parse trie file, invalid header. Try updating "
|
||||
"your trie file." << std::endl;
|
||||
throw 1;
|
||||
return 1;
|
||||
}
|
||||
|
||||
int version;
|
||||
|
@ -122,7 +117,7 @@ void Scorer::load_trie(std::ifstream& fin, const std::string& file_path)
|
|||
std::cerr << "Downgrade your trie file or update your version of DeepSpeech.";
|
||||
}
|
||||
std::cerr << std::endl;
|
||||
throw 1;
|
||||
return 1;
|
||||
}
|
||||
|
||||
fin.read(reinterpret_cast<char*>(&is_utf8_mode_), sizeof(is_utf8_mode_));
|
||||
|
@ -137,6 +132,7 @@ void Scorer::load_trie(std::ifstream& fin, const std::string& file_path)
|
|||
opt.mode = fst::FstReadOptions::MAP;
|
||||
opt.source = file_path;
|
||||
dictionary.reset(FstType::Read(fin, opt));
|
||||
return 0;
|
||||
}
|
||||
|
||||
void Scorer::save_dictionary(const std::string& path, bool append_instead_of_overwrite)
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
|
||||
#include "path_trie.h"
|
||||
#include "alphabet.h"
|
||||
#include "deepspeech.h"
|
||||
|
||||
const double OOV_SCORE = -1000.0;
|
||||
const std::string START_TOKEN = "<s>";
|
||||
|
@ -85,7 +86,7 @@ public:
|
|||
void fill_dictionary(const std::vector<std::string> &vocabulary);
|
||||
|
||||
// load language model from given path
|
||||
void load_lm(const std::string &lm_path);
|
||||
int load_lm(const std::string &lm_path);
|
||||
|
||||
// language model weight
|
||||
double alpha = 0.;
|
||||
|
@ -99,7 +100,7 @@ protected:
|
|||
// necessary setup after setting alphabet
|
||||
void setup_char_map();
|
||||
|
||||
void load_trie(std::ifstream& fin, const std::string& file_path);
|
||||
int load_trie(std::ifstream& fin, const std::string& file_path);
|
||||
|
||||
private:
|
||||
std::unique_ptr<lm::base::Model> language_model_;
|
||||
|
|
Загрузка…
Ссылка в новой задаче