зеркало из https://github.com/mozilla/DSAlign.git
718 строки
33 KiB
Python
718 строки
33 KiB
Python
import os
|
|
import io
|
|
import sys
|
|
import csv
|
|
import math
|
|
import time
|
|
import json
|
|
import wave
|
|
import pickle
|
|
import random
|
|
import tarfile
|
|
import logging
|
|
import argparse
|
|
import statistics
|
|
import os.path as path
|
|
|
|
from datetime import timedelta
|
|
from collections import Counter
|
|
from multiprocessing import Pool
|
|
from audio import AUDIO_TYPE_PCM, AUDIO_TYPE_WAV, AUDIO_TYPE_OPUS,\
|
|
ensure_wav_with_format, extract_audio, change_audio_types, write_audio_format_to_wav_file, verify_wav_file
|
|
from sample_collections import SortingSDBWriter, LabeledSample
|
|
from utils import parse_file_size, log_progress
|
|
|
|
UNKNOWN = '<UNKNOWN>'
|
|
AUDIO_TYPE_LOOKUP = {
|
|
'wav': AUDIO_TYPE_WAV,
|
|
'opus': AUDIO_TYPE_OPUS
|
|
}
|
|
SET_NAMES = ['train', 'dev', 'test']
|
|
|
|
|
|
class Fragment:
|
|
def __init__(self, catalog_index, alignment_index, quality=0, duration=0):
|
|
self.catalog_index = catalog_index
|
|
self.alignment_index = alignment_index
|
|
self.quality = quality
|
|
self.duration = duration
|
|
self.meta = {}
|
|
self.partition = 'other'
|
|
self.list_name = 'other'
|
|
|
|
|
|
def progress(it=None, desc=None, total=None):
|
|
if desc is not None:
|
|
logging.info(desc)
|
|
return it if CLI_ARGS.no_progress else log_progress(it, interval=CLI_ARGS.progress_interval, total=total)
|
|
|
|
|
|
def fail(message, code=1):
|
|
logging.fatal(message)
|
|
exit(code)
|
|
|
|
|
|
def check_path(target_path, fs_type='file'):
|
|
if not (path.isfile(target_path) if fs_type == 'file' else path.isdir(target_path)):
|
|
fail('{} not existing: "{}"'.format(fs_type[0].upper() + fs_type[1:], target_path))
|
|
return path.abspath(target_path)
|
|
|
|
|
|
def make_absolute(base_path, spec_path):
|
|
if not path.isabs(spec_path):
|
|
spec_path = path.join(base_path, spec_path)
|
|
spec_path = path.abspath(spec_path)
|
|
return spec_path if path.isfile(spec_path) else None
|
|
|
|
|
|
def engroup(lst, get_key):
|
|
groups = {}
|
|
for obj in lst:
|
|
key = get_key(obj)
|
|
if key in groups:
|
|
groups[key].append(obj)
|
|
else:
|
|
groups[key] = [obj]
|
|
return groups
|
|
|
|
|
|
def get_sample_size(population_size):
|
|
margin_of_error = 0.01
|
|
fraction_picking = 0.50
|
|
z_score = 2.58 # Corresponds to confidence level 99%
|
|
numerator = (z_score ** 2 * fraction_picking * (1 - fraction_picking)) / (
|
|
margin_of_error ** 2
|
|
)
|
|
sample_size = 0
|
|
for train_size in range(population_size, 0, -1):
|
|
denominator = 1 + (z_score ** 2 * fraction_picking * (1 - fraction_picking)) / (
|
|
margin_of_error ** 2 * train_size
|
|
)
|
|
sample_size = int(numerator / denominator)
|
|
if 2 * sample_size + train_size <= population_size:
|
|
break
|
|
return sample_size
|
|
|
|
|
|
def load_catalog():
|
|
catalog_entries = []
|
|
if CLI_ARGS.audio:
|
|
if CLI_ARGS.aligned:
|
|
catalog_entries.append((check_path(CLI_ARGS.audio), check_path(CLI_ARGS.aligned)))
|
|
else:
|
|
fail('If you specify "--audio", you also have to specify "--aligned"')
|
|
elif CLI_ARGS.aligned:
|
|
fail('If you specify "--aligned", you also have to specify "--audio"')
|
|
elif CLI_ARGS.catalog:
|
|
catalog = check_path(CLI_ARGS.catalog)
|
|
catalog_dir = path.dirname(catalog)
|
|
with open(catalog, 'r', encoding='utf-8') as catalog_file:
|
|
catalog_file_entries = json.load(catalog_file)
|
|
for entry in progress(catalog_file_entries, desc='Reading catalog'):
|
|
audio = make_absolute(catalog_dir, entry['audio'])
|
|
aligned = make_absolute(catalog_dir, entry['aligned'])
|
|
if audio is None or aligned is None:
|
|
if CLI_ARGS.ignore_missing:
|
|
continue
|
|
if audio is None:
|
|
fail('Problem loading catalog "{}": Missing referenced audio file "{}"'
|
|
.format(CLI_ARGS.catalog, entry['audio']))
|
|
if aligned is None:
|
|
fail('Problem loading catalog "{}": Missing referenced alignment file "{}"'
|
|
.format(CLI_ARGS.catalog, entry['aligned']))
|
|
catalog_entries.append((audio, aligned))
|
|
else:
|
|
fail('You have to either specify "--audio" and "--aligned" or "--catalog"')
|
|
return catalog_entries
|
|
|
|
|
|
def load_fragments(catalog_entries):
|
|
def get_meta_list(ae, mf):
|
|
if 'meta' in ae:
|
|
meta_fields = ae['meta']
|
|
if isinstance(meta_fields, dict) and mf in meta_fields:
|
|
mf = meta_fields[mf]
|
|
return mf if isinstance(mf, list) else [mf]
|
|
return []
|
|
|
|
required_metas = {}
|
|
if CLI_ARGS.debias is not None:
|
|
for debias_meta_field in CLI_ARGS.debias:
|
|
required_metas[debias_meta_field] = True
|
|
if CLI_ARGS.split and CLI_ARGS.split_field is not None:
|
|
required_metas[CLI_ARGS.split_field] = True
|
|
|
|
fragments = []
|
|
reasons = Counter()
|
|
for catalog_index, catalog_entry in enumerate(progress(catalog_entries, desc='Loading alignments')):
|
|
audio_path, aligned_path = catalog_entry
|
|
with open(aligned_path, 'r', encoding='utf-8') as aligned_file:
|
|
aligned = json.load(aligned_file)
|
|
for alignment_index, alignment in enumerate(aligned):
|
|
quality = eval(CLI_ARGS.criteria, {'math': math}, alignment)
|
|
alignment['quality'] = quality
|
|
if eval(CLI_ARGS.filter, {'math': math}, alignment):
|
|
reasons['Filter'] += 1
|
|
continue
|
|
meta = {}
|
|
keep = True
|
|
for meta_field in required_metas.keys():
|
|
meta_list = get_meta_list(alignment, meta_field)
|
|
if CLI_ARGS.split and CLI_ARGS.split_field == meta_field:
|
|
if CLI_ARGS.split_drop_multiple and len(meta_list) > 1:
|
|
reasons['Split drop multiple'] += 1
|
|
keep = False
|
|
break
|
|
elif CLI_ARGS.split_drop_unknown and len(meta_list) == 0:
|
|
reasons['Split drop unknown'] += 1
|
|
keep = False
|
|
break
|
|
meta[meta_field] = meta_list[0] if meta_list else UNKNOWN
|
|
if keep:
|
|
duration = alignment['end'] - alignment['start']
|
|
fragment = Fragment(catalog_index, alignment_index, quality=quality, duration=duration)
|
|
fragment.meta = meta
|
|
for minimum, partition_name in CLI_ARGS.partition:
|
|
if quality >= minimum:
|
|
fragment.partition = partition_name
|
|
break
|
|
fragments.append(fragment)
|
|
|
|
if len(fragments) == 0:
|
|
fail('No samples left for export')
|
|
|
|
if len(reasons.keys()) > 0:
|
|
logging.info('Excluded number of samples (for each reason):')
|
|
for reason, count in reasons.most_common():
|
|
logging.info(' - "{}": {}'.format(reason, count))
|
|
return fragments
|
|
|
|
|
|
def debias(fragments):
|
|
if CLI_ARGS.debias is not None:
|
|
for debias in CLI_ARGS.debias:
|
|
grouped = engroup(fragments, lambda f: f.meta[debias])
|
|
if UNKNOWN in grouped:
|
|
fragments = grouped[UNKNOWN]
|
|
del grouped[UNKNOWN]
|
|
else:
|
|
fragments = []
|
|
counts = list(map(lambda f: len(f), grouped.values()))
|
|
mean = statistics.mean(counts)
|
|
sigma = statistics.pstdev(counts, mu=mean)
|
|
cap = int(mean + CLI_ARGS.debias_sigma_factor * sigma)
|
|
counter = Counter()
|
|
for group, group_fragments in progress(grouped.items(), desc='De-biasing "{}"'.format(debias)):
|
|
if len(group_fragments) > cap:
|
|
group_fragments.sort(key=lambda f: f.quality)
|
|
counter[group] += len(group_fragments) - cap
|
|
group_fragments = group_fragments[-cap:]
|
|
fragments.extend(group_fragments)
|
|
if len(counter.keys()) > 0:
|
|
logging.info('Dropped for de-biasing "{}":'.format(debias))
|
|
for group, count in counter.most_common():
|
|
logging.info(' - "{}": {}'.format(group, count))
|
|
return fragments
|
|
|
|
|
|
def parse_set_assignments():
|
|
set_assignments = {}
|
|
for set_index, set_name in enumerate(SET_NAMES):
|
|
attr_name = 'assign_' + set_name
|
|
if hasattr(CLI_ARGS, attr_name):
|
|
set_entities = getattr(CLI_ARGS, attr_name)
|
|
if set_entities is not None:
|
|
for entity_id in filter(None, str(set_entities).split(',')):
|
|
if entity_id in set_assignments:
|
|
fail('Unable to assign entity "{}" to set "{}", as it is already assigned to set "{}"'
|
|
.format(entity_id, set_name, SET_NAMES[set_assignments[entity_id]]))
|
|
set_assignments[entity_id] = set_index
|
|
return set_assignments
|
|
|
|
|
|
def check_targets():
|
|
if CLI_ARGS.target_dir is not None and CLI_ARGS.target_tar is not None:
|
|
fail('Only one allowed: --target-dir or --target-tar')
|
|
elif CLI_ARGS.target_dir is not None:
|
|
CLI_ARGS.target_dir = check_path(CLI_ARGS.target_dir, fs_type='directory')
|
|
elif CLI_ARGS.target_tar is not None:
|
|
if CLI_ARGS.sdb:
|
|
fail('Option --sdb not supported for --target-tar output. Use --target-dir instead.')
|
|
CLI_ARGS.target_tar = path.abspath(CLI_ARGS.target_tar)
|
|
if path.isfile(CLI_ARGS.target_tar):
|
|
if not CLI_ARGS.force:
|
|
fail('Target tar-file already existing - use --force to overwrite')
|
|
elif path.exists(CLI_ARGS.target_tar):
|
|
fail('Target tar-file path is existing, but not a file')
|
|
elif not path.isdir(path.dirname(CLI_ARGS.target_tar)):
|
|
fail('Unable to write tar-file: Path not existing')
|
|
else:
|
|
fail('Either --target-dir or --target-tar has to be provided')
|
|
|
|
|
|
def split(fragments, set_assignments):
|
|
lists = []
|
|
|
|
def assign_fragments(frags, name):
|
|
lists.append(name)
|
|
duration = 0
|
|
for f in frags:
|
|
f.list_name = name
|
|
duration += f.duration
|
|
logging.info('Built set "{}" (samples: {}, duration: {})'
|
|
.format(name, len(frags), timedelta(milliseconds=duration)))
|
|
|
|
if CLI_ARGS.split_seed is not None:
|
|
random.seed(CLI_ARGS.split_seed)
|
|
|
|
if CLI_ARGS.split and CLI_ARGS.split_field is not None:
|
|
fragments = list(fragments)
|
|
metas = engroup(fragments, lambda f: f.meta[CLI_ARGS.split_field]).items()
|
|
metas = sorted(metas, key=lambda meta_frags: len(meta_frags[1]))
|
|
metas = list(map(lambda meta_frags: meta_frags[0], metas))
|
|
partitions = engroup(fragments, lambda f: f.partition)
|
|
partitions = list(map(lambda part_frags: (part_frags[0],
|
|
get_sample_size(len(part_frags[1])),
|
|
engroup(part_frags[1], lambda f: f.meta[CLI_ARGS.split_field]),
|
|
[[], [], []]),
|
|
partitions.items()))
|
|
remaining_metas = []
|
|
for meta in metas:
|
|
if meta in set_assignments:
|
|
set_index = set_assignments[meta]
|
|
for _, _, partition_portions, sample_sets in partitions:
|
|
if meta in partition_portions:
|
|
sample_sets[set_index].extend(partition_portions[meta])
|
|
del partition_portions[meta]
|
|
else:
|
|
remaining_metas.append(meta)
|
|
metas = remaining_metas
|
|
for _, sample_size, _, sample_sets in partitions:
|
|
while len(metas) > 0 and (len(sample_sets[1]) < sample_size or len(sample_sets[2]) < sample_size):
|
|
for sample_set_index in [1, 2]:
|
|
if len(metas) > 0 and sample_size > len(sample_sets[sample_set_index]):
|
|
meta = metas.pop(0)
|
|
for _, _, partition_portions, other_sample_sets in partitions:
|
|
if meta in partition_portions:
|
|
other_sample_sets[sample_set_index].extend(partition_portions[meta])
|
|
del partition_portions[meta]
|
|
for partition, sample_size, partition_portions, sample_sets in partitions:
|
|
for portion in partition_portions.values():
|
|
sample_sets[0].extend(portion)
|
|
for set_index, set_name in enumerate(SET_NAMES):
|
|
assign_fragments(sample_sets[set_index], partition + '-' + set_name)
|
|
else:
|
|
partitions = engroup(fragments, lambda f: f.partition)
|
|
for partition, partition_fragments in partitions.items():
|
|
if CLI_ARGS.split:
|
|
sample_size = get_sample_size(len(partition_fragments))
|
|
random.shuffle(partition_fragments)
|
|
test_set = partition_fragments[:sample_size]
|
|
partition_fragments = partition_fragments[sample_size:]
|
|
dev_set = partition_fragments[:sample_size]
|
|
train_set = partition_fragments[sample_size:]
|
|
sample_sets = [train_set, dev_set, test_set]
|
|
for set_index, set_name in enumerate(SET_NAMES):
|
|
assign_fragments(sample_sets[set_index], partition + '-' + set_name)
|
|
else:
|
|
assign_fragments(partition_fragments, partition)
|
|
return lists
|
|
|
|
|
|
def check_overwrite(lists):
|
|
if CLI_ARGS.target_dir is not None and not CLI_ARGS.force:
|
|
for name in lists:
|
|
suffixes = ['.meta'] + (['.sdb', '.sdb.tmp'] if CLI_ARGS.sdb else ['', '.csv'])
|
|
for s in suffixes:
|
|
p = path.join(CLI_ARGS.target_dir, name + s)
|
|
if path.exists(p):
|
|
fail('"{}" already existing - use --force to ignore'.format(p))
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(description='Export aligned speech samples.')
|
|
|
|
parser.add_argument('--plan', type=str,
|
|
help='Export plan (preparation-cache) to load and/or store')
|
|
parser.add_argument('--audio', type=str,
|
|
help='Take audio file as input (requires "--aligned <file>")')
|
|
parser.add_argument('--aligned', type=str,
|
|
help='Take alignment file ("<...>.aligned") as input (requires "--audio <file>")')
|
|
parser.add_argument('--catalog', type=str,
|
|
help='Take alignment and audio file references of provided catalog ("<...>.catalog") as input')
|
|
parser.add_argument('--ignore-missing', action="store_true",
|
|
help='Ignores catalog entries with missing files')
|
|
|
|
parser.add_argument('--filter', type=str, default='False',
|
|
help='Python expression that computes a boolean value from sample data fields. '
|
|
'If the result is True, the sample will be dropped.')
|
|
|
|
parser.add_argument('--criteria', type=str, default='100',
|
|
help='Python expression that computes a number as quality indicator from sample data fields.')
|
|
|
|
parser.add_argument('--debias', type=str, action='append',
|
|
help='Sample meta field to group samples for debiasing (e.g. "speaker"). '
|
|
'Group sizes will be capped according to --debias-sigma-factor')
|
|
parser.add_argument('--debias-sigma-factor', type=float, default=3.0,
|
|
help='Standard deviation (sigma) factor that will determine '
|
|
'the maximum number of samples per group (see --debias).')
|
|
|
|
parser.add_argument('--partition', type=str, action='append',
|
|
help='Expression of the form "<number>:<partition>" where all samples with a quality indicator '
|
|
'(--criteria) above or equal the given number and below the next bigger one are assigned '
|
|
'to the specified partition. Samples below the lowest partition criteria are assigned to '
|
|
'partition "other".')
|
|
|
|
parser.add_argument('--split', action="store_true",
|
|
help='Split each partition except "other" into train/dev/test sub-sets.')
|
|
parser.add_argument('--split-field', type=str,
|
|
help='Sample meta field that should be used for splitting (e.g. "speaker")')
|
|
parser.add_argument('--split-drop-multiple', action="store_true",
|
|
help='Drop all samples with multiple --split-field assignments.')
|
|
parser.add_argument('--split-drop-unknown', action="store_true",
|
|
help='Drop all samples with no --split-field assignment.')
|
|
for sub_set in SET_NAMES:
|
|
parser.add_argument('--assign-' + sub_set,
|
|
help='Comma separated list of --split-field values that are to be assigned to sub-set '
|
|
'"{}"'.format(sub_set))
|
|
parser.add_argument('--split-seed', type=int,
|
|
help='Random seed for set splitting')
|
|
|
|
parser.add_argument('--target-dir', type=str, required=False,
|
|
help='Existing target directory for storing generated sets (files and directories)')
|
|
parser.add_argument('--target-tar', type=str, required=False,
|
|
help='Target tar-file for storing generated sets (files and directories)')
|
|
parser.add_argument('--sdb', action="store_true",
|
|
help='Writes Sample DBs instead of CSV and .wav files (requires --target-dir)')
|
|
parser.add_argument('--sdb-bucket-size', default='1GB',
|
|
help='Memory bucket size for external sorting of SDBs')
|
|
parser.add_argument('--sdb-workers', type=int, default=None,
|
|
help='Number of SDB encoding workers')
|
|
parser.add_argument('--sdb-buffered-samples', type=int, default=None,
|
|
help='Number of samples per bucket buffer during finalization')
|
|
parser.add_argument('--sdb-audio-type', default='opus', choices=AUDIO_TYPE_LOOKUP.keys(),
|
|
help='Audio representation inside target SDBs')
|
|
parser.add_argument('--tmp-dir', type=str, default=None,
|
|
help='Directory for temporary files - defaults to system one')
|
|
parser.add_argument('--buffer', default='1MB',
|
|
help='Buffer size for writing files (~16MB by default)')
|
|
parser.add_argument('--force', action="store_true",
|
|
help='Overwrite existing files')
|
|
parser.add_argument('--skip-damaged', action="store_true",
|
|
help='If damaged audio files and their samples should be skipped instead of failing export')
|
|
parser.add_argument('--no-meta', action="store_true",
|
|
help='No writing of meta data files')
|
|
parser.add_argument('--rate', type=int, default=16000,
|
|
help='Export wav-files with this sample rate')
|
|
parser.add_argument('--channels', type=int, default=1,
|
|
help='Export wav-files with this number of channels')
|
|
parser.add_argument('--width', type=int, default=2,
|
|
help='Export wav-files with this sample width (bytes)')
|
|
|
|
parser.add_argument('--workers', type=int, default=2,
|
|
help='Number of workers for loading and re-sampling audio files. Default: 2')
|
|
parser.add_argument('--dry-run', action="store_true",
|
|
help='Simulates export without writing or creating any file or directory')
|
|
parser.add_argument('--dry-run-fast', action="store_true",
|
|
help='Simulates export without writing or creating any file or directory. '
|
|
'In contrast to --dry-run this faster simulation will not load samples.')
|
|
parser.add_argument('--loglevel', type=int, default=20,
|
|
help='Log level (between 0 and 50) - default: 20')
|
|
parser.add_argument('--no-progress', action="store_true",
|
|
help='Prevents showing progress indication')
|
|
parser.add_argument('--progress-interval', type=float, default=1.0,
|
|
help='Progress indication interval in seconds')
|
|
|
|
args = parser.parse_args()
|
|
|
|
args.buffer = parse_file_size(args.buffer)
|
|
args.sdb_bucket_size = parse_file_size(args.sdb_bucket_size)
|
|
args.dry_run = args.dry_run or args.dry_run_fast
|
|
partition_specs = []
|
|
if args.partition is not None:
|
|
for partition_expr in args.partition:
|
|
parts = partition_expr.split(':')
|
|
if len(parts) != 2:
|
|
fail('Wrong partition specification: "{}"'.format(partition_expr))
|
|
partition_specs.append((float(parts[0]), str(parts[1])))
|
|
partition_specs.sort(key=lambda p: p[0], reverse=True)
|
|
args.partition = partition_specs
|
|
return args
|
|
|
|
|
|
def load_sample(entry):
|
|
catalog_index, catalog_entry = entry
|
|
audio_path, aligned_path = catalog_entry
|
|
with open(aligned_path, 'r', encoding='utf-8') as aligned_file:
|
|
aligned = json.load(aligned_file)
|
|
tries = 2
|
|
while tries > 0:
|
|
wav_path, wav_is_temp = ensure_wav_with_format(audio_path, audio_format, tmp_dir=CLI_ARGS.tmp_dir)
|
|
if wav_is_temp:
|
|
time.sleep(1)
|
|
if wav_path is not None:
|
|
if verify_wav_file(wav_path):
|
|
return catalog_index, wav_path, wav_is_temp, aligned
|
|
if wav_is_temp:
|
|
os.remove(wav_path)
|
|
logging.warn('Problem converting "{}" into required format - retrying...'.format(audio_path))
|
|
time.sleep(10)
|
|
tries -= 1
|
|
return catalog_index, None, False, aligned
|
|
|
|
|
|
def load_sample_dry(entry):
|
|
catalog_index, catalog_entry = entry
|
|
audio_path, aligned_path = catalog_entry
|
|
if path.isfile(audio_path):
|
|
logging.info('Would load file "{}"'.format(audio_path))
|
|
else:
|
|
fail('Audio file not found: "{}"'.format(audio_path))
|
|
if path.isfile(aligned_path):
|
|
logging.info('Would load file "{}"'.format(audio_path))
|
|
else:
|
|
fail('Alignment file not found: "{}"'.format(audio_path))
|
|
return catalog_index, '', False, []
|
|
|
|
|
|
def load_samples(catalog_entries, fragments):
|
|
catalog_index_wise = engroup(fragments, lambda f: f.catalog_index)
|
|
pool = Pool(CLI_ARGS.workers)
|
|
ls = load_sample_dry if CLI_ARGS.dry_run_fast else load_sample
|
|
indexed_entries = map(lambda ci: (ci, catalog_entries[ci]), catalog_index_wise.keys())
|
|
for catalog_index, wav_path, wav_is_temp, aligned in pool.imap_unordered(ls, indexed_entries):
|
|
if wav_path is None:
|
|
src_audio_path = catalog_entries[catalog_index][0]
|
|
message = 'Unable to convert audio file "{}" to required format - skipping'.format(src_audio_path)
|
|
if CLI_ARGS.skip_damaged:
|
|
logging.warn(message)
|
|
continue
|
|
else:
|
|
raise RuntimeError(message)
|
|
file_fragments = catalog_index_wise[catalog_index]
|
|
file_fragments.sort(key=lambda f: f.alignment_index)
|
|
if CLI_ARGS.dry_run_fast:
|
|
for fragment in file_fragments:
|
|
yield b'', fragment, ''
|
|
else:
|
|
with wave.open(wav_path, 'rb') as source_wav_file:
|
|
wav_duration = source_wav_file.getframerate() * source_wav_file.getnframes() * 1000
|
|
for fragment in file_fragments:
|
|
aligned_entry = aligned[fragment.alignment_index]
|
|
try:
|
|
start, end = aligned_entry['start'], aligned_entry['end']
|
|
assert start < end <= wav_duration
|
|
fragment_audio = extract_audio(source_wav_file, start / 1000.0, end / 1000.0)
|
|
except Exception as ae:
|
|
message = 'Problem extracting audio for alignment entry {} in catalog entry {}'\
|
|
.format(fragment.alignment_index, fragment.catalog_index)
|
|
if CLI_ARGS.skip_damaged:
|
|
logging.warn(message)
|
|
break
|
|
else:
|
|
raise RuntimeError(message) from ae
|
|
yield fragment_audio, fragment, aligned_entry['aligned']
|
|
if wav_is_temp:
|
|
os.remove(wav_path)
|
|
|
|
|
|
def write_meta(file, catalog_entries, id_plus_fragment_iter, total=None):
|
|
writer = csv.writer(file)
|
|
writer.writerow(['sample', 'split_entity', 'catalog_index', 'source_audio_file', 'aligned_file', 'alignment_index'])
|
|
has_split_entity = CLI_ARGS.split and CLI_ARGS.split_field is not None
|
|
for sample_id, fragment in progress(id_plus_fragment_iter, total=total):
|
|
split_entity = fragment.meta[CLI_ARGS.split_field] if has_split_entity else ''
|
|
source_audio_file, aligned_file = catalog_entries[fragment.catalog_index]
|
|
writer.writerow([sample_id,
|
|
split_entity,
|
|
fragment.catalog_index,
|
|
source_audio_file,
|
|
aligned_file,
|
|
fragment.alignment_index])
|
|
|
|
|
|
def write_csvs_and_samples(catalog_entries, lists, fragments):
|
|
created_directories = {}
|
|
tar = None
|
|
if CLI_ARGS.target_tar is not None:
|
|
if CLI_ARGS.dry_run:
|
|
logging.info('Would create tar-file "{}"'.format(CLI_ARGS.target_tar))
|
|
else:
|
|
base_tar = open(CLI_ARGS.target_tar, 'wb', buffering=CLI_ARGS.buffer)
|
|
tar = tarfile.open(fileobj=base_tar, mode='w')
|
|
|
|
class TargetFile:
|
|
def __init__(self, data_path, mode):
|
|
self.data_path = data_path
|
|
self.mode = mode
|
|
self.open_file = None
|
|
|
|
def __enter__(self):
|
|
parts = self.data_path.split('/')
|
|
dirs = ([CLI_ARGS.target_dir] if CLI_ARGS.target_dir is not None else []) + parts[:-1]
|
|
for i in range(1, len(dirs)):
|
|
vp = '/'.join(dirs[:i + 1])
|
|
if not vp in created_directories:
|
|
if tar is None:
|
|
dir_path = path.join(*dirs[:i + 1])
|
|
if not path.isdir(dir_path):
|
|
if CLI_ARGS.dry_run:
|
|
logging.info('Would create directory "{}"'.format(dir_path))
|
|
else:
|
|
os.mkdir(dir_path)
|
|
else:
|
|
tdir = tarfile.TarInfo(vp)
|
|
tdir.type = tarfile.DIRTYPE
|
|
tar.addfile(tdir)
|
|
created_directories[vp] = True
|
|
if CLI_ARGS.target_tar is None:
|
|
file_path = path.join(CLI_ARGS.target_dir, *self.data_path.split('/'))
|
|
if CLI_ARGS.dry_run:
|
|
logging.info('Would write file "{}"'.format(file_path))
|
|
self.open_file = io.BytesIO() if 'b' in self.mode else io.StringIO()
|
|
else:
|
|
self.open_file = open(file_path, self.mode)
|
|
else:
|
|
self.open_file = io.BytesIO() if 'b' in self.mode else io.StringIO()
|
|
return self.open_file
|
|
|
|
def __exit__(self, *args):
|
|
if tar is not None:
|
|
if isinstance(self.open_file, io.StringIO):
|
|
sfile = self.open_file
|
|
sfile.seek(0)
|
|
self.open_file = io.BytesIO(sfile.read().encode('utf8'))
|
|
self.open_file.seek(0, 2)
|
|
sfile.close()
|
|
tfile = tarfile.TarInfo(self.data_path)
|
|
tfile.size = self.open_file.tell()
|
|
self.open_file.seek(0)
|
|
tar.addfile(tfile, self.open_file)
|
|
tar.members = []
|
|
if self.open_file is not None:
|
|
self.open_file.close()
|
|
|
|
group_lists = {}
|
|
for list_name in lists:
|
|
group_lists[list_name] = []
|
|
|
|
for pcm_data, fragment, transcript in progress(load_samples(catalog_entries, fragments),
|
|
desc='Exporting samples', total=len(fragments)):
|
|
group_list = group_lists[fragment.list_name]
|
|
sample_path = '{}/sample-{:010d}.wav'.format(fragment.list_name, len(group_list))
|
|
with TargetFile(sample_path, "wb") as base_wav_file:
|
|
with wave.open(base_wav_file, 'wb') as wav_file:
|
|
write_audio_format_to_wav_file(wav_file)
|
|
wav_file.writeframes(pcm_data)
|
|
file_size = base_wav_file.tell()
|
|
group_list.append((sample_path, file_size, fragment, transcript))
|
|
|
|
for list_name, group_list in group_lists.items():
|
|
csv_filename = list_name + '.csv'
|
|
logging.info('Writing "{}"'.format(csv_filename))
|
|
with TargetFile(csv_filename, 'w') as csv_file:
|
|
writer = csv.writer(csv_file)
|
|
writer.writerow(['wav_filename', 'wav_filesize', 'transcript'])
|
|
for rel_path, file_size, fragment, transcript in progress(group_list):
|
|
writer.writerow([rel_path, file_size, transcript])
|
|
if not CLI_ARGS.no_meta:
|
|
meta_filename = list_name + '.meta'
|
|
logging.info('Writing "{}"'.format(meta_filename))
|
|
with TargetFile(meta_filename, 'w') as meta_file:
|
|
path_fragment_list = map(lambda gi: (gi[0], gi[2]), group_list)
|
|
write_meta(meta_file, catalog_entries, path_fragment_list, total=len(group_list))
|
|
|
|
if tar is not None:
|
|
tar.close()
|
|
|
|
|
|
def write_sdbs(catalog_entries, lists, fragments):
|
|
audio_type = AUDIO_TYPE_LOOKUP[CLI_ARGS.sdb_audio_type]
|
|
sdbs = {}
|
|
for list_name in lists:
|
|
sdb_path = os.path.join(CLI_ARGS.target_dir, list_name + '.sdb')
|
|
if CLI_ARGS.dry_run:
|
|
logging.info('Would create SDB "{}"'.format(sdb_path))
|
|
else:
|
|
logging.info('Creating SDB "{}"'.format(sdb_path))
|
|
sdbs[list_name] = SortingSDBWriter(sdb_path,
|
|
audio_type=audio_type,
|
|
buffering=CLI_ARGS.buffer,
|
|
cache_size=CLI_ARGS.sdb_bucket_size,
|
|
buffered_samples=CLI_ARGS.sdb_buffered_samples)
|
|
|
|
def to_samples():
|
|
for pcm_data, fragment, transcript in load_samples(catalog_entries, fragments):
|
|
cs = LabeledSample(AUDIO_TYPE_PCM, pcm_data, transcript, audio_format=audio_format)
|
|
cs.meta = fragment
|
|
yield cs
|
|
|
|
samples = change_audio_types(to_samples(),
|
|
audio_type=audio_type,
|
|
processes=CLI_ARGS.sdb_workers) if not CLI_ARGS.dry_run_fast else to_samples()
|
|
set_counter = Counter()
|
|
for sample in progress(samples, desc='Exporting samples', total=len(fragments)):
|
|
list_name = sample.meta.list_name
|
|
if not CLI_ARGS.dry_run:
|
|
set_counter[list_name] += 1
|
|
sdb = sdbs[list_name]
|
|
sdb.add(sample)
|
|
for list_name, sdb in sdbs.items():
|
|
meta_path = os.path.join(CLI_ARGS.target_dir, list_name + '.meta')
|
|
if CLI_ARGS.dry_run:
|
|
if not CLI_ARGS.no_meta:
|
|
logging.info('Would write meta file "{}"'.format(meta_path))
|
|
else:
|
|
sdb_path = os.path.join(CLI_ARGS.target_dir, list_name + '.sdb')
|
|
for _ in progress(sdb.finalize(), desc='Finalizing "{}"'.format(sdb_path), total=set_counter[list_name]):
|
|
pass
|
|
if not CLI_ARGS.no_meta:
|
|
logging.info('Writing "{}"'.format(meta_path))
|
|
with open(meta_path, 'w') as meta_file:
|
|
write_meta(meta_file, catalog_entries, enumerate(sdb.meta_list), total=len(sdb.meta_list))
|
|
|
|
|
|
def load_plan():
|
|
if CLI_ARGS.plan is not None and os.path.isfile(CLI_ARGS.plan):
|
|
try:
|
|
logging.info('Loading export-plan from "{}"'.format(CLI_ARGS.plan))
|
|
with open(CLI_ARGS.plan, 'rb') as plan_file:
|
|
catalog_entries, lists, fragments = pickle.load(plan_file)
|
|
return True, catalog_entries, lists, fragments
|
|
except pickle.PickleError:
|
|
logging.warn('Unable to load export-plan "{}" - rebuilding'.format(CLI_ARGS.plan))
|
|
os.remove(CLI_ARGS.plan)
|
|
return False, None, None, None
|
|
|
|
|
|
def save_plan(catalog_entries, lists, fragments):
|
|
if CLI_ARGS.plan is not None:
|
|
logging.info('Saving export-plan to "{}"'.format(CLI_ARGS.plan))
|
|
with open(CLI_ARGS.plan, 'wb') as plan_file:
|
|
pickle.dump((catalog_entries, lists, fragments), plan_file)
|
|
|
|
|
|
def main():
|
|
check_targets()
|
|
has_plan, catalog_entries, lists, fragments = load_plan()
|
|
if not has_plan:
|
|
set_assignments = parse_set_assignments()
|
|
catalog_entries = load_catalog()
|
|
fragments = load_fragments(catalog_entries)
|
|
fragments = debias(fragments)
|
|
lists = split(fragments, set_assignments)
|
|
save_plan(catalog_entries, lists, fragments)
|
|
check_overwrite(lists)
|
|
if CLI_ARGS.sdb:
|
|
write_sdbs(catalog_entries, lists, fragments)
|
|
else:
|
|
write_csvs_and_samples(catalog_entries, lists, fragments)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
CLI_ARGS = parse_args()
|
|
audio_format = (CLI_ARGS.rate, CLI_ARGS.channels, CLI_ARGS.width)
|
|
logging.basicConfig(stream=sys.stderr, level=CLI_ARGS.loglevel)
|
|
logging.getLogger('sox').setLevel(logging.ERROR)
|
|
main()
|