зеркало из https://github.com/mozilla/DSAlign.git
Fixes around custom scorer generation
This commit is contained in:
Родитель
e9ae8b7903
Коммит
bcdbfee856
|
@ -1,11 +1,42 @@
|
|||
import shutil
|
||||
import sys
|
||||
|
||||
import ds_ctcdecoder
|
||||
from deepspeech_training.util.text import Alphabet, UTF8Alphabet
|
||||
import struct
|
||||
from ds_ctcdecoder import Scorer, Alphabet as NativeAlphabet
|
||||
|
||||
|
||||
class Alphabet(object):
|
||||
def __init__(self, config_file):
|
||||
self._config_file = config_file
|
||||
self._label_to_str = {}
|
||||
self._str_to_label = {}
|
||||
self._size = 0
|
||||
if config_file:
|
||||
with open(config_file, 'r', encoding='utf-8') as fin:
|
||||
for line in fin:
|
||||
if line[0:2] == '\\#':
|
||||
line = '#\n'
|
||||
elif line[0] == '#':
|
||||
continue
|
||||
self._label_to_str[self._size] = line[:-1] # remove the line ending
|
||||
self._str_to_label[line[:-1]] = self._size
|
||||
self._size += 1
|
||||
|
||||
def serialize(self):
|
||||
# Serialization format is a sequence of (key, value) pairs, where key is
|
||||
# a uint16_t and value is a uint16_t length followed by `length` UTF-8
|
||||
# encoded bytes with the label.
|
||||
res = bytearray()
|
||||
|
||||
# We start by writing the number of pairs in the buffer as uint16_t.
|
||||
res += struct.pack('<H', self._size)
|
||||
for key, value in self._label_to_str.items():
|
||||
value = value.encode('utf-8')
|
||||
# struct.pack only takes fixed length strings/buffers, so we have to
|
||||
# construct the correct format string with the length of the encoded
|
||||
# label.
|
||||
res += struct.pack('<HH{}s'.format(len(value)), key, len(value), value)
|
||||
return bytes(res)
|
||||
|
||||
|
||||
def create_bundle(
|
||||
alphabet_path,
|
||||
lm_path,
|
||||
|
@ -16,30 +47,14 @@ def create_bundle(
|
|||
default_beta,
|
||||
):
|
||||
words = set()
|
||||
vocab_looks_char_based = True
|
||||
with open(vocab_path) as fin:
|
||||
for line in fin:
|
||||
for word in line.split():
|
||||
words.add(word.encode("utf-8"))
|
||||
if len(word) > 1:
|
||||
vocab_looks_char_based = False
|
||||
print("{} unique words read from vocabulary file.".format(len(words)))
|
||||
|
||||
cbm = "Looks" if vocab_looks_char_based else "Doesn't look"
|
||||
print("{} like a character based model.".format(cbm))
|
||||
|
||||
if force_utf8 != None: # pylint: disable=singleton-comparison
|
||||
use_utf8 = force_utf8
|
||||
else:
|
||||
use_utf8 = vocab_looks_char_based
|
||||
print("Using detected UTF-8 mode: {}".format(use_utf8))
|
||||
|
||||
if use_utf8:
|
||||
serialized_alphabet = UTF8Alphabet().serialize()
|
||||
else:
|
||||
if not alphabet_path:
|
||||
raise RuntimeError("No --alphabet path specified, can't continue.")
|
||||
serialized_alphabet = Alphabet(alphabet_path).serialize()
|
||||
if not alphabet_path:
|
||||
raise RuntimeError("No --alphabet path specified, can't continue.")
|
||||
serialized_alphabet = Alphabet(alphabet_path).serialize()
|
||||
|
||||
alphabet = NativeAlphabet()
|
||||
err = alphabet.deserialize(serialized_alphabet, len(serialized_alphabet))
|
||||
|
@ -48,7 +63,6 @@ def create_bundle(
|
|||
|
||||
scorer = Scorer()
|
||||
scorer.set_alphabet(alphabet)
|
||||
scorer.set_utf8_mode(use_utf8)
|
||||
scorer.reset_params(default_alpha, default_beta)
|
||||
scorer.load_lm(lm_path)
|
||||
# TODO: Why is this not working?
|
||||
|
|
Загрузка…
Ссылка в новой задаче