This commit is contained in:
erogol 2020-05-20 16:12:10 +02:00
Родитель 4a6949632b
Коммит ddd7de6439
1 изменённых файлов: 22 добавлений и 8 удалений

Просмотреть файл

@ -2,15 +2,22 @@
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"metadata": {
"Collapsed": "false"
},
"source": [
"This notebook is to test attention performance on hard sentences taken from DeepVoice paper."
"This notebook is to test attention performance of a TTS model on a list of sentences taken from DeepVoice paper.\n",
"### Features of this notebook\n",
"- You can see visually how your model performs on each sentence and try to dicern common problems.\n",
"- At the end, final attention score would be printed showing the ultimate performace of your model. You can use this value to perform model selection.\n",
"- You can change the list of sentences byt providing a different sentence file."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"Collapsed": "false",
"scrolled": true
},
"outputs": [],
@ -31,7 +38,8 @@
"\n",
"from TTS.layers import *\n",
"from TTS.utils.audio import AudioProcessor\n",
"from TTS.utils.generic_utils import load_config, setup_model\n",
"from TTS.utils.generic_utils import setup_model\n",
"from TTS.utils.io import load_config\n",
"from TTS.utils.text import text_to_sequence\n",
"from TTS.utils.synthesis import synthesis\n",
"from TTS.utils.visual import plot_alignment\n",
@ -45,7 +53,7 @@
"def tts(model, text, CONFIG, use_cuda, ap):\n",
" t_1 = time.time()\n",
" # run the model\n",
" waveform, alignment, mel_spec, mel_postnet_spec, stop_tokens = synthesis(model, text, CONFIG, use_cuda, ap, speaker_id, None, False, CONFIG.enable_eos_bos_chars, True)\n",
" waveform, alignment, mel_spec, mel_postnet_spec, stop_tokens, inputs = synthesis(model, text, CONFIG, use_cuda, ap, speaker_id, None, False, CONFIG.enable_eos_bos_chars, True)\n",
" if CONFIG.model == \"Tacotron\" and not use_gl:\n",
" mel_postnet_spec = ap.out_linear_to_mel(mel_postnet_spec.T).T\n",
" # plotting\n",
@ -62,7 +70,7 @@
" return attn_score\n",
"\n",
"# Set constants\n",
"ROOT_PATH = '/home/erogol/Models/LJSpeech/ljspeech-March-17-2020_01+16AM-871588c/'\n",
"ROOT_PATH = '/home/erogol/Models/LJSpeech/ljspeech-May-20-2020_12+29PM-1835628/'\n",
"MODEL_PATH = ROOT_PATH + '/best_model.pth.tar'\n",
"CONFIG_PATH = ROOT_PATH + '/config.json'\n",
"OUT_FOLDER = './hard_sentences/'\n",
@ -82,7 +90,9 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"Collapsed": "false"
},
"outputs": [],
"source": [
"# LOAD TTS MODEL\n",
@ -130,7 +140,9 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"Collapsed": "false"
},
"outputs": [],
"source": [
"model.decoder.max_decoder_steps=3000\n",
@ -144,7 +156,9 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"Collapsed": "false"
},
"outputs": [],
"source": [
"np.mean(attn_scores)"