зеркало из https://github.com/mozilla/DeepSpeech.git
Disable caching features to memory
This commit is contained in:
Родитель
271e3639a7
Коммит
e3b1b5fd42
|
@ -433,7 +433,8 @@ def train():
|
|||
# Create training and validation datasets
|
||||
train_set = create_dataset(FLAGS.train_files.split(','),
|
||||
batch_size=FLAGS.train_batch_size,
|
||||
cache_path=FLAGS.feature_cache if do_cache_dataset else None,
|
||||
enable_cache=FLAGS.feature_cache and do_cache_dataset,
|
||||
cache_path=FLAGS.feature_cache,
|
||||
train_phase=True)
|
||||
|
||||
iterator = tfv1.data.Iterator.from_structure(tfv1.data.get_output_types(train_set),
|
||||
|
|
|
@ -94,7 +94,7 @@ def to_sparse_tuple(sequence):
|
|||
return indices, sequence, shape
|
||||
|
||||
|
||||
def create_dataset(csvs, batch_size, cache_path='', train_phase=False):
|
||||
def create_dataset(csvs, batch_size, enable_cache=False, cache_path=None, train_phase=False):
|
||||
df = read_csvs(csvs)
|
||||
df.sort_values(by='wav_filesize', inplace=True)
|
||||
|
||||
|
@ -126,7 +126,7 @@ def create_dataset(csvs, batch_size, cache_path='', train_phase=False):
|
|||
output_types=(tf.string, (tf.int64, tf.int32, tf.int64)))
|
||||
.map(process_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE))
|
||||
|
||||
if cache_path is not None:
|
||||
if enable_cache:
|
||||
dataset = dataset.cache(cache_path)
|
||||
|
||||
dataset = (dataset.window(batch_size, drop_remainder=True).flat_map(batch_fn)
|
||||
|
|
Загрузка…
Ссылка в новой задаче