зеркало из https://github.com/mozilla/DeepSpeech.git
303 строки
12 KiB
Python
303 строки
12 KiB
Python
import os
|
|
import csv
|
|
import sys
|
|
import math
|
|
import urllib
|
|
import logging
|
|
import argparse
|
|
import subprocess
|
|
from os import path
|
|
from pathlib import Path
|
|
|
|
import swifter
|
|
import pandas as pd
|
|
from sox import Transformer
|
|
|
|
from util.text import validate_label
|
|
|
|
|
|
__version__ = "0.1.0"
|
|
_logger = logging.getLogger(__name__)
|
|
|
|
|
|
MAX_SECS = 10
|
|
BITDEPTH = 16
|
|
N_CHANNELS = 1
|
|
SAMPLE_RATE = 16000
|
|
|
|
DEV_PERCENTAGE = 0.10
|
|
TRAIN_PERCENTAGE = 0.80
|
|
|
|
|
|
def parse_args(args):
|
|
"""Parse command line parameters
|
|
Args:
|
|
args ([str]): Command line parameters as list of strings
|
|
Returns:
|
|
:obj:`argparse.Namespace`: command line parameters namespace
|
|
"""
|
|
parser = argparse.ArgumentParser(
|
|
description="Imports GramVaani data for Deep Speech"
|
|
)
|
|
parser.add_argument(
|
|
"--version",
|
|
action="version",
|
|
version="GramVaaniImporter {ver}".format(ver=__version__),
|
|
)
|
|
parser.add_argument(
|
|
"-v",
|
|
"--verbose",
|
|
action="store_const",
|
|
required=False,
|
|
help="set loglevel to INFO",
|
|
dest="loglevel",
|
|
const=logging.INFO,
|
|
)
|
|
parser.add_argument(
|
|
"-vv",
|
|
"--very-verbose",
|
|
action="store_const",
|
|
required=False,
|
|
help="set loglevel to DEBUG",
|
|
dest="loglevel",
|
|
const=logging.DEBUG,
|
|
)
|
|
parser.add_argument(
|
|
"-c",
|
|
"--csv_filename",
|
|
required=True,
|
|
help="Path to the GramVaani csv",
|
|
dest="csv_filename",
|
|
)
|
|
parser.add_argument(
|
|
"-t",
|
|
"--target_dir",
|
|
required=True,
|
|
help="Directory in which to save the importer GramVaani data",
|
|
dest="target_dir",
|
|
)
|
|
return parser.parse_args(args)
|
|
|
|
def setup_logging(level):
|
|
"""Setup basic logging
|
|
Args:
|
|
level (int): minimum log level for emitting messages
|
|
"""
|
|
format = "[%(asctime)s] %(levelname)s:%(name)s:%(message)s"
|
|
logging.basicConfig(
|
|
level=level, stream=sys.stdout, format=format, datefmt="%Y-%m-%d %H:%M:%S"
|
|
)
|
|
|
|
class GramVaaniCSV:
|
|
"""GramVaaniCSV representing a GramVaani dataset.
|
|
Args:
|
|
csv_filename (str): Path to the GramVaani csv
|
|
Attributes:
|
|
data (:class:`pandas.DataFrame`): `pandas.DataFrame` Containing the GramVaani csv data
|
|
"""
|
|
|
|
def __init__(self, csv_filename):
|
|
self.data = self._parse_csv(csv_filename)
|
|
|
|
def _parse_csv(self, csv_filename):
|
|
_logger.info("Parsing csv file...%s", os.path.abspath(csv_filename))
|
|
data = pd.read_csv(
|
|
os.path.abspath(csv_filename),
|
|
names=["piece_id","audio_url","transcript_labelled","transcript","labels","content_filename","audio_length","user_id"],
|
|
usecols=["audio_url","transcript","audio_length"],
|
|
skiprows=[0],
|
|
engine="python",
|
|
encoding="utf-8",
|
|
quotechar='"',
|
|
quoting=csv.QUOTE_ALL,
|
|
)
|
|
data.dropna(inplace=True)
|
|
_logger.info("Parsed %d lines csv file." % len(data))
|
|
return data
|
|
|
|
class GramVaaniDownloader:
|
|
"""GramVaaniDownloader downloads a GramVaani dataset.
|
|
Args:
|
|
gram_vaani_csv (GramVaaniCSV): A GramVaaniCSV representing the data to download
|
|
target_dir (str): The path to download the data to
|
|
Attributes:
|
|
data (:class:`pandas.DataFrame`): `pandas.DataFrame` Containing the GramVaani csv data
|
|
"""
|
|
|
|
def __init__(self, gram_vaani_csv, target_dir):
|
|
self.target_dir = target_dir
|
|
self.data = gram_vaani_csv.data
|
|
|
|
def download(self):
|
|
"""Downloads the data associated with this instance
|
|
Return:
|
|
mp3_directory (os.path): The directory into which the associated mp3's were downloaded
|
|
"""
|
|
mp3_directory = self._pre_download()
|
|
self.data.swifter.apply(func=lambda arg: self._download(*arg, mp3_directory), axis=1, raw=True)
|
|
return mp3_directory
|
|
|
|
def _pre_download(self):
|
|
mp3_directory = path.join(self.target_dir, "mp3")
|
|
if not path.exists(self.target_dir):
|
|
_logger.info("Creating directory...%s", self.target_dir)
|
|
os.mkdir(self.target_dir)
|
|
if not path.exists(mp3_directory):
|
|
_logger.info("Creating directory...%s", mp3_directory)
|
|
os.mkdir(mp3_directory)
|
|
return mp3_directory
|
|
|
|
def _download(self, audio_url, transcript, audio_length, mp3_directory):
|
|
if audio_url == "audio_url":
|
|
return
|
|
mp3_filename = path.join(mp3_directory, os.path.basename(audio_url))
|
|
if not path.exists(mp3_filename):
|
|
_logger.debug("Downloading mp3 file...%s", audio_url)
|
|
urllib.request.urlretrieve(audio_url, mp3_filename)
|
|
else:
|
|
_logger.debug("Already downloaded mp3 file...%s", audio_url)
|
|
|
|
class GramVaaniConverter:
|
|
"""GramVaaniConverter converts the mp3's to wav's for a GramVaani dataset.
|
|
Args:
|
|
target_dir (str): The path to download the data from
|
|
mp3_directory (os.path): The path containing the GramVaani mp3's
|
|
Attributes:
|
|
target_dir (str): The target directory passed as a command line argument
|
|
mp3_directory (os.path): The path containing the GramVaani mp3's
|
|
"""
|
|
|
|
def __init__(self, target_dir, mp3_directory):
|
|
self.target_dir = target_dir
|
|
self.mp3_directory = Path(mp3_directory)
|
|
|
|
def convert(self):
|
|
"""Converts the mp3's associated with this instance to wav's
|
|
Return:
|
|
wav_directory (os.path): The directory into which the associated wav's were downloaded
|
|
"""
|
|
wav_directory = self._pre_convert()
|
|
for mp3_filename in self.mp3_directory.glob('**/*.mp3'):
|
|
wav_filename = path.join(wav_directory, os.path.splitext(os.path.basename(mp3_filename))[0] + ".wav")
|
|
if not path.exists(wav_filename):
|
|
_logger.debug("Converting mp3 file %s to wav file %s" % (mp3_filename, wav_filename))
|
|
transformer = Transformer()
|
|
transformer.convert(samplerate=SAMPLE_RATE, n_channels=N_CHANNELS, bitdepth=BITDEPTH)
|
|
transformer.build(str(mp3_filename), str(wav_filename))
|
|
else:
|
|
_logger.debug("Already converted mp3 file %s to wav file %s" % (mp3_filename, wav_filename))
|
|
return wav_directory
|
|
|
|
def _pre_convert(self):
|
|
wav_directory = path.join(self.target_dir, "wav")
|
|
if not path.exists(self.target_dir):
|
|
_logger.info("Creating directory...%s", self.target_dir)
|
|
os.mkdir(self.target_dir)
|
|
if not path.exists(wav_directory):
|
|
_logger.info("Creating directory...%s", wav_directory)
|
|
os.mkdir(wav_directory)
|
|
return wav_directory
|
|
|
|
class GramVaaniDataSets:
|
|
def __init__(self, target_dir, wav_directory, gram_vaani_csv):
|
|
self.target_dir = target_dir
|
|
self.wav_directory = wav_directory
|
|
self.csv_data = gram_vaani_csv.data
|
|
self.raw = pd.DataFrame(columns=["wav_filename","wav_filesize","transcript"])
|
|
self.valid = pd.DataFrame(columns=["wav_filename","wav_filesize","transcript"])
|
|
self.train = pd.DataFrame(columns=["wav_filename","wav_filesize","transcript"])
|
|
self.dev = pd.DataFrame(columns=["wav_filename","wav_filesize","transcript"])
|
|
self.test = pd.DataFrame(columns=["wav_filename","wav_filesize","transcript"])
|
|
|
|
def create(self):
|
|
self._convert_csv_data_to_raw_data()
|
|
self.raw.index = range(len(self.raw.index))
|
|
self.valid = self.raw[self._is_valid_raw_rows()]
|
|
self.valid = self.valid.sample(frac=1).reset_index(drop=True)
|
|
train_size, dev_size, test_size = self._calculate_data_set_sizes()
|
|
self.train = self.valid.loc[0:train_size]
|
|
self.dev = self.valid.loc[train_size:train_size+dev_size]
|
|
self.test = self.valid.loc[train_size+dev_size:train_size+dev_size+test_size]
|
|
|
|
def _convert_csv_data_to_raw_data(self):
|
|
self.raw[["wav_filename","wav_filesize","transcript"]] = self.csv_data[
|
|
["audio_url","transcript","audio_length"]
|
|
].swifter.apply(func=lambda arg: self._convert_csv_data_to_raw_data_impl(*arg), axis=1, raw=True)
|
|
self.raw.reset_index()
|
|
|
|
def _convert_csv_data_to_raw_data_impl(self, audio_url, transcript, audio_length):
|
|
if audio_url == "audio_url":
|
|
return pd.Series(["wav_filename", "wav_filesize", "transcript"])
|
|
mp3_filename = os.path.basename(audio_url)
|
|
wav_relative_filename = path.join("wav", os.path.splitext(os.path.basename(mp3_filename))[0] + ".wav")
|
|
wav_filesize = path.getsize(path.join(self.target_dir, wav_relative_filename))
|
|
transcript = validate_label(transcript)
|
|
if None == transcript:
|
|
transcript = ""
|
|
return pd.Series([wav_relative_filename, wav_filesize, transcript])
|
|
|
|
def _is_valid_raw_rows(self):
|
|
is_valid_raw_transcripts = self._is_valid_raw_transcripts()
|
|
is_valid_raw_wav_frames = self._is_valid_raw_wav_frames()
|
|
is_valid_raw_row = [(is_valid_raw_transcript & is_valid_raw_wav_frame) for is_valid_raw_transcript, is_valid_raw_wav_frame in zip(is_valid_raw_transcripts, is_valid_raw_wav_frames)]
|
|
series = pd.Series(is_valid_raw_row)
|
|
return series
|
|
|
|
def _is_valid_raw_transcripts(self):
|
|
return pd.Series([bool(transcript) for transcript in self.raw.transcript])
|
|
|
|
def _is_valid_raw_wav_frames(self):
|
|
transcripts = [str(transcript) for transcript in self.raw.transcript]
|
|
wav_filepaths = [path.join(self.target_dir, str(wav_filename)) for wav_filename in self.raw.wav_filename]
|
|
wav_frames = [int(subprocess.check_output(['soxi', '-s', wav_filepath], stderr=subprocess.STDOUT)) for wav_filepath in wav_filepaths]
|
|
is_valid_raw_wav_frames = [self._is_wav_frame_valid(wav_frame, transcript) for wav_frame, transcript in zip(wav_frames, transcripts)]
|
|
return pd.Series(is_valid_raw_wav_frames)
|
|
|
|
def _is_wav_frame_valid(self, wav_frame, transcript):
|
|
is_wav_frame_valid = True
|
|
if int(wav_frame/SAMPLE_RATE*1000/10/2) < len(str(transcript)):
|
|
is_wav_frame_valid = False
|
|
elif wav_frame/SAMPLE_RATE > MAX_SECS:
|
|
is_wav_frame_valid = False
|
|
return is_wav_frame_valid
|
|
|
|
def _calculate_data_set_sizes(self):
|
|
total_size = len(self.valid)
|
|
dev_size = math.floor(total_size * DEV_PERCENTAGE)
|
|
train_size = math.floor(total_size * TRAIN_PERCENTAGE)
|
|
test_size = total_size - (train_size + dev_size)
|
|
return (train_size, dev_size, test_size)
|
|
|
|
def save(self):
|
|
datasets = ["train", "dev", "test"]
|
|
for dataset in datasets:
|
|
self._save(dataset)
|
|
|
|
def _save(self, dataset):
|
|
dataset_path = os.path.join(self.target_dir, dataset + ".csv")
|
|
dataframe = getattr(self, dataset)
|
|
dataframe.to_csv(dataset_path, index=False, encoding="utf-8", escapechar='\\', quoting=csv.QUOTE_MINIMAL)
|
|
|
|
def main(args):
|
|
"""Main entry point allowing external calls
|
|
Args:
|
|
args ([str]): command line parameter list
|
|
"""
|
|
args = parse_args(args)
|
|
setup_logging(args.loglevel)
|
|
_logger.info("Starting GramVaani importer...")
|
|
_logger.info("Starting loading GramVaani csv...")
|
|
csv = GramVaaniCSV(args.csv_filename)
|
|
_logger.info("Starting downloading GramVaani mp3's...")
|
|
downloader = GramVaaniDownloader(csv, args.target_dir)
|
|
mp3_directory = downloader.download()
|
|
_logger.info("Starting converting GramVaani mp3's to wav's...")
|
|
converter = GramVaaniConverter(args.target_dir, mp3_directory)
|
|
wav_directory = converter.convert()
|
|
datasets = GramVaaniDataSets(args.target_dir, wav_directory, csv)
|
|
datasets.create()
|
|
datasets.save()
|
|
_logger.info("Finished GramVaani importer...")
|
|
|
|
main(sys.argv[1:])
|