Revert "Merge pull request #3420 from CatalinVoss/remote-io"

This reverts commit 08d18d7328, reversing
changes made to 12badcce1f.
This commit is contained in:
Reuben Morais 2020-11-19 16:58:21 +02:00
Родитель f5cbda694a
Коммит 88f7297215
10 изменённых файлов: 53 добавлений и 191 удалений

Просмотреть файл

@ -35,7 +35,6 @@ from .util.feeding import create_dataset, audio_to_features, audiofile_to_featur
from .util.flags import create_flags, FLAGS
from .util.helpers import check_ctcdecoder_version, ExceptionBox
from .util.logging import create_progressbar, log_debug, log_error, log_info, log_progress, log_warn
from .util.io import open_remote, remove_remote, listdir_remote, is_remote_path, isdir_remote
check_ctcdecoder_version()
@ -513,10 +512,9 @@ def train():
best_dev_path = os.path.join(FLAGS.save_checkpoint_dir, 'best_dev')
# Save flags next to checkpoints
if not is_remote_path(FLAGS.save_checkpoint_dir):
os.makedirs(FLAGS.save_checkpoint_dir, exist_ok=True)
os.makedirs(FLAGS.save_checkpoint_dir, exist_ok=True)
flags_file = os.path.join(FLAGS.save_checkpoint_dir, 'flags.txt')
with open_remote(flags_file, 'w') as fout:
with open(flags_file, 'w') as fout:
fout.write(FLAGS.flags_into_string())
with tfv1.Session(config=Config.session_config) as session:
@ -543,7 +541,7 @@ def train():
feature_cache_index = FLAGS.feature_cache + '.index'
if epoch % FLAGS.cache_for_epochs == 0 and os.path.isfile(feature_cache_index):
log_info('Invalidating feature cache')
remove_remote(feature_cache_index) # this will let TF also overwrite the related cache data files
os.remove(feature_cache_index) # this will let TF also overwrite the related cache data files
# Setup progress bar
class LossWidget(progressbar.widgets.FormatLabel):
@ -811,14 +809,14 @@ def export():
load_graph_for_evaluation(session)
output_filename = FLAGS.export_file_name + '.pb'
if FLAGS.remove_remote_export:
if isdir_remote(FLAGS.export_dir):
if FLAGS.remove_export:
if os.path.isdir(FLAGS.export_dir):
log_info('Removing old export')
remove_remote(FLAGS.export_dir)
shutil.rmtree(FLAGS.export_dir)
output_graph_path = os.path.join(FLAGS.export_dir, output_filename)
if not is_remote_path(FLAGS.export_dir) and not os.path.isdir(FLAGS.export_dir):
if not os.path.isdir(FLAGS.export_dir):
os.makedirs(FLAGS.export_dir)
frozen_graph = tfv1.graph_util.convert_variables_to_constants(
@ -831,7 +829,7 @@ def export():
dest_nodes=output_names)
if not FLAGS.export_tflite:
with open_remote(output_graph_path, 'wb') as fout:
with open(output_graph_path, 'wb') as fout:
fout.write(frozen_graph.SerializeToString())
else:
output_tflite_path = os.path.join(FLAGS.export_dir, output_filename.replace('.pb', '.tflite'))
@ -842,7 +840,7 @@ def export():
converter.allow_custom_ops = True
tflite_model = converter.convert()
with open_remote(output_tflite_path, 'wb') as fout:
with open(output_tflite_path, 'wb') as fout:
fout.write(tflite_model)
log_info('Models exported at %s' % (FLAGS.export_dir))
@ -853,7 +851,7 @@ def export():
FLAGS.export_model_version))
model_runtime = 'tflite' if FLAGS.export_tflite else 'tensorflow'
with open_remote(metadata_fname, 'w') as f:
with open(metadata_fname, 'w') as f:
f.write('---\n')
f.write('author: {}\n'.format(FLAGS.export_author_id))
f.write('model_name: {}\n'.format(FLAGS.export_model_name))
@ -875,12 +873,8 @@ def export():
def package_zip():
# --export_dir path/to/export/LANG_CODE/ => path/to/export/LANG_CODE.zip
export_dir = os.path.join(os.path.abspath(FLAGS.export_dir), '') # Force ending '/'
if is_remote_path(export_dir):
log_error("Cannot package remote path zip %s. Please do this manually." % export_dir)
return
zip_filename = os.path.dirname(export_dir)
shutil.copy(FLAGS.scorer_path, export_dir)
archive = shutil.make_archive(zip_filename, 'zip', export_dir)
@ -965,7 +959,7 @@ def main(_):
tfv1.reset_default_graph()
FLAGS.export_tflite = True
if listdir_remote(FLAGS.export_dir):
if os.listdir(FLAGS.export_dir):
log_error('Directory {} is not empty, please fix this.'.format(FLAGS.export_dir))
sys.exit(1)

Просмотреть файл

@ -8,7 +8,6 @@ import numpy as np
from .helpers import LimitingPool
from collections import namedtuple
from .io import open_remote, remove_remote, copy_remote, is_remote_path
AudioFormat = namedtuple('AudioFormat', 'rate channels width')
@ -169,45 +168,29 @@ class AudioFile:
self.audio_format = audio_format
self.as_path = as_path
self.open_file = None
self.open_wav = None
self.tmp_file_path = None
self.tmp_src_file_path = None
def __enter__(self):
if self.audio_path.endswith('.wav'):
self.open_file = open_remote(self.audio_path, 'r')
self.open_wav = wave.open(self.open_file)
if read_audio_format_from_wav_file(self.open_wav) == self.audio_format:
self.open_file = wave.open(self.audio_path, 'r')
if read_audio_format_from_wav_file(self.open_file) == self.audio_format:
if self.as_path:
self.open_wav.close()
self.open_file.close()
return self.audio_path
return self.open_wav
self.open_wav.close()
return self.open_file
self.open_file.close()
# If the format isn't right, copy the file to local tmp dir and do the conversion on disk
if is_remote_path(self.audio_path):
_, self.tmp_src_file_path = tempfile.mkstemp(suffix='.wav')
copy_remote(self.audio_path, self.tmp_src_file_path)
self.audio_path = self.tmp_file_path
_, self.tmp_file_path = tempfile.mkstemp(suffix='.wav')
convert_audio(self.audio_path, self.tmp_file_path, file_type='wav', audio_format=self.audio_format)
if self.as_path:
return self.tmp_file_path
self.open_wav = wave.open(self.tmp_file_path, 'r')
return self.open_wav
self.open_file = wave.open(self.tmp_file_path, 'r')
return self.open_file
def __exit__(self, *args):
if not self.as_path:
self.open_wav.close()
if self.open_file:
self.open_file.close()
self.open_file.close()
if self.tmp_file_path is not None:
os.remove(self.tmp_file_path)
if self.tmp_src_file_path is not None:
os.remove(self.tmp_src_file_path)
def read_frames(wav_file, frame_duration_ms=30, yield_remainder=False):
@ -337,7 +320,6 @@ def read_opus(opus_file):
def write_wav(wav_file, pcm_data, audio_format=DEFAULT_FORMAT):
# wav_file is already a file-pointer here
with wave.open(wav_file, 'wb') as wav_file_writer:
wav_file_writer.setframerate(audio_format.rate)
wav_file_writer.setnchannels(audio_format.channels)

Просмотреть файл

@ -150,12 +150,6 @@ def _init_augmentation_worker(preparation_context):
AUGMENTATION_CONTEXT = preparation_context
def _load_and_augment_sample(timed_sample, context=None):
sample, clock = timed_sample
realized_sample = sample.unpack()
return _augment_sample((realized_sample, clock), context)
def _augment_sample(timed_sample, context=None):
context = AUGMENTATION_CONTEXT if context is None else context
sample, clock = timed_sample
@ -219,12 +213,12 @@ def apply_sample_augmentations(samples,
context = AugmentationContext(audio_type, augmentations)
if process_ahead == 0:
for timed_sample in timed_samples():
yield _load_and_augment_sample(timed_sample, context=context)
yield _augment_sample(timed_sample, context=context)
else:
with LimitingPool(process_ahead=process_ahead,
initializer=_init_augmentation_worker,
initargs=(context,)) as pool:
yield from pool.imap(_load_and_augment_sample, timed_samples())
yield from pool.imap(_augment_sample, timed_samples())
finally:
for augmentation in augmentations:
augmentation.stop()

Просмотреть файл

@ -19,7 +19,6 @@ import csv
import os
import sys
import unicodedata
from .io import open_remote
def main():
parser = argparse.ArgumentParser()
@ -28,14 +27,14 @@ def main():
parser.add_argument("-alpha", "--alphabet-format", help="Bool. Print in format for alphabet.txt", action="store_true")
parser.add_argument("-unicode", "--disable-unicode-variants", help="Bool. DISABLE check for unicode consistency (use with --alphabet-format)", action="store_true")
args = parser.parse_args()
in_files = args.csv_files.split(",")
in_files = [os.path.abspath(i) for i in args.csv_files.split(",")]
print("### Reading in the following transcript files: ###")
print("### {} ###".format(in_files))
all_text = set()
for in_file in in_files:
with open_remote(in_file, "r") as csv_file:
with open(in_file, "r") as csv_file:
reader = csv.reader(csv_file)
try:
next(reader, None) # skip the file header (i.e. "transcript")

Просмотреть файл

@ -13,7 +13,7 @@ from .gpu import get_available_gpus
from .logging import log_error, log_warn
from .helpers import parse_file_size
from .augmentations import parse_augmentations
from .io import path_exists_remote
class ConfigSingleton:
_config = None
@ -139,7 +139,7 @@ def initialize_globals():
c.audio_step_samples = FLAGS.audio_sample_rate * (FLAGS.feature_win_step / 1000)
if FLAGS.one_shot_infer:
if not path_exists_remote(FLAGS.one_shot_infer):
if not os.path.exists(FLAGS.one_shot_infer):
log_error('Path specified in --one_shot_infer is not a valid file.')
sys.exit(1)

Просмотреть файл

@ -2,7 +2,6 @@ import requests
import progressbar
from os import path, makedirs
from .io import open_remote, path_exists_remote, is_remote_path
SIMPLE_BAR = ['Progress ', progressbar.Bar(), ' ', progressbar.Percentage(), ' completed']
@ -10,16 +9,16 @@ def maybe_download(archive_name, target_dir, archive_url):
# If archive file does not exist, download it...
archive_path = path.join(target_dir, archive_name)
if not is_remote_path(target_dir) and not path.exists(target_dir):
if not path.exists(target_dir):
print('No path "%s" - creating ...' % target_dir)
makedirs(target_dir)
if not path_exists_remote(archive_path):
if not path.exists(archive_path):
print('No archive "%s" - downloading...' % archive_path)
req = requests.get(archive_url, stream=True)
total_size = int(req.headers.get('content-length', 0))
done = 0
with open_remote(archive_path, 'wb') as f:
with open(archive_path, 'wb') as f:
bar = progressbar.ProgressBar(max_value=total_size, widgets=SIMPLE_BAR)
for data in req.iter_content(1024*1024):
done += len(data)

Просмотреть файл

@ -10,7 +10,7 @@ from attrdict import AttrDict
from .flags import FLAGS
from .text import levenshtein
from .io import open_remote
def pmap(fun, iterable):
pool = Pool()
@ -124,5 +124,5 @@ def save_samples_json(samples, output_path):
We set ensure_ascii=True to prevent json from escaping non-ASCII chars
in the texts.
'''
with open_remote(output_path, 'w') as fout:
with open(output_path, 'w') as fout:
json.dump(samples, fout, default=float, ensure_ascii=False, indent=2)

Просмотреть файл

@ -1,81 +0,0 @@
"""
A set of I/O utils that allow us to open files on remote storage as if they were present locally and access
into HDFS storage using Tensorflow's C++ FileStream API.
Currently only includes wrappers for Google's GCS, but this can easily be expanded for AWS S3 buckets.
"""
import os
from tensorflow.io import gfile
def is_remote_path(path):
"""
Returns True iff the path is one of the remote formats that this
module supports
"""
return path.startswith('gs://') or path.startswith('hdfs://')
def path_exists_remote(path):
"""
Wrapper that allows existance check of local and remote paths like
`gs://...`
"""
if is_remote_path(path):
return gfile.exists(path)
return os.path.exists(path)
def copy_remote(src, dst, overwrite=False):
"""
Allows us to copy a file from local to remote or vice versa
"""
return gfile.copy(src, dst, overwrite)
def open_remote(path, mode='r', buffering=-1, encoding=None, newline=None, closefd=True, opener=None):
"""
Wrapper around open() method that can handle remote paths like `gs://...`
off Google Cloud using Tensorflow's IO helpers.
buffering, encoding, newline, closefd, and opener are ignored for remote files
This enables us to do:
with open_remote('gs://.....', mode='w+') as f:
do something with the file f, whether or not we have local access to it
"""
if is_remote_path(path):
return gfile.GFile(path, mode=mode)
return open(path, mode, buffering=buffering, encoding=encoding, newline=newline, closefd=closefd, opener=opener)
def isdir_remote(path):
"""
Wrapper to check if remote and local paths are directories
"""
if is_remote_path(path):
return gfile.isdir(path)
return os.path.isdir(path)
def listdir_remote(path):
"""
Wrapper to list paths in local dirs (alternative to using a glob, I suppose)
"""
if is_remote_path(path):
return gfile.listdir(path)
return os.listdir(path)
def glob_remote(filename):
"""
Wrapper that provides globs on local and remote paths like `gs://...`
"""
return gfile.glob(filename)
def remove_remote(filename):
"""
Wrapper that can remove local and remote files like `gs://...`
"""
# Conditional import
return gfile.remove_remote(filename)

Просмотреть файл

@ -18,7 +18,6 @@ from .audio import (
get_audio_type_from_extension,
write_wav
)
from .io import open_remote, is_remote_path
BIG_ENDIAN = 'big'
INT_SIZE = 4
@ -60,25 +59,6 @@ class LabeledSample(Sample):
self.transcript = transcript
class PackedSample:
"""
A wrapper that we can carry around in an iterator and pass to a child process in order to
have the child process do the loading/unpacking of the sample, allowing for parallel file
I/O.
"""
def __init__(self, filename, audio_type, label):
self.filename = filename
self.audio_type = audio_type
self.label = label
def unpack(self):
with open_remote(self.filename, 'rb') as audio_file:
data = audio_file.read()
if self.label is None:
s = Sample(self.audio_type, data, sample_id=self.filename)
s = LabeledSample(self.audio_type, data, self.label, sample_id=self.filename)
return s
def load_sample(filename, label=None):
"""
Loads audio-file as a (labeled or unlabeled) sample
@ -89,19 +69,21 @@ def load_sample(filename, label=None):
Filename of the audio-file to load as sample
label : str
Label (transcript) of the sample.
If None: returned result.unpack() will return util.audio.Sample instance
Otherwise: returned result.unpack() util.sample_collections.LabeledSample instance
If None: return util.audio.Sample instance
Otherwise: return util.sample_collections.LabeledSample instance
Returns
-------
util.sample_collections.PackedSample, a wrapper object, on which calling unpack() will return
util.audio.Sample instance if label is None, else util.sample_collections.LabeledSample instance
util.audio.Sample instance if label is None, else util.sample_collections.LabeledSample instance
"""
ext = os.path.splitext(filename)[1].lower()
audio_type = get_audio_type_from_extension(ext)
if audio_type is None:
raise ValueError('Unknown audio type extension "{}"'.format(ext))
return PackedSample(filename, audio_type, label)
with open(filename, 'rb') as audio_file:
if label is None:
return Sample(audio_type, audio_file.read(), sample_id=filename)
return LabeledSample(audio_type, audio_file.read(), label, sample_id=filename)
class DirectSDBWriter:
@ -137,7 +119,7 @@ class DirectSDBWriter:
raise ValueError('Audio type "{}" not supported'.format(audio_type))
self.audio_type = audio_type
self.bitrate = bitrate
self.sdb_file = open_remote(sdb_filename, 'wb', buffering=buffering)
self.sdb_file = open(sdb_filename, 'wb', buffering=buffering)
self.offsets = []
self.num_samples = 0
@ -233,7 +215,7 @@ class SDB: # pylint: disable=too-many-instance-attributes
"""
self.sdb_filename = sdb_filename
self.id_prefix = sdb_filename if id_prefix is None else id_prefix
self.sdb_file = open_remote(sdb_filename, 'rb', buffering=REVERSE_BUFFER_SIZE if reverse else buffering)
self.sdb_file = open(sdb_filename, 'rb', buffering=REVERSE_BUFFER_SIZE if reverse else buffering)
self.offsets = []
if self.sdb_file.read(len(MAGIC)) != MAGIC:
raise RuntimeError('No Sample Database')
@ -350,8 +332,6 @@ class CSVWriter: # pylint: disable=too-many-instance-attributes
labeled : bool or None
If True: Writes labeled samples (util.sample_collections.LabeledSample) only.
If False: Ignores transcripts (if available) and writes (unlabeled) util.audio.Sample instances.
Currently only works with local files (not gs:// or hdfs://...)
"""
self.csv_filename = Path(csv_filename)
self.csv_base_dir = self.csv_filename.parent.resolve().absolute()
@ -365,7 +345,7 @@ class CSVWriter: # pylint: disable=too-many-instance-attributes
self.labeled = labeled
if labeled:
fieldnames.append('transcript')
self.csv_file = open_remote(csv_filename, 'w', encoding='utf-8', newline='')
self.csv_file = open(csv_filename, 'w', encoding='utf-8', newline='')
self.csv_writer = csv.DictWriter(self.csv_file, fieldnames=fieldnames)
self.csv_writer.writeheader()
self.counter = 0
@ -400,7 +380,7 @@ class CSVWriter: # pylint: disable=too-many-instance-attributes
class TarWriter: # pylint: disable=too-many-instance-attributes
"""Sample collection writer for writing a CSV data-set and all its referenced WAV samples to a tar file."""
"""Sample collection writer for writing a CSV data-set and all its referenced WAV samples to a tar file"""
def __init__(self,
tar_filename,
gz=False,
@ -418,8 +398,6 @@ class TarWriter: # pylint: disable=too-many-instance-attributes
If False: Ignores transcripts (if available) and writes (unlabeled) util.audio.Sample instances.
include : str[]
List of files to include into tar root.
Currently only works with local files (not gs:// or hdfs://...)
"""
self.tar = tarfile.open(tar_filename, 'w:gz' if gz else 'w')
samples_dir = tarfile.TarInfo('samples')
@ -520,7 +498,8 @@ class CSV(SampleList):
If the order of the samples should be reversed
"""
rows = []
with open_remote(csv_filename, 'r', encoding='utf8') as csv_file:
csv_dir = Path(csv_filename).parent
with open(csv_filename, 'r', encoding='utf8') as csv_file:
reader = csv.DictReader(csv_file)
if 'transcript' in reader.fieldnames:
if labeled is None:
@ -529,12 +508,9 @@ class CSV(SampleList):
raise RuntimeError('No transcript data (missing CSV column)')
for row in reader:
wav_filename = Path(row['wav_filename'])
if not wav_filename.is_absolute() and not is_remote_path(row['wav_filename']):
wav_filename = Path(csv_filename).parent / wav_filename
wav_filename = str(wav_filename)
else:
# Pathlib otherwise removes a / from filenames like hdfs://
wav_filename = row['wav_filename']
if not wav_filename.is_absolute():
wav_filename = csv_dir / wav_filename
wav_filename = str(wav_filename)
wav_filesize = int(row['wav_filesize']) if 'wav_filesize' in row else 0
if labeled:
rows.append((wav_filename, wav_filesize, row['transcript']))

Просмотреть файл

@ -14,7 +14,6 @@ import sys
from pkg_resources import parse_version
from .io import isdir_remote, open_remote, is_remote_path
DEFAULT_SCHEMES = {
'deepspeech': 'https://community-tc.services.mozilla.com/api/index/v1/task/project.deepspeech.deepspeech.native_client.%(branch_name)s.%(arch_string)s/artifacts/public/%(artifact_name)s',
@ -43,13 +42,13 @@ def maybe_download_tc(target_dir, tc_url, progress=True):
assert target_dir is not None
if not is_remote_path(target_dir):
try:
os.makedirs(target_dir)
except OSError as e:
if e.errno != errno.EEXIST:
raise e
assert os.path.isdir(os.path.dirname(target_dir))
target_dir = os.path.abspath(target_dir)
try:
os.makedirs(target_dir)
except OSError as e:
if e.errno != errno.EEXIST:
raise e
assert os.path.isdir(os.path.dirname(target_dir))
tc_filename = os.path.basename(tc_url)
target_file = os.path.join(target_dir, tc_filename)
@ -62,7 +61,7 @@ def maybe_download_tc(target_dir, tc_url, progress=True):
print('File already exists: %s' % target_file)
if is_gzip:
with open_remote(target_file, "r+b") as frw:
with open(target_file, "r+b") as frw:
decompressed = gzip.decompress(frw.read())
frw.seek(0)
frw.write(decompressed)