Disable caching features to memory

This commit is contained in:
Reuben Morais 2019-11-28 13:51:33 +01:00
Родитель 271e3639a7
Коммит e3b1b5fd42
2 изменённых файлов: 4 добавлений и 3 удалений

Просмотреть файл

@ -433,7 +433,8 @@ def train():
# Create training and validation datasets
train_set = create_dataset(FLAGS.train_files.split(','),
batch_size=FLAGS.train_batch_size,
cache_path=FLAGS.feature_cache if do_cache_dataset else None,
enable_cache=FLAGS.feature_cache and do_cache_dataset,
cache_path=FLAGS.feature_cache,
train_phase=True)
iterator = tfv1.data.Iterator.from_structure(tfv1.data.get_output_types(train_set),

Просмотреть файл

@ -94,7 +94,7 @@ def to_sparse_tuple(sequence):
return indices, sequence, shape
def create_dataset(csvs, batch_size, cache_path='', train_phase=False):
def create_dataset(csvs, batch_size, enable_cache=False, cache_path=None, train_phase=False):
df = read_csvs(csvs)
df.sort_values(by='wav_filesize', inplace=True)
@ -126,7 +126,7 @@ def create_dataset(csvs, batch_size, cache_path='', train_phase=False):
output_types=(tf.string, (tf.int64, tf.int32, tf.int64)))
.map(process_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE))
if cache_path is not None:
if enable_cache:
dataset = dataset.cache(cache_path)
dataset = (dataset.window(batch_size, drop_remainder=True).flat_map(batch_fn)