diff --git a/bin/compare_samples.py b/bin/compare_samples.py index 94108a7a..27898cd1 100755 --- a/bin/compare_samples.py +++ b/bin/compare_samples.py @@ -15,8 +15,8 @@ def fail(message): def compare_samples(): - sample1 = load_sample(CLI_ARGS.sample1) - sample2 = load_sample(CLI_ARGS.sample2) + sample1 = load_sample(CLI_ARGS.sample1).unpack() + sample2 = load_sample(CLI_ARGS.sample2).unpack() if sample1.audio_format != sample2.audio_format: fail('Samples differ on: audio-format ({} and {})'.format(sample1.audio_format, sample2.audio_format)) if sample1.duration != sample2.duration: diff --git a/training/deepspeech_training/train.py b/training/deepspeech_training/train.py index 8bf7a354..94ca7c04 100644 --- a/training/deepspeech_training/train.py +++ b/training/deepspeech_training/train.py @@ -35,6 +35,7 @@ from .util.feeding import create_dataset, audio_to_features, audiofile_to_featur from .util.flags import create_flags, FLAGS from .util.helpers import check_ctcdecoder_version, ExceptionBox from .util.logging import create_progressbar, log_debug, log_error, log_info, log_progress, log_warn +from .util.io import open_remote, remove_remote, listdir_remote, is_remote_path, isdir_remote check_ctcdecoder_version() @@ -512,9 +513,10 @@ def train(): best_dev_path = os.path.join(FLAGS.save_checkpoint_dir, 'best_dev') # Save flags next to checkpoints - os.makedirs(FLAGS.save_checkpoint_dir, exist_ok=True) + if not is_remote_path(FLAGS.save_checkpoint_dir): + os.makedirs(FLAGS.save_checkpoint_dir, exist_ok=True) flags_file = os.path.join(FLAGS.save_checkpoint_dir, 'flags.txt') - with open(flags_file, 'w') as fout: + with open_remote(flags_file, 'w') as fout: fout.write(FLAGS.flags_into_string()) with tfv1.Session(config=Config.session_config) as session: @@ -541,7 +543,7 @@ def train(): feature_cache_index = FLAGS.feature_cache + '.index' if epoch % FLAGS.cache_for_epochs == 0 and os.path.isfile(feature_cache_index): log_info('Invalidating feature cache') - os.remove(feature_cache_index) # this will let TF also overwrite the related cache data files + remove_remote(feature_cache_index) # this will let TF also overwrite the related cache data files # Setup progress bar class LossWidget(progressbar.widgets.FormatLabel): @@ -810,13 +812,13 @@ def export(): output_filename = FLAGS.export_file_name + '.pb' if FLAGS.remove_export: - if os.path.isdir(FLAGS.export_dir): + if isdir_remote(FLAGS.export_dir): log_info('Removing old export') - shutil.rmtree(FLAGS.export_dir) + remove_remote(FLAGS.export_dir) output_graph_path = os.path.join(FLAGS.export_dir, output_filename) - if not os.path.isdir(FLAGS.export_dir): + if not is_remote_path(FLAGS.export_dir) and not os.path.isdir(FLAGS.export_dir): os.makedirs(FLAGS.export_dir) frozen_graph = tfv1.graph_util.convert_variables_to_constants( @@ -829,7 +831,7 @@ def export(): dest_nodes=output_names) if not FLAGS.export_tflite: - with open(output_graph_path, 'wb') as fout: + with open_remote(output_graph_path, 'wb') as fout: fout.write(frozen_graph.SerializeToString()) else: output_tflite_path = os.path.join(FLAGS.export_dir, output_filename.replace('.pb', '.tflite')) @@ -840,7 +842,7 @@ def export(): converter.allow_custom_ops = True tflite_model = converter.convert() - with open(output_tflite_path, 'wb') as fout: + with open_remote(output_tflite_path, 'wb') as fout: fout.write(tflite_model) log_info('Models exported at %s' % (FLAGS.export_dir)) @@ -851,7 +853,7 @@ def export(): FLAGS.export_model_version)) model_runtime = 'tflite' if FLAGS.export_tflite else 'tensorflow' - with open(metadata_fname, 'w') as f: + with open_remote(metadata_fname, 'w') as f: f.write('---\n') f.write('author: {}\n'.format(FLAGS.export_author_id)) f.write('model_name: {}\n'.format(FLAGS.export_model_name)) @@ -873,8 +875,12 @@ def export(): def package_zip(): # --export_dir path/to/export/LANG_CODE/ => path/to/export/LANG_CODE.zip export_dir = os.path.join(os.path.abspath(FLAGS.export_dir), '') # Force ending '/' - zip_filename = os.path.dirname(export_dir) + if is_remote_path(export_dir): + log_error("Cannot package remote path zip %s. Please do this manually." % export_dir) + return + zip_filename = os.path.dirname(export_dir) + shutil.copy(FLAGS.scorer_path, export_dir) archive = shutil.make_archive(zip_filename, 'zip', export_dir) @@ -959,7 +965,7 @@ def main(_): tfv1.reset_default_graph() FLAGS.export_tflite = True - if os.listdir(FLAGS.export_dir): + if listdir_remote(FLAGS.export_dir): log_error('Directory {} is not empty, please fix this.'.format(FLAGS.export_dir)) sys.exit(1) diff --git a/training/deepspeech_training/util/audio.py b/training/deepspeech_training/util/audio.py index 031f13ed..79318268 100644 --- a/training/deepspeech_training/util/audio.py +++ b/training/deepspeech_training/util/audio.py @@ -8,6 +8,7 @@ import numpy as np from .helpers import LimitingPool from collections import namedtuple +from .io import open_remote, remove_remote, copy_remote, is_remote_path AudioFormat = namedtuple('AudioFormat', 'rate channels width') @@ -117,15 +118,19 @@ class Sample: self.audio_type = new_audio_type -def _change_audio_type(sample_and_audio_type): - sample, audio_type, bitrate = sample_and_audio_type +def _unpack_and_change_audio_type(sample_and_audio_type): + packed_sample, audio_type, bitrate = sample_and_audio_type + if hasattr(packed_sample, 'unpack'): + sample = packed_sample.unpack() + else: + sample = packed_sample sample.change_audio_type(audio_type, bitrate=bitrate) return sample -def change_audio_types(samples, audio_type=AUDIO_TYPE_PCM, bitrate=None, processes=None, process_ahead=None): +def change_audio_types(packed_samples, audio_type=AUDIO_TYPE_PCM, bitrate=None, processes=None, process_ahead=None): with LimitingPool(processes=processes, process_ahead=process_ahead) as pool: - yield from pool.imap(_change_audio_type, map(lambda s: (s, audio_type, bitrate), samples)) + yield from pool.imap(_unpack_and_change_audio_type, map(lambda s: (s, audio_type, bitrate), packed_samples)) def get_audio_type_from_extension(ext): @@ -168,29 +173,45 @@ class AudioFile: self.audio_format = audio_format self.as_path = as_path self.open_file = None + self.open_wav = None self.tmp_file_path = None + self.tmp_src_file_path = None def __enter__(self): if self.audio_path.endswith('.wav'): - self.open_file = wave.open(self.audio_path, 'r') - if read_audio_format_from_wav_file(self.open_file) == self.audio_format: + self.open_file = open_remote(self.audio_path, 'rb') + self.open_wav = wave.open(self.open_file) + if read_audio_format_from_wav_file(self.open_wav) == self.audio_format: if self.as_path: + self.open_wav.close() self.open_file.close() return self.audio_path - return self.open_file + return self.open_wav + self.open_wav.close() self.open_file.close() + + # If the format isn't right, copy the file to local tmp dir and do the conversion on disk + if is_remote_path(self.audio_path): + _, self.tmp_src_file_path = tempfile.mkstemp(suffix='.wav') + copy_remote(self.audio_path, self.tmp_src_file_path) + self.audio_path = self.tmp_file_path + _, self.tmp_file_path = tempfile.mkstemp(suffix='.wav') convert_audio(self.audio_path, self.tmp_file_path, file_type='wav', audio_format=self.audio_format) if self.as_path: return self.tmp_file_path - self.open_file = wave.open(self.tmp_file_path, 'r') - return self.open_file + self.open_wav = wave.open(self.tmp_file_path, 'rb') + return self.open_wav def __exit__(self, *args): if not self.as_path: - self.open_file.close() + self.open_wav.close() + if self.open_file: + self.open_file.close() if self.tmp_file_path is not None: os.remove(self.tmp_file_path) + if self.tmp_src_file_path is not None: + os.remove(self.tmp_src_file_path) def read_frames(wav_file, frame_duration_ms=30, yield_remainder=False): @@ -320,6 +341,7 @@ def read_opus(opus_file): def write_wav(wav_file, pcm_data, audio_format=DEFAULT_FORMAT): + # wav_file is already a file-pointer here with wave.open(wav_file, 'wb') as wav_file_writer: wav_file_writer.setframerate(audio_format.rate) wav_file_writer.setnchannels(audio_format.channels) diff --git a/training/deepspeech_training/util/augmentations.py b/training/deepspeech_training/util/augmentations.py index 941c17f2..2422582c 100644 --- a/training/deepspeech_training/util/augmentations.py +++ b/training/deepspeech_training/util/augmentations.py @@ -8,7 +8,7 @@ import numpy as np from multiprocessing import Queue, Process from .audio import gain_db_to_ratio, max_dbfs, normalize_audio, AUDIO_TYPE_NP, AUDIO_TYPE_PCM, AUDIO_TYPE_OPUS from .helpers import LimitingPool, int_range, float_range, pick_value_from_range, tf_pick_value_from_range, MEGABYTE -from .sample_collections import samples_from_source +from .sample_collections import samples_from_source, unpack_maybe BUFFER_SIZE = 1 * MEGABYTE SPEC_PARSER = re.compile(r'^(?P[a-z_]+)(\[(?P.*)\])?$') @@ -150,6 +150,12 @@ def _init_augmentation_worker(preparation_context): AUGMENTATION_CONTEXT = preparation_context +def _load_and_augment_sample(timed_sample, context=None): + sample, clock = timed_sample + realized_sample = unpack_maybe(sample) + return _augment_sample((realized_sample, clock), context) + + def _augment_sample(timed_sample, context=None): context = AUGMENTATION_CONTEXT if context is None else context sample, clock = timed_sample @@ -213,12 +219,12 @@ def apply_sample_augmentations(samples, context = AugmentationContext(audio_type, augmentations) if process_ahead == 0: for timed_sample in timed_samples(): - yield _augment_sample(timed_sample, context=context) + yield _load_and_augment_sample(timed_sample, context=context) else: with LimitingPool(process_ahead=process_ahead, initializer=_init_augmentation_worker, initargs=(context,)) as pool: - yield from pool.imap(_augment_sample, timed_samples()) + yield from pool.imap(_load_and_augment_sample, timed_samples()) finally: for augmentation in augmentations: augmentation.stop() @@ -256,6 +262,7 @@ class Overlay(SampleAugmentation): self.enqueue_process.start() def apply(self, sample, clock=0.0): + sample = unpack_maybe(sample) sample.change_audio_type(new_audio_type=AUDIO_TYPE_NP) n_layers = pick_value_from_range(self.layers, clock=clock) audio = sample.audio @@ -265,6 +272,7 @@ class Overlay(SampleAugmentation): while overlay_offset < len(audio): if self.current_sample is None: next_overlay_sample = self.queue.get() + next_overlay_sample = unpack_maybe(next_overlay_sample) next_overlay_sample.change_audio_type(new_audio_type=AUDIO_TYPE_NP) self.current_sample = next_overlay_sample.audio n_required = len(audio) - overlay_offset diff --git a/training/deepspeech_training/util/check_characters.py b/training/deepspeech_training/util/check_characters.py index f155b4ac..7e6cdd0b 100644 --- a/training/deepspeech_training/util/check_characters.py +++ b/training/deepspeech_training/util/check_characters.py @@ -19,6 +19,7 @@ import csv import os import sys import unicodedata +from .io import open_remote def main(): parser = argparse.ArgumentParser() @@ -27,14 +28,14 @@ def main(): parser.add_argument("-alpha", "--alphabet-format", help="Bool. Print in format for alphabet.txt", action="store_true") parser.add_argument("-unicode", "--disable-unicode-variants", help="Bool. DISABLE check for unicode consistency (use with --alphabet-format)", action="store_true") args = parser.parse_args() - in_files = [os.path.abspath(i) for i in args.csv_files.split(",")] + in_files = args.csv_files.split(",") print("### Reading in the following transcript files: ###") print("### {} ###".format(in_files)) all_text = set() for in_file in in_files: - with open(in_file, "r") as csv_file: + with open_remote(in_file, "r") as csv_file: reader = csv.reader(csv_file) try: next(reader, None) # skip the file header (i.e. "transcript") diff --git a/training/deepspeech_training/util/config.py b/training/deepspeech_training/util/config.py index 0b9929e5..18da6eed 100755 --- a/training/deepspeech_training/util/config.py +++ b/training/deepspeech_training/util/config.py @@ -13,7 +13,7 @@ from .gpu import get_available_gpus from .logging import log_error, log_warn from .helpers import parse_file_size from .augmentations import parse_augmentations - +from .io import path_exists_remote class ConfigSingleton: _config = None @@ -139,7 +139,7 @@ def initialize_globals(): c.audio_step_samples = FLAGS.audio_sample_rate * (FLAGS.feature_win_step / 1000) if FLAGS.one_shot_infer: - if not os.path.exists(FLAGS.one_shot_infer): + if not path_exists_remote(FLAGS.one_shot_infer): log_error('Path specified in --one_shot_infer is not a valid file.') sys.exit(1) diff --git a/training/deepspeech_training/util/downloader.py b/training/deepspeech_training/util/downloader.py index a6d57e3e..c527eb9b 100644 --- a/training/deepspeech_training/util/downloader.py +++ b/training/deepspeech_training/util/downloader.py @@ -2,6 +2,7 @@ import requests import progressbar from os import path, makedirs +from .io import open_remote, path_exists_remote, is_remote_path SIMPLE_BAR = ['Progress ', progressbar.Bar(), ' ', progressbar.Percentage(), ' completed'] @@ -9,17 +10,18 @@ def maybe_download(archive_name, target_dir, archive_url): # If archive file does not exist, download it... archive_path = path.join(target_dir, archive_name) - if not path.exists(target_dir): + if not is_remote_path(target_dir) and not path.exists(target_dir): print('No path "%s" - creating ...' % target_dir) makedirs(target_dir) - if not path.exists(archive_path): + if not path_exists_remote(archive_path): print('No archive "%s" - downloading...' % archive_path) req = requests.get(archive_url, stream=True) total_size = int(req.headers.get('content-length', 0)) done = 0 - with open(archive_path, 'wb') as f: + with open_remote(archive_path, 'wb') as f: bar = progressbar.ProgressBar(max_value=total_size if total_size > 0 else progressbar.UnknownLength, widgets=SIMPLE_BAR) + for data in req.iter_content(1024*1024): done += len(data) f.write(data) diff --git a/training/deepspeech_training/util/evaluate_tools.py b/training/deepspeech_training/util/evaluate_tools.py index 66fc8293..68d29f3e 100644 --- a/training/deepspeech_training/util/evaluate_tools.py +++ b/training/deepspeech_training/util/evaluate_tools.py @@ -10,7 +10,7 @@ from attrdict import AttrDict from .flags import FLAGS from .text import levenshtein - +from .io import open_remote def pmap(fun, iterable): pool = Pool() @@ -124,5 +124,5 @@ def save_samples_json(samples, output_path): We set ensure_ascii=True to prevent json from escaping non-ASCII chars in the texts. ''' - with open(output_path, 'w') as fout: + with open_remote(output_path, 'w') as fout: json.dump(samples, fout, default=float, ensure_ascii=False, indent=2) diff --git a/training/deepspeech_training/util/helpers.py b/training/deepspeech_training/util/helpers.py index 195c117e..7545c8ee 100644 --- a/training/deepspeech_training/util/helpers.py +++ b/training/deepspeech_training/util/helpers.py @@ -78,6 +78,32 @@ class Interleaved: return self.len +class LenMap: + """ + Wrapper around python map() output object that preserves the original collection length + by implementing __len__. + """ + def __init__(self, fn, iterable): + try: + self.length = len(iterable) + except TypeError: + self.length = None + self.mapobj = map(fn, iterable) + + def __iter__(self): + self.mapobj = self.mapobj.__iter__() + return self + + def __next__(self): + return self.mapobj.__next__() + + def __getitem__(self, key): + return self.mapobj.__getitem__(key) + + def __len__(self): + return self.length + + class LimitingPool: """Limits unbound ahead-processing of multiprocessing.Pool's imap method before items get consumed by the iteration caller. diff --git a/training/deepspeech_training/util/io.py b/training/deepspeech_training/util/io.py new file mode 100644 index 00000000..947b43af --- /dev/null +++ b/training/deepspeech_training/util/io.py @@ -0,0 +1,81 @@ +""" +A set of I/O utils that allow us to open files on remote storage as if they were present locally and access +into HDFS storage using Tensorflow's C++ FileStream API. +Currently only includes wrappers for Google's GCS, but this can easily be expanded for AWS S3 buckets. +""" +import os +from tensorflow.io import gfile + + +def is_remote_path(path): + """ + Returns True iff the path is one of the remote formats that this + module supports + """ + return path.startswith('gs://') or path.startswith('hdfs://') + + +def path_exists_remote(path): + """ + Wrapper that allows existance check of local and remote paths like + `gs://...` + """ + if is_remote_path(path): + return gfile.exists(path) + return os.path.exists(path) + + +def copy_remote(src, dst, overwrite=False): + """ + Allows us to copy a file from local to remote or vice versa + """ + return gfile.copy(src, dst, overwrite) + + +def open_remote(path, mode='r', buffering=-1, encoding=None, newline=None, closefd=True, opener=None): + """ + Wrapper around open() method that can handle remote paths like `gs://...` + off Google Cloud using Tensorflow's IO helpers. + + buffering, encoding, newline, closefd, and opener are ignored for remote files + + This enables us to do: + with open_remote('gs://.....', mode='w+') as f: + do something with the file f, whether or not we have local access to it + """ + if is_remote_path(path): + return gfile.GFile(path, mode=mode) + return open(path, mode, buffering=buffering, encoding=encoding, newline=newline, closefd=closefd, opener=opener) + + +def isdir_remote(path): + """ + Wrapper to check if remote and local paths are directories + """ + if is_remote_path(path): + return gfile.isdir(path) + return os.path.isdir(path) + + +def listdir_remote(path): + """ + Wrapper to list paths in local dirs (alternative to using a glob, I suppose) + """ + if is_remote_path(path): + return gfile.listdir(path) + return os.listdir(path) + + +def glob_remote(filename): + """ + Wrapper that provides globs on local and remote paths like `gs://...` + """ + return gfile.glob(filename) + + +def remove_remote(filename): + """ + Wrapper that can remove local and remote files like `gs://...` + """ + # Conditional import + return gfile.remove_remote(filename) \ No newline at end of file diff --git a/training/deepspeech_training/util/sample_collections.py b/training/deepspeech_training/util/sample_collections.py index 3f1b55ea..085439c9 100644 --- a/training/deepspeech_training/util/sample_collections.py +++ b/training/deepspeech_training/util/sample_collections.py @@ -8,7 +8,7 @@ import tarfile from pathlib import Path from functools import partial -from .helpers import KILOBYTE, MEGABYTE, GIGABYTE, Interleaved +from .helpers import KILOBYTE, MEGABYTE, GIGABYTE, Interleaved, LenMap from .audio import ( Sample, DEFAULT_FORMAT, @@ -18,6 +18,7 @@ from .audio import ( get_audio_type_from_extension, write_wav ) +from .io import open_remote, is_remote_path BIG_ENDIAN = 'big' INT_SIZE = 4 @@ -59,6 +60,37 @@ class LabeledSample(Sample): self.transcript = transcript +class PackedSample: + """ + A wrapper that we can carry around in an iterator and pass to a child process in order to + have the child process do the loading/unpacking of the sample, allowing for parallel file + I/O. + """ + def __init__(self, filename, audio_type, label): + self.filename = filename + self.audio_type = audio_type + self.label = label + + def unpack(self): + with open_remote(self.filename, 'rb') as audio_file: + data = audio_file.read() + if self.label is None: + s = Sample(self.audio_type, data, sample_id=self.filename) + s = LabeledSample(self.audio_type, data, self.label, sample_id=self.filename) + return s + + +def unpack_maybe(sample): + """ + Loads the supplied sample from disk (or the network) if the audio isn't loaded in to memory already. + """ + if hasattr(sample, 'unpack'): + realized_sample = sample.unpack() + else: + realized_sample = sample + return realized_sample + + def load_sample(filename, label=None): """ Loads audio-file as a (labeled or unlabeled) sample @@ -69,21 +101,19 @@ def load_sample(filename, label=None): Filename of the audio-file to load as sample label : str Label (transcript) of the sample. - If None: return util.audio.Sample instance - Otherwise: return util.sample_collections.LabeledSample instance + If None: returned result.unpack() will return util.audio.Sample instance + Otherwise: returned result.unpack() util.sample_collections.LabeledSample instance Returns ------- - util.audio.Sample instance if label is None, else util.sample_collections.LabeledSample instance + util.sample_collections.PackedSample, a wrapper object, on which calling unpack() will return + util.audio.Sample instance if label is None, else util.sample_collections.LabeledSample instance """ ext = os.path.splitext(filename)[1].lower() audio_type = get_audio_type_from_extension(ext) if audio_type is None: raise ValueError('Unknown audio type extension "{}"'.format(ext)) - with open(filename, 'rb') as audio_file: - if label is None: - return Sample(audio_type, audio_file.read(), sample_id=filename) - return LabeledSample(audio_type, audio_file.read(), label, sample_id=filename) + return PackedSample(filename, audio_type, label) class DirectSDBWriter: @@ -119,7 +149,7 @@ class DirectSDBWriter: raise ValueError('Audio type "{}" not supported'.format(audio_type)) self.audio_type = audio_type self.bitrate = bitrate - self.sdb_file = open(sdb_filename, 'wb', buffering=buffering) + self.sdb_file = open_remote(sdb_filename, 'wb', buffering=buffering) self.offsets = [] self.num_samples = 0 @@ -215,7 +245,7 @@ class SDB: # pylint: disable=too-many-instance-attributes """ self.sdb_filename = sdb_filename self.id_prefix = sdb_filename if id_prefix is None else id_prefix - self.sdb_file = open(sdb_filename, 'rb', buffering=REVERSE_BUFFER_SIZE if reverse else buffering) + self.sdb_file = open_remote(sdb_filename, 'rb', buffering=REVERSE_BUFFER_SIZE if reverse else buffering) self.offsets = [] if self.sdb_file.read(len(MAGIC)) != MAGIC: raise RuntimeError('No Sample Database') @@ -332,6 +362,8 @@ class CSVWriter: # pylint: disable=too-many-instance-attributes labeled : bool or None If True: Writes labeled samples (util.sample_collections.LabeledSample) only. If False: Ignores transcripts (if available) and writes (unlabeled) util.audio.Sample instances. + + Currently only works with local files (not gs:// or hdfs://...) """ self.csv_filename = Path(csv_filename) self.csv_base_dir = self.csv_filename.parent.resolve().absolute() @@ -345,7 +377,7 @@ class CSVWriter: # pylint: disable=too-many-instance-attributes self.labeled = labeled if labeled: fieldnames.append('transcript') - self.csv_file = open(csv_filename, 'w', encoding='utf-8', newline='') + self.csv_file = open_remote(csv_filename, 'w', encoding='utf-8', newline='') self.csv_writer = csv.DictWriter(self.csv_file, fieldnames=fieldnames) self.csv_writer.writeheader() self.counter = 0 @@ -380,7 +412,7 @@ class CSVWriter: # pylint: disable=too-many-instance-attributes class TarWriter: # pylint: disable=too-many-instance-attributes - """Sample collection writer for writing a CSV data-set and all its referenced WAV samples to a tar file""" + """Sample collection writer for writing a CSV data-set and all its referenced WAV samples to a tar file.""" def __init__(self, tar_filename, gz=False, @@ -398,6 +430,8 @@ class TarWriter: # pylint: disable=too-many-instance-attributes If False: Ignores transcripts (if available) and writes (unlabeled) util.audio.Sample instances. include : str[] List of files to include into tar root. + + Currently only works with local files (not gs:// or hdfs://...) """ self.tar = tarfile.open(tar_filename, 'w:gz' if gz else 'w') samples_dir = tarfile.TarInfo('samples') @@ -498,8 +532,7 @@ class CSV(SampleList): If the order of the samples should be reversed """ rows = [] - csv_dir = Path(csv_filename).parent - with open(csv_filename, 'r', encoding='utf8') as csv_file: + with open_remote(csv_filename, 'r', encoding='utf8') as csv_file: reader = csv.DictReader(csv_file) if 'transcript' in reader.fieldnames: if labeled is None: @@ -508,9 +541,12 @@ class CSV(SampleList): raise RuntimeError('No transcript data (missing CSV column)') for row in reader: wav_filename = Path(row['wav_filename']) - if not wav_filename.is_absolute(): - wav_filename = csv_dir / wav_filename - wav_filename = str(wav_filename) + if not wav_filename.is_absolute() and not is_remote_path(row['wav_filename']): + wav_filename = Path(csv_filename).parent / wav_filename + wav_filename = str(wav_filename) + else: + # Pathlib otherwise removes a / from filenames like hdfs:// + wav_filename = row['wav_filename'] wav_filesize = int(row['wav_filesize']) if 'wav_filesize' in row else 0 if labeled: rows.append((wav_filename, wav_filesize, row['transcript'])) @@ -554,6 +590,11 @@ def samples_from_sources(sample_sources, buffering=BUFFER_SIZE, labeled=None, re Loads and combines samples from a list of source files. Sources are combined in an interleaving way to keep default sample order from shortest to longest. + Note that when using distributed training, it is much faster to call this function with single pre- + sorted sample source, because this allows for parallelization of the file I/O. (If this function is + called with multiple sources, the samples have to be unpacked on a single parent process to allow + for reading their durations.) + Parameters ---------- sample_sources : list of str @@ -570,13 +611,20 @@ def samples_from_sources(sample_sources, buffering=BUFFER_SIZE, labeled=None, re Returns ------- - iterable of util.sample_collections.LabeledSample (labeled=True) or util.audio.Sample (labeled=False) supporting len + iterable of util.sample_collections.PackedSample if a single collection is provided, wrapping + LabeledSample (labeled=True) or util.audio.Sample (labeled=False) supporting len + or LabeledSample / util.audio.Sample directly, if multiple collections are provided """ sample_sources = list(sample_sources) if len(sample_sources) == 0: raise ValueError('No files') if len(sample_sources) == 1: return samples_from_source(sample_sources[0], buffering=buffering, labeled=labeled, reverse=reverse) - cols = [samples_from_source(source, buffering=buffering, labeled=labeled, reverse=reverse) - for source in sample_sources] + + # If we wish to interleave based on duration, we have to unpack the audio. Note that this unpacking should + # be done lazily onn the fly so that it respects the LimitingPool logic used in the feeding code. + cols = [LenMap( + unpack_maybe, samples_from_source(source, buffering=buffering, labeled=labeled, reverse=reverse)) + for source in sample_sources] + return Interleaved(*cols, key=lambda s: s.duration, reverse=reverse)