* Redo remote I/O changes once more; this time without messing with taskcluster

* Add bin changes

* Fix merge-induced issue?

* For the interleaved case with multiple collections, unpack audio on the fly

To reproduce the previous failure

rm data/smoke_test/ldc93s1.csv
rm data/smoke_test/ldc93s1.sdb
rm -rf /tmp/ldc93s1_cache_sdb_csv
rm -rf /tmp/ckpt_sdb_csv
rm -rf /tmp/train_sdb_csv

./bin/run-tc-ldc93s1_new_sdb_csv.sh 109 16000
python -u DeepSpeech.py --noshow_progressbar --noearly_stop --train_files ./data/smoke_test/ldc93s1.sdb,./data/smoke_test/ldc93s1.csv --train_batch_size 1 --feature_cache /tmp/ldc93s1_cache_sdb_csv --dev_files ./data/smoke_test/ldc93s1.sdb,./data/smoke_test/ldc93s1.csv --dev_batch_size 1 --test_files ./data/smoke_test/ldc93s1.sdb,./data/smoke_test/ldc93s1.csv --test_batch_size 1 --n_hidden 100 --epochs 109 --max_to_keep 1 --checkpoint_dir /tmp/ckpt_sdb_csv --learning_rate 0.001 --dropout_rate 0.05 --export_dir /tmp/train_sdb_csv --scorer_path data/smoke_test/pruned_lm.scorer --audio_sample_rate 16000

* Attempt to preserve length information with a wrapper around `map()`… this gets pretty python-y

* Call the right `__next__()`

* Properly implement the rest of the map wrappers here……

* Fix trailing whitespace situation and other linter complaints

* Remove data accidentally checked in

* Fix overlay augmentations

* Wavs must be open in rb mode if we're passing in an external file pointer -- this confused me

* Lint whitespace

* Revert "Fix trailing whitespace situation and other linter complaints"

This reverts commit c3c45397a2.

* Fix linter issue but without such an aggressive diff

* Move unpack_maybe into sample_collections

* Use unpack_maybe in place of duplicate lambda

* Fix confusing comment

* Add clarifying comment for on-the-fly unpacking
This commit is contained in:
Catalin Voss 2020-12-07 04:07:34 -08:00 коммит произвёл GitHub
Родитель 18b66adf46
Коммит 6640cf2341
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
11 изменённых файлов: 249 добавлений и 55 удалений

Просмотреть файл

@ -15,8 +15,8 @@ def fail(message):
def compare_samples():
sample1 = load_sample(CLI_ARGS.sample1)
sample2 = load_sample(CLI_ARGS.sample2)
sample1 = load_sample(CLI_ARGS.sample1).unpack()
sample2 = load_sample(CLI_ARGS.sample2).unpack()
if sample1.audio_format != sample2.audio_format:
fail('Samples differ on: audio-format ({} and {})'.format(sample1.audio_format, sample2.audio_format))
if sample1.duration != sample2.duration:

Просмотреть файл

@ -35,6 +35,7 @@ from .util.feeding import create_dataset, audio_to_features, audiofile_to_featur
from .util.flags import create_flags, FLAGS
from .util.helpers import check_ctcdecoder_version, ExceptionBox
from .util.logging import create_progressbar, log_debug, log_error, log_info, log_progress, log_warn
from .util.io import open_remote, remove_remote, listdir_remote, is_remote_path, isdir_remote
check_ctcdecoder_version()
@ -512,9 +513,10 @@ def train():
best_dev_path = os.path.join(FLAGS.save_checkpoint_dir, 'best_dev')
# Save flags next to checkpoints
os.makedirs(FLAGS.save_checkpoint_dir, exist_ok=True)
if not is_remote_path(FLAGS.save_checkpoint_dir):
os.makedirs(FLAGS.save_checkpoint_dir, exist_ok=True)
flags_file = os.path.join(FLAGS.save_checkpoint_dir, 'flags.txt')
with open(flags_file, 'w') as fout:
with open_remote(flags_file, 'w') as fout:
fout.write(FLAGS.flags_into_string())
with tfv1.Session(config=Config.session_config) as session:
@ -541,7 +543,7 @@ def train():
feature_cache_index = FLAGS.feature_cache + '.index'
if epoch % FLAGS.cache_for_epochs == 0 and os.path.isfile(feature_cache_index):
log_info('Invalidating feature cache')
os.remove(feature_cache_index) # this will let TF also overwrite the related cache data files
remove_remote(feature_cache_index) # this will let TF also overwrite the related cache data files
# Setup progress bar
class LossWidget(progressbar.widgets.FormatLabel):
@ -810,13 +812,13 @@ def export():
output_filename = FLAGS.export_file_name + '.pb'
if FLAGS.remove_export:
if os.path.isdir(FLAGS.export_dir):
if isdir_remote(FLAGS.export_dir):
log_info('Removing old export')
shutil.rmtree(FLAGS.export_dir)
remove_remote(FLAGS.export_dir)
output_graph_path = os.path.join(FLAGS.export_dir, output_filename)
if not os.path.isdir(FLAGS.export_dir):
if not is_remote_path(FLAGS.export_dir) and not os.path.isdir(FLAGS.export_dir):
os.makedirs(FLAGS.export_dir)
frozen_graph = tfv1.graph_util.convert_variables_to_constants(
@ -829,7 +831,7 @@ def export():
dest_nodes=output_names)
if not FLAGS.export_tflite:
with open(output_graph_path, 'wb') as fout:
with open_remote(output_graph_path, 'wb') as fout:
fout.write(frozen_graph.SerializeToString())
else:
output_tflite_path = os.path.join(FLAGS.export_dir, output_filename.replace('.pb', '.tflite'))
@ -840,7 +842,7 @@ def export():
converter.allow_custom_ops = True
tflite_model = converter.convert()
with open(output_tflite_path, 'wb') as fout:
with open_remote(output_tflite_path, 'wb') as fout:
fout.write(tflite_model)
log_info('Models exported at %s' % (FLAGS.export_dir))
@ -851,7 +853,7 @@ def export():
FLAGS.export_model_version))
model_runtime = 'tflite' if FLAGS.export_tflite else 'tensorflow'
with open(metadata_fname, 'w') as f:
with open_remote(metadata_fname, 'w') as f:
f.write('---\n')
f.write('author: {}\n'.format(FLAGS.export_author_id))
f.write('model_name: {}\n'.format(FLAGS.export_model_name))
@ -873,8 +875,12 @@ def export():
def package_zip():
# --export_dir path/to/export/LANG_CODE/ => path/to/export/LANG_CODE.zip
export_dir = os.path.join(os.path.abspath(FLAGS.export_dir), '') # Force ending '/'
zip_filename = os.path.dirname(export_dir)
if is_remote_path(export_dir):
log_error("Cannot package remote path zip %s. Please do this manually." % export_dir)
return
zip_filename = os.path.dirname(export_dir)
shutil.copy(FLAGS.scorer_path, export_dir)
archive = shutil.make_archive(zip_filename, 'zip', export_dir)
@ -959,7 +965,7 @@ def main(_):
tfv1.reset_default_graph()
FLAGS.export_tflite = True
if os.listdir(FLAGS.export_dir):
if listdir_remote(FLAGS.export_dir):
log_error('Directory {} is not empty, please fix this.'.format(FLAGS.export_dir))
sys.exit(1)

Просмотреть файл

@ -8,6 +8,7 @@ import numpy as np
from .helpers import LimitingPool
from collections import namedtuple
from .io import open_remote, remove_remote, copy_remote, is_remote_path
AudioFormat = namedtuple('AudioFormat', 'rate channels width')
@ -117,15 +118,19 @@ class Sample:
self.audio_type = new_audio_type
def _change_audio_type(sample_and_audio_type):
sample, audio_type, bitrate = sample_and_audio_type
def _unpack_and_change_audio_type(sample_and_audio_type):
packed_sample, audio_type, bitrate = sample_and_audio_type
if hasattr(packed_sample, 'unpack'):
sample = packed_sample.unpack()
else:
sample = packed_sample
sample.change_audio_type(audio_type, bitrate=bitrate)
return sample
def change_audio_types(samples, audio_type=AUDIO_TYPE_PCM, bitrate=None, processes=None, process_ahead=None):
def change_audio_types(packed_samples, audio_type=AUDIO_TYPE_PCM, bitrate=None, processes=None, process_ahead=None):
with LimitingPool(processes=processes, process_ahead=process_ahead) as pool:
yield from pool.imap(_change_audio_type, map(lambda s: (s, audio_type, bitrate), samples))
yield from pool.imap(_unpack_and_change_audio_type, map(lambda s: (s, audio_type, bitrate), packed_samples))
def get_audio_type_from_extension(ext):
@ -168,29 +173,45 @@ class AudioFile:
self.audio_format = audio_format
self.as_path = as_path
self.open_file = None
self.open_wav = None
self.tmp_file_path = None
self.tmp_src_file_path = None
def __enter__(self):
if self.audio_path.endswith('.wav'):
self.open_file = wave.open(self.audio_path, 'r')
if read_audio_format_from_wav_file(self.open_file) == self.audio_format:
self.open_file = open_remote(self.audio_path, 'rb')
self.open_wav = wave.open(self.open_file)
if read_audio_format_from_wav_file(self.open_wav) == self.audio_format:
if self.as_path:
self.open_wav.close()
self.open_file.close()
return self.audio_path
return self.open_file
return self.open_wav
self.open_wav.close()
self.open_file.close()
# If the format isn't right, copy the file to local tmp dir and do the conversion on disk
if is_remote_path(self.audio_path):
_, self.tmp_src_file_path = tempfile.mkstemp(suffix='.wav')
copy_remote(self.audio_path, self.tmp_src_file_path)
self.audio_path = self.tmp_file_path
_, self.tmp_file_path = tempfile.mkstemp(suffix='.wav')
convert_audio(self.audio_path, self.tmp_file_path, file_type='wav', audio_format=self.audio_format)
if self.as_path:
return self.tmp_file_path
self.open_file = wave.open(self.tmp_file_path, 'r')
return self.open_file
self.open_wav = wave.open(self.tmp_file_path, 'rb')
return self.open_wav
def __exit__(self, *args):
if not self.as_path:
self.open_file.close()
self.open_wav.close()
if self.open_file:
self.open_file.close()
if self.tmp_file_path is not None:
os.remove(self.tmp_file_path)
if self.tmp_src_file_path is not None:
os.remove(self.tmp_src_file_path)
def read_frames(wav_file, frame_duration_ms=30, yield_remainder=False):
@ -320,6 +341,7 @@ def read_opus(opus_file):
def write_wav(wav_file, pcm_data, audio_format=DEFAULT_FORMAT):
# wav_file is already a file-pointer here
with wave.open(wav_file, 'wb') as wav_file_writer:
wav_file_writer.setframerate(audio_format.rate)
wav_file_writer.setnchannels(audio_format.channels)

Просмотреть файл

@ -8,7 +8,7 @@ import numpy as np
from multiprocessing import Queue, Process
from .audio import gain_db_to_ratio, max_dbfs, normalize_audio, AUDIO_TYPE_NP, AUDIO_TYPE_PCM, AUDIO_TYPE_OPUS
from .helpers import LimitingPool, int_range, float_range, pick_value_from_range, tf_pick_value_from_range, MEGABYTE
from .sample_collections import samples_from_source
from .sample_collections import samples_from_source, unpack_maybe
BUFFER_SIZE = 1 * MEGABYTE
SPEC_PARSER = re.compile(r'^(?P<cls>[a-z_]+)(\[(?P<params>.*)\])?$')
@ -150,6 +150,12 @@ def _init_augmentation_worker(preparation_context):
AUGMENTATION_CONTEXT = preparation_context
def _load_and_augment_sample(timed_sample, context=None):
sample, clock = timed_sample
realized_sample = unpack_maybe(sample)
return _augment_sample((realized_sample, clock), context)
def _augment_sample(timed_sample, context=None):
context = AUGMENTATION_CONTEXT if context is None else context
sample, clock = timed_sample
@ -213,12 +219,12 @@ def apply_sample_augmentations(samples,
context = AugmentationContext(audio_type, augmentations)
if process_ahead == 0:
for timed_sample in timed_samples():
yield _augment_sample(timed_sample, context=context)
yield _load_and_augment_sample(timed_sample, context=context)
else:
with LimitingPool(process_ahead=process_ahead,
initializer=_init_augmentation_worker,
initargs=(context,)) as pool:
yield from pool.imap(_augment_sample, timed_samples())
yield from pool.imap(_load_and_augment_sample, timed_samples())
finally:
for augmentation in augmentations:
augmentation.stop()
@ -256,6 +262,7 @@ class Overlay(SampleAugmentation):
self.enqueue_process.start()
def apply(self, sample, clock=0.0):
sample = unpack_maybe(sample)
sample.change_audio_type(new_audio_type=AUDIO_TYPE_NP)
n_layers = pick_value_from_range(self.layers, clock=clock)
audio = sample.audio
@ -265,6 +272,7 @@ class Overlay(SampleAugmentation):
while overlay_offset < len(audio):
if self.current_sample is None:
next_overlay_sample = self.queue.get()
next_overlay_sample = unpack_maybe(next_overlay_sample)
next_overlay_sample.change_audio_type(new_audio_type=AUDIO_TYPE_NP)
self.current_sample = next_overlay_sample.audio
n_required = len(audio) - overlay_offset

Просмотреть файл

@ -19,6 +19,7 @@ import csv
import os
import sys
import unicodedata
from .io import open_remote
def main():
parser = argparse.ArgumentParser()
@ -27,14 +28,14 @@ def main():
parser.add_argument("-alpha", "--alphabet-format", help="Bool. Print in format for alphabet.txt", action="store_true")
parser.add_argument("-unicode", "--disable-unicode-variants", help="Bool. DISABLE check for unicode consistency (use with --alphabet-format)", action="store_true")
args = parser.parse_args()
in_files = [os.path.abspath(i) for i in args.csv_files.split(",")]
in_files = args.csv_files.split(",")
print("### Reading in the following transcript files: ###")
print("### {} ###".format(in_files))
all_text = set()
for in_file in in_files:
with open(in_file, "r") as csv_file:
with open_remote(in_file, "r") as csv_file:
reader = csv.reader(csv_file)
try:
next(reader, None) # skip the file header (i.e. "transcript")

Просмотреть файл

@ -13,7 +13,7 @@ from .gpu import get_available_gpus
from .logging import log_error, log_warn
from .helpers import parse_file_size
from .augmentations import parse_augmentations
from .io import path_exists_remote
class ConfigSingleton:
_config = None
@ -139,7 +139,7 @@ def initialize_globals():
c.audio_step_samples = FLAGS.audio_sample_rate * (FLAGS.feature_win_step / 1000)
if FLAGS.one_shot_infer:
if not os.path.exists(FLAGS.one_shot_infer):
if not path_exists_remote(FLAGS.one_shot_infer):
log_error('Path specified in --one_shot_infer is not a valid file.')
sys.exit(1)

Просмотреть файл

@ -2,6 +2,7 @@ import requests
import progressbar
from os import path, makedirs
from .io import open_remote, path_exists_remote, is_remote_path
SIMPLE_BAR = ['Progress ', progressbar.Bar(), ' ', progressbar.Percentage(), ' completed']
@ -9,17 +10,18 @@ def maybe_download(archive_name, target_dir, archive_url):
# If archive file does not exist, download it...
archive_path = path.join(target_dir, archive_name)
if not path.exists(target_dir):
if not is_remote_path(target_dir) and not path.exists(target_dir):
print('No path "%s" - creating ...' % target_dir)
makedirs(target_dir)
if not path.exists(archive_path):
if not path_exists_remote(archive_path):
print('No archive "%s" - downloading...' % archive_path)
req = requests.get(archive_url, stream=True)
total_size = int(req.headers.get('content-length', 0))
done = 0
with open(archive_path, 'wb') as f:
with open_remote(archive_path, 'wb') as f:
bar = progressbar.ProgressBar(max_value=total_size if total_size > 0 else progressbar.UnknownLength, widgets=SIMPLE_BAR)
for data in req.iter_content(1024*1024):
done += len(data)
f.write(data)

Просмотреть файл

@ -10,7 +10,7 @@ from attrdict import AttrDict
from .flags import FLAGS
from .text import levenshtein
from .io import open_remote
def pmap(fun, iterable):
pool = Pool()
@ -124,5 +124,5 @@ def save_samples_json(samples, output_path):
We set ensure_ascii=True to prevent json from escaping non-ASCII chars
in the texts.
'''
with open(output_path, 'w') as fout:
with open_remote(output_path, 'w') as fout:
json.dump(samples, fout, default=float, ensure_ascii=False, indent=2)

Просмотреть файл

@ -78,6 +78,32 @@ class Interleaved:
return self.len
class LenMap:
"""
Wrapper around python map() output object that preserves the original collection length
by implementing __len__.
"""
def __init__(self, fn, iterable):
try:
self.length = len(iterable)
except TypeError:
self.length = None
self.mapobj = map(fn, iterable)
def __iter__(self):
self.mapobj = self.mapobj.__iter__()
return self
def __next__(self):
return self.mapobj.__next__()
def __getitem__(self, key):
return self.mapobj.__getitem__(key)
def __len__(self):
return self.length
class LimitingPool:
"""Limits unbound ahead-processing of multiprocessing.Pool's imap method
before items get consumed by the iteration caller.

Просмотреть файл

@ -0,0 +1,81 @@
"""
A set of I/O utils that allow us to open files on remote storage as if they were present locally and access
into HDFS storage using Tensorflow's C++ FileStream API.
Currently only includes wrappers for Google's GCS, but this can easily be expanded for AWS S3 buckets.
"""
import os
from tensorflow.io import gfile
def is_remote_path(path):
"""
Returns True iff the path is one of the remote formats that this
module supports
"""
return path.startswith('gs://') or path.startswith('hdfs://')
def path_exists_remote(path):
"""
Wrapper that allows existance check of local and remote paths like
`gs://...`
"""
if is_remote_path(path):
return gfile.exists(path)
return os.path.exists(path)
def copy_remote(src, dst, overwrite=False):
"""
Allows us to copy a file from local to remote or vice versa
"""
return gfile.copy(src, dst, overwrite)
def open_remote(path, mode='r', buffering=-1, encoding=None, newline=None, closefd=True, opener=None):
"""
Wrapper around open() method that can handle remote paths like `gs://...`
off Google Cloud using Tensorflow's IO helpers.
buffering, encoding, newline, closefd, and opener are ignored for remote files
This enables us to do:
with open_remote('gs://.....', mode='w+') as f:
do something with the file f, whether or not we have local access to it
"""
if is_remote_path(path):
return gfile.GFile(path, mode=mode)
return open(path, mode, buffering=buffering, encoding=encoding, newline=newline, closefd=closefd, opener=opener)
def isdir_remote(path):
"""
Wrapper to check if remote and local paths are directories
"""
if is_remote_path(path):
return gfile.isdir(path)
return os.path.isdir(path)
def listdir_remote(path):
"""
Wrapper to list paths in local dirs (alternative to using a glob, I suppose)
"""
if is_remote_path(path):
return gfile.listdir(path)
return os.listdir(path)
def glob_remote(filename):
"""
Wrapper that provides globs on local and remote paths like `gs://...`
"""
return gfile.glob(filename)
def remove_remote(filename):
"""
Wrapper that can remove local and remote files like `gs://...`
"""
# Conditional import
return gfile.remove_remote(filename)

Просмотреть файл

@ -8,7 +8,7 @@ import tarfile
from pathlib import Path
from functools import partial
from .helpers import KILOBYTE, MEGABYTE, GIGABYTE, Interleaved
from .helpers import KILOBYTE, MEGABYTE, GIGABYTE, Interleaved, LenMap
from .audio import (
Sample,
DEFAULT_FORMAT,
@ -18,6 +18,7 @@ from .audio import (
get_audio_type_from_extension,
write_wav
)
from .io import open_remote, is_remote_path
BIG_ENDIAN = 'big'
INT_SIZE = 4
@ -59,6 +60,37 @@ class LabeledSample(Sample):
self.transcript = transcript
class PackedSample:
"""
A wrapper that we can carry around in an iterator and pass to a child process in order to
have the child process do the loading/unpacking of the sample, allowing for parallel file
I/O.
"""
def __init__(self, filename, audio_type, label):
self.filename = filename
self.audio_type = audio_type
self.label = label
def unpack(self):
with open_remote(self.filename, 'rb') as audio_file:
data = audio_file.read()
if self.label is None:
s = Sample(self.audio_type, data, sample_id=self.filename)
s = LabeledSample(self.audio_type, data, self.label, sample_id=self.filename)
return s
def unpack_maybe(sample):
"""
Loads the supplied sample from disk (or the network) if the audio isn't loaded in to memory already.
"""
if hasattr(sample, 'unpack'):
realized_sample = sample.unpack()
else:
realized_sample = sample
return realized_sample
def load_sample(filename, label=None):
"""
Loads audio-file as a (labeled or unlabeled) sample
@ -69,21 +101,19 @@ def load_sample(filename, label=None):
Filename of the audio-file to load as sample
label : str
Label (transcript) of the sample.
If None: return util.audio.Sample instance
Otherwise: return util.sample_collections.LabeledSample instance
If None: returned result.unpack() will return util.audio.Sample instance
Otherwise: returned result.unpack() util.sample_collections.LabeledSample instance
Returns
-------
util.audio.Sample instance if label is None, else util.sample_collections.LabeledSample instance
util.sample_collections.PackedSample, a wrapper object, on which calling unpack() will return
util.audio.Sample instance if label is None, else util.sample_collections.LabeledSample instance
"""
ext = os.path.splitext(filename)[1].lower()
audio_type = get_audio_type_from_extension(ext)
if audio_type is None:
raise ValueError('Unknown audio type extension "{}"'.format(ext))
with open(filename, 'rb') as audio_file:
if label is None:
return Sample(audio_type, audio_file.read(), sample_id=filename)
return LabeledSample(audio_type, audio_file.read(), label, sample_id=filename)
return PackedSample(filename, audio_type, label)
class DirectSDBWriter:
@ -119,7 +149,7 @@ class DirectSDBWriter:
raise ValueError('Audio type "{}" not supported'.format(audio_type))
self.audio_type = audio_type
self.bitrate = bitrate
self.sdb_file = open(sdb_filename, 'wb', buffering=buffering)
self.sdb_file = open_remote(sdb_filename, 'wb', buffering=buffering)
self.offsets = []
self.num_samples = 0
@ -215,7 +245,7 @@ class SDB: # pylint: disable=too-many-instance-attributes
"""
self.sdb_filename = sdb_filename
self.id_prefix = sdb_filename if id_prefix is None else id_prefix
self.sdb_file = open(sdb_filename, 'rb', buffering=REVERSE_BUFFER_SIZE if reverse else buffering)
self.sdb_file = open_remote(sdb_filename, 'rb', buffering=REVERSE_BUFFER_SIZE if reverse else buffering)
self.offsets = []
if self.sdb_file.read(len(MAGIC)) != MAGIC:
raise RuntimeError('No Sample Database')
@ -332,6 +362,8 @@ class CSVWriter: # pylint: disable=too-many-instance-attributes
labeled : bool or None
If True: Writes labeled samples (util.sample_collections.LabeledSample) only.
If False: Ignores transcripts (if available) and writes (unlabeled) util.audio.Sample instances.
Currently only works with local files (not gs:// or hdfs://...)
"""
self.csv_filename = Path(csv_filename)
self.csv_base_dir = self.csv_filename.parent.resolve().absolute()
@ -345,7 +377,7 @@ class CSVWriter: # pylint: disable=too-many-instance-attributes
self.labeled = labeled
if labeled:
fieldnames.append('transcript')
self.csv_file = open(csv_filename, 'w', encoding='utf-8', newline='')
self.csv_file = open_remote(csv_filename, 'w', encoding='utf-8', newline='')
self.csv_writer = csv.DictWriter(self.csv_file, fieldnames=fieldnames)
self.csv_writer.writeheader()
self.counter = 0
@ -380,7 +412,7 @@ class CSVWriter: # pylint: disable=too-many-instance-attributes
class TarWriter: # pylint: disable=too-many-instance-attributes
"""Sample collection writer for writing a CSV data-set and all its referenced WAV samples to a tar file"""
"""Sample collection writer for writing a CSV data-set and all its referenced WAV samples to a tar file."""
def __init__(self,
tar_filename,
gz=False,
@ -398,6 +430,8 @@ class TarWriter: # pylint: disable=too-many-instance-attributes
If False: Ignores transcripts (if available) and writes (unlabeled) util.audio.Sample instances.
include : str[]
List of files to include into tar root.
Currently only works with local files (not gs:// or hdfs://...)
"""
self.tar = tarfile.open(tar_filename, 'w:gz' if gz else 'w')
samples_dir = tarfile.TarInfo('samples')
@ -498,8 +532,7 @@ class CSV(SampleList):
If the order of the samples should be reversed
"""
rows = []
csv_dir = Path(csv_filename).parent
with open(csv_filename, 'r', encoding='utf8') as csv_file:
with open_remote(csv_filename, 'r', encoding='utf8') as csv_file:
reader = csv.DictReader(csv_file)
if 'transcript' in reader.fieldnames:
if labeled is None:
@ -508,9 +541,12 @@ class CSV(SampleList):
raise RuntimeError('No transcript data (missing CSV column)')
for row in reader:
wav_filename = Path(row['wav_filename'])
if not wav_filename.is_absolute():
wav_filename = csv_dir / wav_filename
wav_filename = str(wav_filename)
if not wav_filename.is_absolute() and not is_remote_path(row['wav_filename']):
wav_filename = Path(csv_filename).parent / wav_filename
wav_filename = str(wav_filename)
else:
# Pathlib otherwise removes a / from filenames like hdfs://
wav_filename = row['wav_filename']
wav_filesize = int(row['wav_filesize']) if 'wav_filesize' in row else 0
if labeled:
rows.append((wav_filename, wav_filesize, row['transcript']))
@ -554,6 +590,11 @@ def samples_from_sources(sample_sources, buffering=BUFFER_SIZE, labeled=None, re
Loads and combines samples from a list of source files. Sources are combined in an interleaving way to
keep default sample order from shortest to longest.
Note that when using distributed training, it is much faster to call this function with single pre-
sorted sample source, because this allows for parallelization of the file I/O. (If this function is
called with multiple sources, the samples have to be unpacked on a single parent process to allow
for reading their durations.)
Parameters
----------
sample_sources : list of str
@ -570,13 +611,20 @@ def samples_from_sources(sample_sources, buffering=BUFFER_SIZE, labeled=None, re
Returns
-------
iterable of util.sample_collections.LabeledSample (labeled=True) or util.audio.Sample (labeled=False) supporting len
iterable of util.sample_collections.PackedSample if a single collection is provided, wrapping
LabeledSample (labeled=True) or util.audio.Sample (labeled=False) supporting len
or LabeledSample / util.audio.Sample directly, if multiple collections are provided
"""
sample_sources = list(sample_sources)
if len(sample_sources) == 0:
raise ValueError('No files')
if len(sample_sources) == 1:
return samples_from_source(sample_sources[0], buffering=buffering, labeled=labeled, reverse=reverse)
cols = [samples_from_source(source, buffering=buffering, labeled=labeled, reverse=reverse)
for source in sample_sources]
# If we wish to interleave based on duration, we have to unpack the audio. Note that this unpacking should
# be done lazily onn the fly so that it respects the LimitingPool logic used in the feeding code.
cols = [LenMap(
unpack_maybe, samples_from_source(source, buffering=buffering, labeled=labeled, reverse=reverse))
for source in sample_sources]
return Interleaved(*cols, key=lambda s: s.duration, reverse=reverse)