This commit is contained in:
Tilman Kamp 2020-02-07 18:29:30 +01:00
Родитель fba9c8971d
Коммит c3e879a642
4 изменённых файлов: 405 добавлений и 2 удалений

Просмотреть файл

@ -16,7 +16,10 @@ from tqdm import tqdm
from datetime import timedelta
from collections import Counter
from multiprocessing import Pool
from audio import DEFAULT_FORMAT, ensure_wav_with_format, extract_audio, write_audio_format_to_wav_file
from audio import DEFAULT_FORMAT, AUDIO_TYPE_PCM, AUDIO_TYPE_OPUS,\
ensure_wav_with_format, extract_audio, convert_samples, write_audio_format_to_wav_file
from sdb import SortingSDBWriter, CollectionSample
from utils import MEGABYTE, parse_file_size
audio_format = DEFAULT_FORMAT
unknown = '<unknown>'
@ -114,7 +117,13 @@ def main(args):
help='Existing target directory for storing generated sets (files and directories)')
parser.add_argument('--target-tar', type=str, required=False,
help='Target tar-file for storing generated sets (files and directories)')
parser.add_argument('--buffer', type=int, default=2 << 23,
parser.add_argument('--sdb', action="store_true",
help='Writes Sample DBs instead of CSV and .wav files (requires --target-dir)')
parser.add_argument('--sdb-bucket-size', default='1GB',
help='Memory bucket size for external sorting of SDBs')
parser.add_argument('--sdb-worker-factor', type=float, default=1.0,
help='CPU core factor for the number of Opus encoding workers (0 -> 1 worker)')
parser.add_argument('--buffer', default='1MB',
help='Buffer size for writing files (~16MB by default)')
parser.add_argument('--force', action="store_true",
help='Overwrite existing files')
@ -143,6 +152,9 @@ def main(args):
args = parser.parse_args()
args.buffer = parse_file_size(args.buffer)
args.sdb_bucket_size = parse_file_size(args.sdb_bucket_size)
logging.basicConfig(stream=sys.stderr, level=args.loglevel if args.loglevel else 20)
logging.getLogger('sox').setLevel(logging.ERROR)
@ -171,6 +183,8 @@ def main(args):
elif args.target_dir is not None:
target_dir = check_path(args.target_dir, fs_type='directory')
elif args.target_tar is not None:
if args.sdb:
fail('Option --sdb not supported for --target-tar output. Use --target-dir instead.')
target_tar = path.abspath(args.target_tar)
if path.isfile(target_tar):
if not args.force:
@ -364,6 +378,25 @@ def main(args):
for fragment in file_fragments:
yield b'', fragment
if args.sdb:
for list_name in lists.keys():
sdb_path = os.path.join(target_dir, list_name + '.sdb')
lists[list_name] = SortingSDBWriter(sdb_path, buffering=args.buffer, cache_size=args.sdb_bucket_size)
def to_samples():
for s, f in list_fragments():
yield CollectionSample(f['list-name'], AUDIO_TYPE_PCM, s, f['aligned'], audio_format=audio_format)
sdb_processes = max(1, int(args.sdb_worker_factor * os.cpu_count()))
for sample in progress(convert_samples(to_samples(), audio_type=AUDIO_TYPE_OPUS, processes=sdb_processes),
desc='Exporting samples', total=len(fragments)):
list_name = sample.sample_id
sdb = lists[list_name]
sdb.add(sample)
for sdb in lists.values():
sdb.close()
return
created_directories = {}
tar = None
if target_tar is not None:

344
align/sdb.py Normal file
Просмотреть файл

@ -0,0 +1,344 @@
# -*- coding: utf-8 -*-
import os
import csv
import json
from pathlib import Path
from functools import partial
from utils import MEGABYTE, GIGABYTE
from audio import Sample, AUDIO_TYPE_WAV, AUDIO_TYPE_OPUS, DEFAULT_FORMAT, LOADABLE_FILE_FORMATS
FILE_EXTENSION_CSV = '.csv'
FILE_EXTENSION_SDB = '.sdb'
BIG_ENDIAN = 'big'
INT_SIZE = 4
BIGINT_SIZE = 2 * INT_SIZE
MAGIC = b'SAMPLEDB'
BUFFER_SIZE = 1 * MEGABYTE
CACHE_SIZE = 1 * GIGABYTE
SCHEMA_KEY = 'schema'
CONTENT_KEY = 'content'
MIME_TYPE_KEY = 'mime-type'
MIME_TYPE_TEXT = 'text/plain'
CONTENT_TYPE_SPEECH = 'speech'
CONTENT_TYPE_TRANSCRIPT = 'transcript'
META = {
SCHEMA_KEY: [
{CONTENT_KEY: CONTENT_TYPE_SPEECH, MIME_TYPE_KEY: AUDIO_TYPE_OPUS},
{CONTENT_KEY: CONTENT_TYPE_TRANSCRIPT, MIME_TYPE_KEY: MIME_TYPE_TEXT}
]
}
COLUMN_FILENAME = 'wav_filename'
COLUMN_FILESIZE = 'wav_filesize'
COLUMN_TRANSCRIPT = 'transcript'
class CollectionSample(Sample):
def __init__(self, sample_id, audio_type, raw_data, transcript, audio_format=DEFAULT_FORMAT):
super().__init__(audio_type, raw_data, audio_format=audio_format)
self.sample_id = sample_id
self.transcript = transcript
class DirectSDBWriter:
def __init__(self, sdb_filename, buffering=BUFFER_SIZE):
self.sdb_filename = sdb_filename
self.sdb_file = open(sdb_filename, 'wb', buffering=buffering)
self.offsets = []
self.num_samples = 0
self.sdb_file.write(MAGIC)
meta_data = json.dumps(META).encode()
self.write_big_int(len(meta_data))
self.sdb_file.write(meta_data)
self.offset_samples = self.sdb_file.tell()
self.sdb_file.seek(2 * BIGINT_SIZE, 1)
def write_int(self, n):
return self.sdb_file.write(n.to_bytes(INT_SIZE, BIG_ENDIAN))
def write_big_int(self, n):
return self.sdb_file.write(n.to_bytes(BIGINT_SIZE, BIG_ENDIAN))
def __enter__(self):
return self
def add(self, sample):
def to_bytes(n):
return n.to_bytes(INT_SIZE, BIG_ENDIAN)
sample.convert(AUDIO_TYPE_OPUS)
opus = sample.audio.getbuffer()
opus_len = to_bytes(len(opus))
transcript = sample.transcript.encode()
transcript_len = to_bytes(len(transcript))
entry_len = to_bytes(len(opus_len) + len(opus) + len(transcript_len) + len(transcript))
buffer = b''.join([entry_len, opus_len, opus, transcript_len, transcript])
self.offsets.append(self.sdb_file.tell())
self.sdb_file.write(buffer)
self.num_samples += 1
def close(self):
if self.sdb_file is None:
return
offset_index = self.sdb_file.tell()
self.sdb_file.seek(self.offset_samples)
self.write_big_int(offset_index - self.offset_samples - BIGINT_SIZE)
self.write_big_int(self.num_samples)
self.sdb_file.seek(offset_index + BIGINT_SIZE)
self.write_big_int(self.num_samples)
for offset in self.offsets:
self.write_big_int(offset)
offset_end = self.sdb_file.tell()
self.sdb_file.seek(offset_index)
self.write_big_int(offset_end - offset_index - BIGINT_SIZE)
self.sdb_file.close()
def __len__(self):
return len(self.offsets)
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
class SortingSDBWriter: # pylint: disable=too-many-instance-attributes
def __init__(self, sdb_filename, tmp_sdb_filename=None, cache_size=CACHE_SIZE, buffering=BUFFER_SIZE):
self.sdb_filename = sdb_filename
self.buffering = buffering
self.tmp_sdb_filename = (sdb_filename + '.tmp') if tmp_sdb_filename is None else tmp_sdb_filename
self.tmp_sdb = DirectSDBWriter(self.tmp_sdb_filename, buffering=buffering)
self.cache_size = cache_size
self.buckets = []
self.bucket = []
self.bucket_offset = 0
self.bucket_size = 0
self.overall_size = 0
def __enter__(self):
return self
def finish_bucket(self):
if len(self.bucket) == 0:
return
self.bucket.sort(key=lambda s: s.duration)
for sample in self.bucket:
self.tmp_sdb.add(sample)
self.buckets.append((self.bucket_offset, len(self.bucket)))
self.bucket_offset += len(self.bucket)
self.bucket = []
self.overall_size += self.bucket_size
self.bucket_size = 0
def add(self, sample):
sample.convert(AUDIO_TYPE_OPUS)
self.bucket.append(sample)
self.bucket_size += len(sample.audio.getbuffer())
if self.bucket_size > self.cache_size:
self.finish_bucket()
def close(self):
if self.tmp_sdb is None:
return
self.finish_bucket()
num_samples = len(self.tmp_sdb)
self.tmp_sdb.close()
avg_sample_size = self.overall_size / num_samples
max_cached_samples = self.cache_size / avg_sample_size
buffer_size = max(1, int(max_cached_samples / len(self.buckets)))
sdb_reader = SDB(self.tmp_sdb_filename, buffering=self.buffering)
def buffered_view(start, end):
buffer = []
current_offset = start
while current_offset < end:
while len(buffer) < buffer_size and current_offset < end:
buffer.insert(0, sdb_reader[current_offset])
current_offset += 1
while len(buffer) > 0:
yield buffer.pop(-1)
bucket_views = list(map(lambda b: buffered_view(b[0], b[0] + b[1]), self.buckets))
interleaved = Interleaved(*bucket_views)
with DirectSDBWriter(self.sdb_filename, buffering=self.buffering) as sdb_writer:
for sample in interleaved:
sdb_writer.add(sample)
os.unlink(self.tmp_sdb_filename)
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
class SDB: # pylint: disable=too-many-instance-attributes
def __init__(self, sdb_filename, buffering=BUFFER_SIZE):
self.meta = {}
self.schema = []
self.offsets = []
self.sdb_filename = sdb_filename
self.sdb_file = open(sdb_filename, 'rb', buffering=buffering)
if self.sdb_file.read(len(MAGIC)) != MAGIC:
raise RuntimeError('No Sample Database')
meta_chunk_len = self.read_big_int()
self.meta = json.loads(self.sdb_file.read(meta_chunk_len))
if SCHEMA_KEY not in self.meta:
raise RuntimeError('Missing schema')
self.schema = self.meta[SCHEMA_KEY]
self.speech_index = self.find_column(content=CONTENT_TYPE_SPEECH)
if self.speech_index == -1:
raise RuntimeError('No speech data (missing in schema)')
self.audio_type = self.schema[self.speech_index][MIME_TYPE_KEY]
if self.audio_type not in LOADABLE_FILE_FORMATS:
raise RuntimeError('Unsupported audio format: {}'.format(self.audio_type))
self.transcript_index = self.find_column(content=CONTENT_TYPE_TRANSCRIPT)
if self.transcript_index == -1:
raise RuntimeError('No transcript data (missing in schema)')
text_type = self.schema[self.transcript_index][MIME_TYPE_KEY]
if text_type != MIME_TYPE_TEXT:
raise RuntimeError('Unsupported text type: {}'.format(text_type))
sample_chunk_len = self.read_big_int()
self.sdb_file.seek(sample_chunk_len + BIGINT_SIZE, 1)
num_samples = self.read_big_int()
for _ in range(num_samples):
self.offsets.append(self.read_big_int())
def read_int(self):
return int.from_bytes(self.sdb_file.read(INT_SIZE), BIG_ENDIAN)
def read_big_int(self):
return int.from_bytes(self.sdb_file.read(BIGINT_SIZE), BIG_ENDIAN)
def find_column(self, content=None, mime_type=None):
criteria = []
if content is not None:
criteria.append((CONTENT_KEY, content))
if mime_type is not None:
criteria.append((MIME_TYPE_KEY, mime_type))
if len(criteria) == 0:
raise ValueError('At least one of "content" or "mime-type" has to be provided')
for index, column in enumerate(self.schema):
matched = 0
for field, value in criteria:
if column[field] == value:
matched += 1
if matched == len(criteria):
return index
return -1
def read_row(self, row_index, *columns):
columns = list(columns)
column_data = [None] * len(columns)
found = 0
if not 0 <= row_index < len(self.offsets):
raise ValueError('Wrong sample index: {} - has to be between 0 and {}'
.format(row_index, len(self.offsets) - 1))
self.sdb_file.seek(self.offsets[row_index] + INT_SIZE)
for index in range(len(self.schema)):
chunk_len = self.read_int()
if index in columns:
column_data[columns.index(index)] = self.sdb_file.read(chunk_len)
found += 1
if found == len(columns):
return tuple(column_data)
else:
self.sdb_file.seek(chunk_len, 1)
return tuple(column_data)
def __getitem__(self, i):
audio_data, transcript = self.read_row(i, self.speech_index, self.transcript_index)
transcript = transcript.decode()
sample_id = self.sdb_filename + ':' + str(i)
return CollectionSample(sample_id, self.audio_type, audio_data, transcript)
def __iter__(self):
for i in range(len(self.offsets)):
yield self[i]
def __len__(self):
return len(self.offsets)
def close(self):
if self.sdb_file is not None:
self.sdb_file.close()
def __del__(self):
self.close()
class CSV:
def __init__(self, csv_filename):
self.csv_filename = csv_filename
self.rows = []
csv_dir = Path(csv_filename).parent
with open(csv_filename, 'r') as csv_file:
reader = csv.DictReader(csv_file)
for row in reader:
wav_filename = Path(row[COLUMN_FILENAME])
if not wav_filename.is_absolute():
wav_filename = csv_dir / wav_filename
self.rows.append((str(wav_filename), int(row[COLUMN_FILESIZE]), row[COLUMN_TRANSCRIPT]))
self.rows.sort(key=lambda r: r[1])
def __getitem__(self, i):
wav_filename, _, transcript = self.rows[i]
with open(wav_filename, 'rb') as wav_file:
return CollectionSample(wav_filename, AUDIO_TYPE_WAV, wav_file.read(), transcript)
def __iter__(self):
for i in range(len(self.rows)):
yield self[i]
def __len__(self):
return len(self.rows)
class Interleaved:
def __init__(self, *cols):
self.cols = cols
def __iter__(self):
firsts = []
for index, col in enumerate(self.cols):
try:
it = iter(col)
except TypeError:
it = col
try:
first = next(it)
firsts.append((index, it, first))
except StopIteration:
continue
while len(firsts) > 0:
firsts.sort(key=lambda it_first: it_first[2].duration)
index, it, first = firsts.pop(0)
yield first
try:
first = next(it)
except StopIteration:
continue
firsts.append((index, it, first))
def __len__(self):
return sum(map(len, self.cols))
def samples_from_file(filename, buffering=BUFFER_SIZE):
ext = os.path.splitext(filename)[1].lower()
if ext == FILE_EXTENSION_SDB:
return SDB(filename, buffering=buffering)
if ext == FILE_EXTENSION_CSV:
return CSV(filename)
raise ValueError('Unknown file type: "{}"'.format(ext))
def samples_from_files(filenames, buffering=BUFFER_SIZE):
if len(filenames) == 0:
raise ValueError('No files')
if len(filenames) == 1:
return samples_from_file(filenames[0], buffering=buffering)
cols = list(map(partial(samples_from_file, buffering=buffering), filenames))
return Interleaved(*cols)

Просмотреть файл

@ -1,6 +1,31 @@
import os
import time
from multiprocessing.dummy import Pool as ThreadPool
KILO = 1024
KILOBYTE = 1 * KILO
MEGABYTE = KILO * KILOBYTE
GIGABYTE = KILO * MEGABYTE
TERABYTE = KILO * GIGABYTE
SIZE_PREFIX_LOOKUP = {'k': KILOBYTE, 'm': MEGABYTE, 'g': GIGABYTE, 't': TERABYTE}
def parse_file_size(file_size):
file_size = file_size.lower().strip()
if len(file_size) == 0:
return 0
n = int(keep_only_digits(file_size))
if file_size[-1] == 'b':
file_size = file_size[:-1]
e = file_size[-1]
return SIZE_PREFIX_LOOKUP[e] * n if e in SIZE_PREFIX_LOOKUP else n
def keep_only_digits(txt):
return ''.join(filter(str.isdigit, txt))
def circulate(items, center=None):
count = len(list(items))
if count > 0:

Просмотреть файл

@ -5,3 +5,4 @@ webrtcvad
tqdm
textdistance
pydub
opuslib