DeepSpeech/bin/import_gram_vaani.py

303 строки
12 KiB
Python

import os
import csv
import sys
import math
import urllib
import logging
import argparse
import subprocess
from os import path
from pathlib import Path
import swifter
import pandas as pd
from sox import Transformer
from util.text import validate_label
__version__ = "0.1.0"
_logger = logging.getLogger(__name__)
MAX_SECS = 10
BITDEPTH = 16
N_CHANNELS = 1
SAMPLE_RATE = 16000
DEV_PERCENTAGE = 0.10
TRAIN_PERCENTAGE = 0.80
def parse_args(args):
"""Parse command line parameters
Args:
args ([str]): Command line parameters as list of strings
Returns:
:obj:`argparse.Namespace`: command line parameters namespace
"""
parser = argparse.ArgumentParser(
description="Imports GramVaani data for Deep Speech"
)
parser.add_argument(
"--version",
action="version",
version="GramVaaniImporter {ver}".format(ver=__version__),
)
parser.add_argument(
"-v",
"--verbose",
action="store_const",
required=False,
help="set loglevel to INFO",
dest="loglevel",
const=logging.INFO,
)
parser.add_argument(
"-vv",
"--very-verbose",
action="store_const",
required=False,
help="set loglevel to DEBUG",
dest="loglevel",
const=logging.DEBUG,
)
parser.add_argument(
"-c",
"--csv_filename",
required=True,
help="Path to the GramVaani csv",
dest="csv_filename",
)
parser.add_argument(
"-t",
"--target_dir",
required=True,
help="Directory in which to save the importer GramVaani data",
dest="target_dir",
)
return parser.parse_args(args)
def setup_logging(level):
"""Setup basic logging
Args:
level (int): minimum log level for emitting messages
"""
format = "[%(asctime)s] %(levelname)s:%(name)s:%(message)s"
logging.basicConfig(
level=level, stream=sys.stdout, format=format, datefmt="%Y-%m-%d %H:%M:%S"
)
class GramVaaniCSV:
"""GramVaaniCSV representing a GramVaani dataset.
Args:
csv_filename (str): Path to the GramVaani csv
Attributes:
data (:class:`pandas.DataFrame`): `pandas.DataFrame` Containing the GramVaani csv data
"""
def __init__(self, csv_filename):
self.data = self._parse_csv(csv_filename)
def _parse_csv(self, csv_filename):
_logger.info("Parsing csv file...%s", os.path.abspath(csv_filename))
data = pd.read_csv(
os.path.abspath(csv_filename),
names=["piece_id","audio_url","transcript_labelled","transcript","labels","content_filename","audio_length","user_id"],
usecols=["audio_url","transcript","audio_length"],
skiprows=[0],
engine="python",
encoding="utf-8",
quotechar='"',
quoting=csv.QUOTE_ALL,
)
data.dropna(inplace=True)
_logger.info("Parsed %d lines csv file." % len(data))
return data
class GramVaaniDownloader:
"""GramVaaniDownloader downloads a GramVaani dataset.
Args:
gram_vaani_csv (GramVaaniCSV): A GramVaaniCSV representing the data to download
target_dir (str): The path to download the data to
Attributes:
data (:class:`pandas.DataFrame`): `pandas.DataFrame` Containing the GramVaani csv data
"""
def __init__(self, gram_vaani_csv, target_dir):
self.target_dir = target_dir
self.data = gram_vaani_csv.data
def download(self):
"""Downloads the data associated with this instance
Return:
mp3_directory (os.path): The directory into which the associated mp3's were downloaded
"""
mp3_directory = self._pre_download()
self.data.swifter.apply(func=lambda arg: self._download(*arg, mp3_directory), axis=1, raw=True)
return mp3_directory
def _pre_download(self):
mp3_directory = path.join(self.target_dir, "mp3")
if not path.exists(self.target_dir):
_logger.info("Creating directory...%s", self.target_dir)
os.mkdir(self.target_dir)
if not path.exists(mp3_directory):
_logger.info("Creating directory...%s", mp3_directory)
os.mkdir(mp3_directory)
return mp3_directory
def _download(self, audio_url, transcript, audio_length, mp3_directory):
if audio_url == "audio_url":
return
mp3_filename = path.join(mp3_directory, os.path.basename(audio_url))
if not path.exists(mp3_filename):
_logger.debug("Downloading mp3 file...%s", audio_url)
urllib.request.urlretrieve(audio_url, mp3_filename)
else:
_logger.debug("Already downloaded mp3 file...%s", audio_url)
class GramVaaniConverter:
"""GramVaaniConverter converts the mp3's to wav's for a GramVaani dataset.
Args:
target_dir (str): The path to download the data from
mp3_directory (os.path): The path containing the GramVaani mp3's
Attributes:
target_dir (str): The target directory passed as a command line argument
mp3_directory (os.path): The path containing the GramVaani mp3's
"""
def __init__(self, target_dir, mp3_directory):
self.target_dir = target_dir
self.mp3_directory = Path(mp3_directory)
def convert(self):
"""Converts the mp3's associated with this instance to wav's
Return:
wav_directory (os.path): The directory into which the associated wav's were downloaded
"""
wav_directory = self._pre_convert()
for mp3_filename in self.mp3_directory.glob('**/*.mp3'):
wav_filename = path.join(wav_directory, os.path.splitext(os.path.basename(mp3_filename))[0] + ".wav")
if not path.exists(wav_filename):
_logger.debug("Converting mp3 file %s to wav file %s" % (mp3_filename, wav_filename))
transformer = Transformer()
transformer.convert(samplerate=SAMPLE_RATE, n_channels=N_CHANNELS, bitdepth=BITDEPTH)
transformer.build(str(mp3_filename), str(wav_filename))
else:
_logger.debug("Already converted mp3 file %s to wav file %s" % (mp3_filename, wav_filename))
return wav_directory
def _pre_convert(self):
wav_directory = path.join(self.target_dir, "wav")
if not path.exists(self.target_dir):
_logger.info("Creating directory...%s", self.target_dir)
os.mkdir(self.target_dir)
if not path.exists(wav_directory):
_logger.info("Creating directory...%s", wav_directory)
os.mkdir(wav_directory)
return wav_directory
class GramVaaniDataSets:
def __init__(self, target_dir, wav_directory, gram_vaani_csv):
self.target_dir = target_dir
self.wav_directory = wav_directory
self.csv_data = gram_vaani_csv.data
self.raw = pd.DataFrame(columns=["wav_filename","wav_filesize","transcript"])
self.valid = pd.DataFrame(columns=["wav_filename","wav_filesize","transcript"])
self.train = pd.DataFrame(columns=["wav_filename","wav_filesize","transcript"])
self.dev = pd.DataFrame(columns=["wav_filename","wav_filesize","transcript"])
self.test = pd.DataFrame(columns=["wav_filename","wav_filesize","transcript"])
def create(self):
self._convert_csv_data_to_raw_data()
self.raw.index = range(len(self.raw.index))
self.valid = self.raw[self._is_valid_raw_rows()]
self.valid = self.valid.sample(frac=1).reset_index(drop=True)
train_size, dev_size, test_size = self._calculate_data_set_sizes()
self.train = self.valid.loc[0:train_size]
self.dev = self.valid.loc[train_size:train_size+dev_size]
self.test = self.valid.loc[train_size+dev_size:train_size+dev_size+test_size]
def _convert_csv_data_to_raw_data(self):
self.raw[["wav_filename","wav_filesize","transcript"]] = self.csv_data[
["audio_url","transcript","audio_length"]
].swifter.apply(func=lambda arg: self._convert_csv_data_to_raw_data_impl(*arg), axis=1, raw=True)
self.raw.reset_index()
def _convert_csv_data_to_raw_data_impl(self, audio_url, transcript, audio_length):
if audio_url == "audio_url":
return pd.Series(["wav_filename", "wav_filesize", "transcript"])
mp3_filename = os.path.basename(audio_url)
wav_relative_filename = path.join("wav", os.path.splitext(os.path.basename(mp3_filename))[0] + ".wav")
wav_filesize = path.getsize(path.join(self.target_dir, wav_relative_filename))
transcript = validate_label(transcript)
if None == transcript:
transcript = ""
return pd.Series([wav_relative_filename, wav_filesize, transcript])
def _is_valid_raw_rows(self):
is_valid_raw_transcripts = self._is_valid_raw_transcripts()
is_valid_raw_wav_frames = self._is_valid_raw_wav_frames()
is_valid_raw_row = [(is_valid_raw_transcript & is_valid_raw_wav_frame) for is_valid_raw_transcript, is_valid_raw_wav_frame in zip(is_valid_raw_transcripts, is_valid_raw_wav_frames)]
series = pd.Series(is_valid_raw_row)
return series
def _is_valid_raw_transcripts(self):
return pd.Series([bool(transcript) for transcript in self.raw.transcript])
def _is_valid_raw_wav_frames(self):
transcripts = [str(transcript) for transcript in self.raw.transcript]
wav_filepaths = [path.join(self.target_dir, str(wav_filename)) for wav_filename in self.raw.wav_filename]
wav_frames = [int(subprocess.check_output(['soxi', '-s', wav_filepath], stderr=subprocess.STDOUT)) for wav_filepath in wav_filepaths]
is_valid_raw_wav_frames = [self._is_wav_frame_valid(wav_frame, transcript) for wav_frame, transcript in zip(wav_frames, transcripts)]
return pd.Series(is_valid_raw_wav_frames)
def _is_wav_frame_valid(self, wav_frame, transcript):
is_wav_frame_valid = True
if int(wav_frame/SAMPLE_RATE*1000/10/2) < len(str(transcript)):
is_wav_frame_valid = False
elif wav_frame/SAMPLE_RATE > MAX_SECS:
is_wav_frame_valid = False
return is_wav_frame_valid
def _calculate_data_set_sizes(self):
total_size = len(self.valid)
dev_size = math.floor(total_size * DEV_PERCENTAGE)
train_size = math.floor(total_size * TRAIN_PERCENTAGE)
test_size = total_size - (train_size + dev_size)
return (train_size, dev_size, test_size)
def save(self):
datasets = ["train", "dev", "test"]
for dataset in datasets:
self._save(dataset)
def _save(self, dataset):
dataset_path = os.path.join(self.target_dir, dataset + ".csv")
dataframe = getattr(self, dataset)
dataframe.to_csv(dataset_path, index=False, encoding="utf-8", escapechar='\\', quoting=csv.QUOTE_MINIMAL)
def main(args):
"""Main entry point allowing external calls
Args:
args ([str]): command line parameter list
"""
args = parse_args(args)
setup_logging(args.loglevel)
_logger.info("Starting GramVaani importer...")
_logger.info("Starting loading GramVaani csv...")
csv = GramVaaniCSV(args.csv_filename)
_logger.info("Starting downloading GramVaani mp3's...")
downloader = GramVaaniDownloader(csv, args.target_dir)
mp3_directory = downloader.download()
_logger.info("Starting converting GramVaani mp3's to wav's...")
converter = GramVaaniConverter(args.target_dir, mp3_directory)
wav_directory = converter.convert()
datasets = GramVaaniDataSets(args.target_dir, wav_directory, csv)
datasets.create()
datasets.save()
_logger.info("Finished GramVaani importer...")
main(sys.argv[1:])