зеркало из https://github.com/mozilla/TTS.git
notebook update
This commit is contained in:
Родитель
7f117e3fd6
Коммит
a510baa79c
|
@ -32,16 +32,17 @@
|
|||
"import numpy as np\n",
|
||||
"from tqdm import tqdm as tqdm\n",
|
||||
"from torch.utils.data import DataLoader\n",
|
||||
"from TTS.models.tacotron2 import Tacotron2\n",
|
||||
"from TTS.datasets.TTSDataset import MyDataset\n",
|
||||
"from TTS.layers.losses import L1LossMasked\n",
|
||||
"from TTS.utils.audio import AudioProcessor\n",
|
||||
"from TTS.utils.visual import plot_spectrogram\n",
|
||||
"from TTS.utils.generic_utils import load_config, setup_model\n",
|
||||
"from TTS.datasets.preprocess import ljspeech\n",
|
||||
"from TTS.utils.generic_utils import load_config, setup_model, sequence_mask\n",
|
||||
"from TTS.utils.text.symbols import symbols, phonemes\n",
|
||||
"\n",
|
||||
"%matplotlib inline\n",
|
||||
"\n",
|
||||
"import os\n",
|
||||
"os.environ['CUDA_VISIBLE_DEVICES']='1'"
|
||||
"os.environ['CUDA_VISIBLE_DEVICES']='2'"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -68,22 +69,23 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"OUT_PATH = \"/home/erogol/Data/Mozilla/wavernn/4841/\"\n",
|
||||
"DATA_PATH = \"/home/erogol/Data/Mozilla/\"\n",
|
||||
"DATASET = \"mozilla\"\n",
|
||||
"OUT_PATH = \"/data/rw/pit/data/turkish-vocoder/\"\n",
|
||||
"DATA_PATH = \"/data/rw/home/Turkish\"\n",
|
||||
"DATASET = \"ljspeech\"\n",
|
||||
"METADATA_FILE = \"metadata.txt\"\n",
|
||||
"CONFIG_PATH = \"/media/erogol/data_ssd/Data/models/mozilla_models/4841/config.json\"\n",
|
||||
"MODEL_FILE = \"/media/erogol/data_ssd/Data/models/mozilla_models/4841/best_model.pth.tar\"\n",
|
||||
"DRY_RUN = False # if False, does not generate output files, only computes loss and visuals.\n",
|
||||
"CONFIG_PATH = \"/data/rw/pit/keep/turkish-January-08-2020_01+56AM-ca5e133/config.json\"\n",
|
||||
"MODEL_FILE = \"/data/rw/pit/keep/turkish-January-08-2020_01+56AM-ca5e133/checkpoint_255000.pth.tar\"\n",
|
||||
"BATCH_SIZE = 32\n",
|
||||
"\n",
|
||||
"QUANTIZED_WAV = False\n",
|
||||
"QUANTIZE_BIT = 9\n",
|
||||
"DRY_RUN = False # if False, does not generate output files, only computes loss and visuals.\n",
|
||||
"\n",
|
||||
"use_cuda = torch.cuda.is_available()\n",
|
||||
"print(\" > CUDA enabled: \", use_cuda)\n",
|
||||
"\n",
|
||||
"C = load_config(CONFIG_PATH)\n",
|
||||
"ap = AudioProcessor(bits=9, **C.audio)\n",
|
||||
"C.prenet_dropout = False\n",
|
||||
"C.separate_stopnet = True"
|
||||
"ap = AudioProcessor(bits=QUANTIZE_BIT, **C.audio)"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -92,35 +94,32 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"preprocessor = importlib.import_module('datasets.preprocess')\n",
|
||||
"preprocessor = getattr(preprocessor, DATASET.lower())\n",
|
||||
"\n",
|
||||
"dataset = MyDataset(DATA_PATH, METADATA_FILE, C.r, C.text_cleaner, ap, preprocessor, use_phonemes=C.use_phonemes, phoneme_cache_path=C.phoneme_cache_path)\n",
|
||||
"loader = DataLoader(dataset, batch_size=BATCH_SIZE, num_workers=4, collate_fn=dataset.collate_fn, shuffle=False, drop_last=False)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from TTS.utils.text.symbols import symbols, phonemes\n",
|
||||
"from TTS.utils.generic_utils import sequence_mask\n",
|
||||
"from TTS.layers.losses import L1LossMasked\n",
|
||||
"from TTS.utils.text.symbols import symbols, phonemes\n",
|
||||
"\n",
|
||||
"# load the model\n",
|
||||
"num_chars = len(phonemes) if C.use_phonemes else len(symbols)\n",
|
||||
"model = setup_model(num_chars, C)\n",
|
||||
"# TODO: multiple speaker\n",
|
||||
"model = setup_model(num_chars, num_speakers=0, c=C)\n",
|
||||
"checkpoint = torch.load(MODEL_FILE)\n",
|
||||
"model.load_state_dict(checkpoint['model'])\n",
|
||||
"print(checkpoint['step'])\n",
|
||||
"model.eval()\n",
|
||||
"model.decoder.set_r(checkpoint['r'])\n",
|
||||
"if use_cuda:\n",
|
||||
" model = model.cuda()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"preprocessor = importlib.import_module('TTS.datasets.preprocess')\n",
|
||||
"preprocessor = getattr(preprocessor, DATASET.lower())\n",
|
||||
"meta_data = preprocessor(DATA_PATH,METADATA_FILE)\n",
|
||||
"dataset = MyDataset(checkpoint['r'], C.text_cleaner, ap, meta_data, use_phonemes=C.use_phonemes, phoneme_cache_path=C.phoneme_cache_path, enable_eos_bos=C.enable_eos_bos_chars)\n",
|
||||
"loader = DataLoader(dataset, batch_size=BATCH_SIZE, num_workers=4, collate_fn=dataset.collate_fn, shuffle=False, drop_last=False)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
|
@ -137,73 +136,92 @@
|
|||
"import pickle\n",
|
||||
"\n",
|
||||
"file_idxs = []\n",
|
||||
"metadata = []\n",
|
||||
"losses = []\n",
|
||||
"postnet_losses = []\n",
|
||||
"criterion = L1LossMasked()\n",
|
||||
"for data in tqdm(loader):\n",
|
||||
" # setup input data\n",
|
||||
" text_input = data[0]\n",
|
||||
" text_lengths = data[1]\n",
|
||||
" linear_input = data[2]\n",
|
||||
" mel_input = data[3]\n",
|
||||
" mel_lengths = data[4]\n",
|
||||
" stop_targets = data[5]\n",
|
||||
" item_idx = data[6]\n",
|
||||
" \n",
|
||||
" # dispatch data to GPU\n",
|
||||
" if use_cuda:\n",
|
||||
" text_input = text_input.cuda()\n",
|
||||
" text_lengths = text_lengths.cuda()\n",
|
||||
" mel_input = mel_input.cuda()\n",
|
||||
" mel_lengths = mel_lengths.cuda()\n",
|
||||
"# linear_input = linear_input.cuda()\n",
|
||||
" stop_targets = stop_targets.cuda()\n",
|
||||
" \n",
|
||||
" mask = sequence_mask(text_lengths)\n",
|
||||
" mel_outputs, postnet_outputs, alignments, stop_tokens = model.forward(text_input, text_lengths, mel_input)\n",
|
||||
" \n",
|
||||
" # compute mel specs from linear spec if model is Tacotron\n",
|
||||
" mel_specs = []\n",
|
||||
" if C.model == \"Tacotron\":\n",
|
||||
" postnet_outputs = postnet_outputs.data.cpu().numpy()\n",
|
||||
" for b in range(postnet_outputs.shape[0]):\n",
|
||||
" postnet_output = postnet_outputs[b]\n",
|
||||
" mel_specs.append(torch.FloatTensor(ap.out_linear_to_mel(postnet_output.T).T).cuda())\n",
|
||||
" postnet_outputs = torch.stack(mel_specs)\n",
|
||||
" \n",
|
||||
" loss = criterion(mel_outputs, mel_input, mel_lengths)\n",
|
||||
" loss_postnet = criterion(postnet_outputs, mel_input, mel_lengths)\n",
|
||||
" losses.append(loss.item())\n",
|
||||
" postnet_losses.append(loss_postnet.item())\n",
|
||||
"with torch.no_grad():\n",
|
||||
" for data in tqdm(loader):\n",
|
||||
" # setup input data\n",
|
||||
" text_input = data[0]\n",
|
||||
" text_lengths = data[1]\n",
|
||||
" linear_input = data[3]\n",
|
||||
" mel_input = data[4]\n",
|
||||
" mel_lengths = data[5]\n",
|
||||
" stop_targets = data[6]\n",
|
||||
" item_idx = data[7]\n",
|
||||
"\n",
|
||||
" # dispatch data to GPU\n",
|
||||
" if use_cuda:\n",
|
||||
" text_input = text_input.cuda()\n",
|
||||
" text_lengths = text_lengths.cuda()\n",
|
||||
" mel_input = mel_input.cuda()\n",
|
||||
" mel_lengths = mel_lengths.cuda()\n",
|
||||
"\n",
|
||||
" mask = sequence_mask(text_lengths)\n",
|
||||
" mel_outputs, postnet_outputs, alignments, stop_tokens = model.forward(text_input, text_lengths, mel_input)\n",
|
||||
" \n",
|
||||
" # compute loss\n",
|
||||
" loss = criterion(mel_outputs, mel_input, mel_lengths)\n",
|
||||
" loss_postnet = criterion(postnet_outputs, mel_input, mel_lengths)\n",
|
||||
" losses.append(loss.item())\n",
|
||||
" postnet_losses.append(loss_postnet.item())\n",
|
||||
"\n",
|
||||
" # compute mel specs from linear spec if model is Tacotron\n",
|
||||
" if C.model == \"Tacotron\":\n",
|
||||
" mel_specs = []\n",
|
||||
" postnet_outputs = postnet_outputs.data.cpu().numpy()\n",
|
||||
" for b in range(postnet_outputs.shape[0]):\n",
|
||||
" postnet_output = postnet_outputs[b]\n",
|
||||
" mel_specs.append(torch.FloatTensor(ap.out_linear_to_mel(postnet_output.T).T).cuda())\n",
|
||||
" postnet_outputs = torch.stack(mel_specs)\n",
|
||||
" elif C.model == \"Tacotron2\":\n",
|
||||
" postnet_outputs = postnet_outputs.detach().cpu().numpy()\n",
|
||||
" alignments = alignments.detach().cpu().numpy()\n",
|
||||
"\n",
|
||||
" if not DRY_RUN:\n",
|
||||
" for idx in range(text_input.shape[0]):\n",
|
||||
" wav_file_path = item_idx[idx]\n",
|
||||
" wav = ap.load_wav(wav_file_path)\n",
|
||||
" file_name, wavq_path, mel_path, wav_path = set_filename(wav_file_path, OUT_PATH)\n",
|
||||
" file_idxs.append(file_name)\n",
|
||||
"\n",
|
||||
" # quantize and save wav\n",
|
||||
" if QUANTIZED_WAV:\n",
|
||||
" wavq = ap.quantize(wav)\n",
|
||||
" np.save(wavq_path, wavq)\n",
|
||||
"\n",
|
||||
" # save TTS mel\n",
|
||||
" mel = postnet_outputs[idx]\n",
|
||||
" mel_length = mel_lengths[idx]\n",
|
||||
" mel = mel[:mel_length, :].T\n",
|
||||
" np.save(mel_path, mel)\n",
|
||||
"\n",
|
||||
" metadata.append([wav_file_path, mel_path])\n",
|
||||
"\n",
|
||||
" # for wavernn\n",
|
||||
" if not DRY_RUN:\n",
|
||||
" for idx in range(text_input.shape[0]):\n",
|
||||
" wav_file_path = item_idx[idx]\n",
|
||||
" wav = ap.load_wav(wav_file_path)\n",
|
||||
" file_name, wavq_path, mel_path, wav_path = set_filename(wav_file_path, OUT_PATH)\n",
|
||||
" file_idxs.append(file_name)\n",
|
||||
"\n",
|
||||
"# # quantize and save wav\n",
|
||||
"# wavq = ap.quantize(wav)\n",
|
||||
"# np.save(wavq_path, wavq)\n",
|
||||
"\n",
|
||||
" # save TTS mel\n",
|
||||
" mel = postnet_outputs[idx]\n",
|
||||
" mel = mel.data.cpu().numpy()\n",
|
||||
" mel_length = mel_lengths[idx]\n",
|
||||
" mel = mel[:mel_length, :].T\n",
|
||||
" np.save(mel_path, mel)\n",
|
||||
"\n",
|
||||
" # save GL voice\n",
|
||||
" # wav_gen = ap.inv_mel_spectrogram(mel.T) # mel to wav\n",
|
||||
" # wav_gen = ap.quantize(wav_gen)\n",
|
||||
" # np.save(wav_path, wav_gen)\n",
|
||||
"\n",
|
||||
"if not DRY_RUN:\n",
|
||||
" pickle.dump(file_idxs, open(OUT_PATH+\"/dataset_ids.pkl\", \"wb\")) \n",
|
||||
" pickle.dump(file_idxs, open(OUT_PATH+\"/dataset_ids.pkl\", \"wb\")) \n",
|
||||
" \n",
|
||||
" # for pwgan\n",
|
||||
" with open(os.path.join(OUT_PATH, \"metadata.txt\"), \"w\") as f:\n",
|
||||
" for data in metadata:\n",
|
||||
" f.write(f\"{data[0]}|{data[1]+'.npy'}\\n\")\n",
|
||||
"\n",
|
||||
"print(np.mean(losses))\n",
|
||||
"print(np.mean(postnet_losses))"
|
||||
" print(np.mean(losses))\n",
|
||||
" print(np.mean(postnet_losses))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# for pwgan\n",
|
||||
"with open(os.path.join(OUT_PATH, \"metadata.txt\"), \"w\") as f:\n",
|
||||
" for data in metadata:\n",
|
||||
" f.write(f\"{data[0]}|{data[1]+'.npy'}\\n\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -219,8 +237,9 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# plot posnet output\n",
|
||||
"idx = 1\n",
|
||||
"mel_example = postnet_outputs[idx].data.cpu().numpy()\n",
|
||||
"mel_example = postnet_outputs[idx]\n",
|
||||
"plot_spectrogram(mel_example[:mel_lengths[idx], :], ap);\n",
|
||||
"print(mel_example[:mel_lengths[1], :].shape)"
|
||||
]
|
||||
|
@ -231,6 +250,7 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# plot decoder output\n",
|
||||
"mel_example = mel_outputs[idx].data.cpu().numpy()\n",
|
||||
"plot_spectrogram(mel_example[:mel_lengths[idx], :], ap);\n",
|
||||
"print(mel_example[:mel_lengths[1], :].shape)"
|
||||
|
@ -242,6 +262,7 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# plot GT specgrogram\n",
|
||||
"wav = ap.load_wav(item_idx[idx])\n",
|
||||
"melt = ap.melspectrogram(wav)\n",
|
||||
"print(melt.shape)\n",
|
||||
|
@ -278,13 +299,6 @@
|
|||
"plt.colorbar()\n",
|
||||
"plt.tight_layout()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
|
Загрузка…
Ссылка в новой задаче