зеркало из https://github.com/mozilla/DeepSpeech.git
Addressed review comments
This commit is contained in:
Родитель
441ac5869f
Коммит
0bc132cabe
|
@ -63,18 +63,18 @@ def parse_args(args):
|
|||
const=logging.DEBUG,
|
||||
)
|
||||
parser.add_argument(
|
||||
"-f",
|
||||
"--file",
|
||||
"-c",
|
||||
"--csv_filename",
|
||||
required=True,
|
||||
help="Path to the GramVaani csv",
|
||||
dest="csv_filename",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-d",
|
||||
"--directory",
|
||||
"-t",
|
||||
"--target_dir",
|
||||
required=True,
|
||||
help="Directory in which to save the importer GramVaani data",
|
||||
dest="directory",
|
||||
dest="target_dir",
|
||||
)
|
||||
return parser.parse_args(args)
|
||||
|
||||
|
@ -119,13 +119,13 @@ class GramVaaniDownloader:
|
|||
"""GramVaaniDownloader downloads a GramVaani dataset.
|
||||
Args:
|
||||
gram_vaani_csv (GramVaaniCSV): A GramVaaniCSV representing the data to download
|
||||
directory (str): The path to download the data from
|
||||
target_dir (str): The path to download the data to
|
||||
Attributes:
|
||||
data (:class:`pandas.DataFrame`): `pandas.DataFrame` Containing the GramVaani csv data
|
||||
"""
|
||||
|
||||
def __init__(self, gram_vaani_csv, directory):
|
||||
self.directory = directory
|
||||
def __init__(self, gram_vaani_csv, target_dir):
|
||||
self.target_dir = target_dir
|
||||
self.data = gram_vaani_csv.data
|
||||
|
||||
def download(self):
|
||||
|
@ -138,10 +138,10 @@ class GramVaaniDownloader:
|
|||
return mp3_directory
|
||||
|
||||
def _pre_download(self):
|
||||
mp3_directory = path.join(self.directory, "mp3")
|
||||
if not path.exists(self.directory):
|
||||
_logger.info("Creating directory...%s", self.directory)
|
||||
os.mkdir(self.directory)
|
||||
mp3_directory = path.join(self.target_dir, "mp3")
|
||||
if not path.exists(self.target_dir):
|
||||
_logger.info("Creating directory...%s", self.target_dir)
|
||||
os.mkdir(self.target_dir)
|
||||
if not path.exists(mp3_directory):
|
||||
_logger.info("Creating directory...%s", mp3_directory)
|
||||
os.mkdir(mp3_directory)
|
||||
|
@ -160,15 +160,15 @@ class GramVaaniDownloader:
|
|||
class GramVaaniConverter:
|
||||
"""GramVaaniConverter converts the mp3's to wav's for a GramVaani dataset.
|
||||
Args:
|
||||
directory (str): The path to download the data from
|
||||
target_dir (str): The path to download the data from
|
||||
mp3_directory (os.path): The path containing the GramVaani mp3's
|
||||
Attributes:
|
||||
directory (str): The target directory passed as a command line argument
|
||||
target_dir (str): The target directory passed as a command line argument
|
||||
mp3_directory (os.path): The path containing the GramVaani mp3's
|
||||
"""
|
||||
|
||||
def __init__(self, directory, mp3_directory):
|
||||
self.directory = directory
|
||||
def __init__(self, target_dir, mp3_directory):
|
||||
self.target_dir = target_dir
|
||||
self.mp3_directory = Path(mp3_directory)
|
||||
|
||||
def convert(self):
|
||||
|
@ -189,18 +189,18 @@ class GramVaaniConverter:
|
|||
return wav_directory
|
||||
|
||||
def _pre_convert(self):
|
||||
wav_directory = path.join(self.directory, "wav")
|
||||
if not path.exists(self.directory):
|
||||
_logger.info("Creating directory...%s", self.directory)
|
||||
os.mkdir(self.directory)
|
||||
wav_directory = path.join(self.target_dir, "wav")
|
||||
if not path.exists(self.target_dir):
|
||||
_logger.info("Creating directory...%s", self.target_dir)
|
||||
os.mkdir(self.target_dir)
|
||||
if not path.exists(wav_directory):
|
||||
_logger.info("Creating directory...%s", wav_directory)
|
||||
os.mkdir(wav_directory)
|
||||
return wav_directory
|
||||
|
||||
class GramVaaniDataSets:
|
||||
def __init__(self, directory, wav_directory, gram_vaani_csv):
|
||||
self.directory = directory
|
||||
def __init__(self, target_dir, wav_directory, gram_vaani_csv):
|
||||
self.target_dir = target_dir
|
||||
self.wav_directory = wav_directory
|
||||
self.csv_data = gram_vaani_csv.data
|
||||
self.raw = pd.DataFrame(columns=["wav_filename","wav_filesize","transcript"])
|
||||
|
@ -230,7 +230,7 @@ class GramVaaniDataSets:
|
|||
return pd.Series(["wav_filename", "wav_filesize", "transcript"])
|
||||
mp3_filename = os.path.basename(audio_url)
|
||||
wav_relative_filename = path.join("wav", os.path.splitext(os.path.basename(mp3_filename))[0] + ".wav")
|
||||
wav_filesize = path.getsize(path.join(self.directory, wav_relative_filename))
|
||||
wav_filesize = path.getsize(path.join(self.target_dir, wav_relative_filename))
|
||||
transcript = validate_label(transcript)
|
||||
if None == transcript:
|
||||
transcript = ""
|
||||
|
@ -248,7 +248,7 @@ class GramVaaniDataSets:
|
|||
|
||||
def _is_valid_raw_wav_frames(self):
|
||||
transcripts = [str(transcript) for transcript in self.raw.transcript]
|
||||
wav_filepaths = [path.join(self.directory, str(wav_filename)) for wav_filename in self.raw.wav_filename]
|
||||
wav_filepaths = [path.join(self.target_dir, str(wav_filename)) for wav_filename in self.raw.wav_filename]
|
||||
wav_frames = [int(subprocess.check_output(['soxi', '-s', wav_filepath], stderr=subprocess.STDOUT)) for wav_filepath in wav_filepaths]
|
||||
is_valid_raw_wav_frames = [self._is_wav_frame_valid(wav_frame, transcript) for wav_frame, transcript in zip(wav_frames, transcripts)]
|
||||
return pd.Series(is_valid_raw_wav_frames)
|
||||
|
@ -274,7 +274,7 @@ class GramVaaniDataSets:
|
|||
self._save(dataset)
|
||||
|
||||
def _save(self, dataset):
|
||||
dataset_path = os.path.join(self.directory, dataset + ".csv")
|
||||
dataset_path = os.path.join(self.target_dir, dataset + ".csv")
|
||||
dataframe = getattr(self, dataset)
|
||||
dataframe.to_csv(dataset_path, index=False, encoding="utf-8", escapechar='\\', quoting=csv.QUOTE_MINIMAL)
|
||||
|
||||
|
@ -289,12 +289,12 @@ def main(args):
|
|||
_logger.info("Starting loading GramVaani csv...")
|
||||
csv = GramVaaniCSV(args.csv_filename)
|
||||
_logger.info("Starting downloading GramVaani mp3's...")
|
||||
downloader = GramVaaniDownloader(csv, args.directory)
|
||||
downloader = GramVaaniDownloader(csv, args.target_dir)
|
||||
mp3_directory = downloader.download()
|
||||
_logger.info("Starting converting GramVaani mp3's to wav's...")
|
||||
converter = GramVaaniConverter(args.directory, mp3_directory)
|
||||
converter = GramVaaniConverter(args.target_dir, mp3_directory)
|
||||
wav_directory = converter.convert()
|
||||
datasets = GramVaaniDataSets(args.directory, wav_directory, csv)
|
||||
datasets = GramVaaniDataSets(args.target_dir, wav_directory, csv)
|
||||
datasets.create()
|
||||
datasets.save()
|
||||
_logger.info("Finished GramVaani importer...")
|
||||
|
|
Загрузка…
Ссылка в новой задаче