From 5750090fcd4a6718608c7357e3890f5b2a545ae0 Mon Sep 17 00:00:00 2001 From: Eren Golge Date: Thu, 22 Mar 2018 12:34:16 -0700 Subject: [PATCH] Stop token prediction - does train yet --- config.json | 6 +++--- datasets/LJSpeech.py | 23 +++++++++++++++++----- layers/.tacotron.py.swo | Bin 28672 -> 0 bytes layers/tacotron.py | 40 ++++++++++++++++++++++++-------------- models/.tacotron.py.swo | Bin 12288 -> 0 bytes models/tacotron.py | 7 +++++-- tests/layers_tests.py | 11 ++++++++--- tests/loader_tests.py | 10 +++++++--- train.py | 42 +++++++++++++++++++++++++++++----------- utils/data.py | 23 ++++++++++++++++++++++ 10 files changed, 121 insertions(+), 41 deletions(-) delete mode 100644 layers/.tacotron.py.swo delete mode 100644 models/.tacotron.py.swo diff --git a/config.json b/config.json index ffea446..285e1d8 100644 --- a/config.json +++ b/config.json @@ -12,16 +12,16 @@ "text_cleaner": "english_cleaners", "epochs": 2000, - "lr": 0.001, + "lr": 0.0003, "warmup_steps": 4000, "batch_size": 32, - "eval_batch_size": 32, + "eval_batch_size":32, "r": 5, "griffin_lim_iters": 60, "power": 1.5, - "num_loader_workers": 12, + "num_loader_workers": 8, "checkpoint": false, "save_step": 69, diff --git a/datasets/LJSpeech.py b/datasets/LJSpeech.py index 5334e1c..fb6c930 100644 --- a/datasets/LJSpeech.py +++ b/datasets/LJSpeech.py @@ -7,7 +7,8 @@ from torch.utils.data import Dataset from TTS.utils.text import text_to_sequence from TTS.utils.audio import AudioProcessor -from TTS.utils.data import prepare_data, pad_data, pad_per_step +from TTS.utils.data import (prepare_data, pad_data, pad_per_step, + prepare_tensor, prepare_stop_target) class LJSpeechDataset(Dataset): @@ -93,15 +94,26 @@ class LJSpeechDataset(Dataset): text_lenghts = np.array([len(x) for x in text]) max_text_len = np.max(text_lenghts) + linear = [self.ap.spectrogram(w).astype('float32') for w in wav] + mel = [self.ap.melspectrogram(w).astype('float32') for w in wav] + mel_lengths = [m.shape[1] for m in mel] + + # compute 'stop token' targets + stop_targets = [np.array([0.]*mel_len) for mel_len in mel_lengths] + # PAD sequences with largest length of the batch text = prepare_data(text).astype(np.int32) wav = prepare_data(wav) - linear = np.array([self.ap.spectrogram(w).astype('float32') for w in wav]) - mel = np.array([self.ap.melspectrogram(w).astype('float32') for w in wav]) + # PAD features with largest length of the batch + linear = prepare_tensor(linear) + mel = prepare_tensor(mel) assert mel.shape[2] == linear.shape[2] timesteps = mel.shape[2] + # PAD stop targets + stop_targets = prepare_stop_target(stop_targets, self.outputs_per_step) + # PAD with zeros that can be divided by outputs per step if (timesteps + 1) % self.outputs_per_step != 0: pad_len = self.outputs_per_step - \ @@ -112,7 +124,7 @@ class LJSpeechDataset(Dataset): linear = pad_per_step(linear, pad_len) mel = pad_per_step(mel, pad_len) - # reshape jombo + # reshape mojo linear = linear.transpose(0, 2, 1) mel = mel.transpose(0, 2, 1) @@ -121,7 +133,8 @@ class LJSpeechDataset(Dataset): text = torch.LongTensor(text) linear = torch.FloatTensor(linear) mel = torch.FloatTensor(mel) - return text, text_lenghts, linear, mel, item_idxs[0] + stop_targets = torch.FloatTensor(stop_targets) + return text, text_lenghts, linear, mel, stop_targets, item_idxs[0] raise TypeError(("batch must contain tensors, numbers, dicts or lists;\ found {}" diff --git a/layers/.tacotron.py.swo b/layers/.tacotron.py.swo deleted file mode 100644 index c637f4479218ebeffd233f9b79def064bac02fbd..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 28672 zcmeI4dyr&TS-|g-Y?KIS<$+iNb#^9Yx|5#XncW1E4IQ()F*|X?8a9&+Fsxm7y6^N{ zr{{Ka@9o{$P?#9aA6i&Q5Uap9AxH!bhCie#qQnQVBE;}06!i~Z1tP_xtQ3h*=%yYD^sJic?zcfRkOuXkns)U7X9FKu7r@%kap`^&EduYUXNi@mr0 zwC4r!XfPVejh$Z`!`{NqDa_rUH5u+7$D^MKdTFwM$?uKQcogj)`sac;*`M0)?eV$Y z>q*E@J%VP2k?o!O@@z2weaK3pfOB~VMCmOw3mS^~8MY6;X5 zs3lNK;Cn#=>3FYq4-L7{Hs<;EeNW~6BK!T0ZPHy}uLt%$wC^sTXW8#RXx~3vxxd$b zztO&bx^n*;|6^XAZSLF4^;<20S^~8MY6;X5s3lNKpq4-_fm#B!1ZoM?5~w9GMFM`` z^Cr*rypMsL|2zNu@A>d|@D8{Su7j7tkHA-+=XsC818^FeZ~=VW{gq+R$Wlq_)lsavDvVU@M}wd-45CI7 z3|HH($T*6UX7hCE8Fz2j)o~AyS$8e$_k$>*Ld(-@91|7G8n1V*xv{eHK*p?$JCG0 zFdC?IEg<*TPJ*b%DLmbjy_UMJ`WgT_MB6aW9e7=jPV0X_b|JHP)k{QO^o zx4r7ScExf!A0;%_62?$ehcn}SHtx{II?#MeC9cx z_a1mR{3^T|9)L}FCCtM!Kj3*EfIop>gg#sY2jKv`0KR&$=beXlz?sDOtf4^-}OHoECmvGfRt0(&ODjX_3GOLcoqsdq_$;M=uhT~!2Gz_^# zK`%%WKR%Z|lNF)DL`9=ioft*Ip9nVA{S6=vSE)k_NmxwT6POv7BQ z52>FpxpJey!@>_b+ROXPn z+1Mj)T&l1gv_ z+Qr?`>0~Ztmn7HXNN%Z^nbI4wk~^K%Kx$?X${fmyii30#M@HHAR@Md$qYPR`DY&_p z%irn>yF8fDW}nIHRi5KhbDTb#r3Q%P=#j-^4c)27I}1!vJ1Vhjkh8u%Mb4%!c{hcT z$aKf<(PFj>%kIT&GS1{{pWj)&)`HFDuI$U+XsH#k*1D@yc(%h)vO@jnH7kr{_=xncI#AK4l5DdwxV}}}H@7yqw@wXkr~eOa)-@sLwQ#U@ z)<0KlEyLGcrX6N(*kr7_@}Rsn)zuv%-Ne%u_=7zdBw>Hz4@HAwZ;D1Kce=Dji<9xn zp1_vE>9S2@J&D1)&n7nFEG&MK&1O4g^CB6K65L=EVBTM|sL|fKN=MqE9>w7xjHpwx zhJ8?|Y_+>8oh4<^j57=Us&vfNOgQFiw=hjBHPsE?V5uIDMnj$5@}ZWJe@poei*-9} z>kAHtX*>>0wn<|PT=Apz#Q6XJgn#`C{BZIAUHt!7@bUi@WFO%7;bDm2Fx&)J!c+M6 zABTtFUN{9m49}3)C*VnV6y69YLCSa;G+-}0LAj5>J#Ydphx7RHZ-J9A2Q9b+zJ?F~ z1iS|x0P*jC1A7M#bI{|Nj6tib@T zffvADxDY;tPyc>+0LCzaSHfLzJzNGKbpC!}Et2tuebHNMV~#YIv!R!d6>WJv`Ha6z%QnadajsHP1&Cly}4_!V>hW>_%Eq(WxJ_3}KnPIS4q)g0g5 zQZ3KA+PZi2Pjv<3O4tu$qXhk7WmR)a3e`3CT%#=5jvd|3vPVny|%vd&;T>E~K9 zDD5Y|*A&%ycR0>h@={sui>!e-&K-oGm($^K&!v#}09(M$Q*(FNs8jjf=+dAnJ8iVekm${LF6Y-L#%p4K?)(jDfU z08t+ok7G+`79}p9U4HAb#YIWEJUW>y%0*46p-3znR|Yl`K*`H+{Q&1SLDsRNG3Zd?=BgdOd9Yc#ms0+HgJI#SdjYQV`wU!haX zIggs)66NOCoU%%Tt-kc^C|*zKs(7nKR(xV%uCEp9XoyXXO4V}HZKWR4NVTC|_e07r zlp6(;7=y|dLk1gSh4s~F)wx~ThuP8RT4Y2oQ9+}0Wm+t>=V|d$J}hMA(S4DGgN;#W zc1_w#&f&@vF1my=#mqt0VmZ}z%6`j^iY>~Fi#u$&c}cq%I7@9&mFBc=H5?ARjpEj> z&OjQ5aIDH@bcpWEbzr5q#X^Sm$S)s_a%d;b-(7Zk?AlTCZ4KW_1zk70Rcgnc)=gfC zidM7?Zdt1Jtz0)0lVChl?v~v*$wFRPHby!ytnCl2Xkqch*W{`!_;1Fo|W%pJ@-a8}(i8V*NiWn-M(sALqc^6g65Z(a4- zqeoFXS28Qs6<>X~+SQ~F)cT>(aL6M_EJr*Nu*ddtD@L5x>RGy z=<4#v3@5YaxXmaM5mxp9O_sJ9#Z5`4uuO*K-6}a-QtfQFRE4Z&8~{!|d-AqPyk+0E zwPm}NQ+O#zyLhgQC@NfR`;%0TP&ivKHE@9vp8c*L^=4Ub#qK?%^nVDZ&;K9Chy66Z zt@!`$_xFE`fB#2tKl}oW;4U}?i!cMvgQxNLAAz?)4_4q7_$vPW!>|l;-v1)_cl`Tb zhqpll*TGZx_V0%`!6F=kD_|DBh)@5g@B#QuI1A$M_h1FY@4pQG8{hss+z;#UDv)#i z7s8Li=ke=52j}4pun5QC3Ydi%_&V*^GC&OVtCm15fm#B!1ZoNVzm$NOLH2M`ybU50 zv_q+0))6P8+4eb+Z+c<@g$Yp-{qAU$7#WExw8@P#GT%Nh?Yz_Eh`Hwv!$Bl2OkoSv z-o!#PyV1l;_SS8#l$;;B%(vJGYU6gL;b1bF6wku)d@I@HH@4*DNt@-BAK&ObRi#e_ z^DQoWP-4T5#L6ak)mB=EX=SmJ`7IYC`h6!Q+$D7Bat2Ll|b*k1o0r~&*b*J%Q+aDN*U)tHuYR6 zgbwv6SI|8)y*f6d1s&0_&4^*3xJ=}oM~sLI4vAbRvZAB-&!XkCd*!_d?Pe%O;xy`i$f8N|C`V_P}7i{;9iqAA5m zQ#)SeG*X+=?LzWA+BjJb`wS7Ss5a;>UIgSNY{uu9IWZXR-3#M(&L(!A8!#GW`lDI z(c3MxC4jxD10&1rbyd>+<{D_X#LEFgSNJ`CNsa zT?>hRl~9uWTD7QByna?OiFKOK+ngygeUOtE>$DlltfIzDd=t5t57iYjD~L+lq^0m% z5sJKK)e`5MX7y8zY+Y(8qD=A-Ds%pS5$Dbk=hNc<-(d}&FW}q%2Yenr1oy&en1gTN z-~R=ig`?oXKj7bg7Cs0M!RtZ%{+r;(;YPRup26pT5*~#6UWxB)JP$MN}p30?&= zun#`NUdVglES!O5cs^VX|H9e+hv0qiDBK6Da1*>3UIZ7xUicjH{BsbwE{bgHS1p0> z1qm4SVopx-6?u7^%dCke1j@#t*kp3-Qh2zUipZuqTBVLw5xcVRGs8MsrH)pqqg9Ls z)jqCKo83BErH)q7ze2L-$GbipuA^0^o&ndEa=%m2Dt1R~ khfoApbTNLB00kXlng#K+?5~xwZaZqM?j7T*Ou&fuzbgW+1^@s6 diff --git a/layers/tacotron.py b/layers/tacotron.py index c0828d0..e9a40b2 100644 --- a/layers/tacotron.py +++ b/layers/tacotron.py @@ -5,6 +5,7 @@ from torch import nn from .attention import AttentionRNN from .attention import get_mask_from_lengths +from .custom_layers import StopProjection class Prenet(nn.Module): r""" Prenet as explained at https://arxiv.org/abs/1703.10135. @@ -214,8 +215,9 @@ class Decoder(nn.Module): r (int): number of outputs per time step. eps (float): threshold for detecting the end of a sentence. """ - def __init__(self, in_features, memory_dim, r, eps=0.05): + def __init__(self, in_features, memory_dim, r, eps=0.05, mode='train'): super(Decoder, self).__init__() + self.mode = mode self.max_decoder_steps = 200 self.memory_dim = memory_dim self.eps = eps @@ -231,6 +233,8 @@ class Decoder(nn.Module): [nn.GRUCell(256, 256) for _ in range(2)]) # RNN_state -> |Linear| -> mel_spec self.proj_to_mel = nn.Linear(256, memory_dim * r) + # RNN_state | attention_context -> |Linear| -> stop_token + self.stop_token = StopProjection(256 + in_features, r) def forward(self, inputs, memory=None): """ @@ -252,10 +256,9 @@ class Decoder(nn.Module): B = inputs.size(0) # Run greedy decoding if memory is None - greedy = memory is None + greedy = ~self.training if memory is not None: - # Grouping multiple frames if necessary if memory.size(-1) == self.memory_dim: memory = memory.view(B, memory.size(1) // self.r, -1) @@ -283,6 +286,7 @@ class Decoder(nn.Module): outputs = [] alignments = [] + stop_outputs = [] t = 0 memory_input = initial_memory @@ -292,11 +296,12 @@ class Decoder(nn.Module): memory_input = outputs[-1] else: # combine prev. model output and prev. real target - memory_input = torch.div(outputs[-1] + memory[t-1], 2.0) + # memory_input = torch.div(outputs[-1] + memory[t-1], 2.0) # add a random noise - noise = torch.autograd.Variable( - memory_input.data.new(memory_input.size()).normal_(0.0, 0.5)) - memory_input = memory_input + noise + # noise = torch.autograd.Variable( + # memory_input.data.new(memory_input.size()).normal_(0.0, 0.5)) + # memory_input = memory_input + noise + memory_input = memory[t-1] # Prenet processed_memory = self.prenet(memory_input) @@ -316,35 +321,42 @@ class Decoder(nn.Module): decoder_input, decoder_rnn_hiddens[idx]) # Residual connectinon decoder_input = decoder_rnn_hiddens[idx] + decoder_input - + output = decoder_input + stop_token_input = decoder_input + + # stop token prediction + stop_token_input = torch.cat((output, current_context_vec), -1) + stop_output = self.stop_token(stop_token_input) # predict mel vectors from decoder vectors output = self.proj_to_mel(output) outputs += [output] alignments += [alignment] + stop_outputs += [stop_output] t += 1 - if greedy: + if (not greedy and self.training) or (greedy and memory is not None): + if t >= T_decoder: + break + else: if t > 1 and is_end_of_frames(output, self.eps): break elif t > self.max_decoder_steps: print(" !! Decoder stopped with 'max_decoder_steps'. \ Something is probably wrong.") break - else: - if t >= T_decoder: - break - + assert greedy or len(outputs) == T_decoder # Back to batch first alignments = torch.stack(alignments).transpose(0, 1) outputs = torch.stack(outputs).transpose(0, 1).contiguous() + stop_outputs = torch.stack(stop_outputs).transpose(0, 1).contiguous() - return outputs, alignments + return outputs, alignments, stop_outputs def is_end_of_frames(output, eps=0.2): #0.2 diff --git a/models/.tacotron.py.swo b/models/.tacotron.py.swo deleted file mode 100644 index b4cfd7c58104b01c73bb3f9d66bfc10d8e7c0613..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 12288 zcmeI2O>Epm6vro&uL5lJo zRpkIjZasknhf08i;070j#Dxz52e@$KP^q9wz#&%-Nc?Tj*xprX5(y!J#?r65<9Rdl z-hZC`ys1~(ryn^^?{j7ejy;6@c!fQCxUrACd4>?i(?J?#9oIhM$nRX+qUb$eOxPOo z^abWi(OUNWRPr=w#c9AI(UPj|4A-|ZpRq9t7zMVdz)1EF&F>;J-A+3j(YWgXz2o*% zTO4f6i~>dhqkvJsC}0#Y3K#{90!D%Vg936mMP5Z9*QsE->T_!1Q#H+pQNSo*6fg=H z1&jhl0i%FXz$jo8FbWt2i~|2b1-yWey*CpwwG+58F)8J3|eHpwD-U9;M59UE06Fz@Hd>6s@;8Sn`JPuBP8Swi} z@DIEJd~g)(2OBpM@)1bD6JQ_MxPg!_!582>SOiA^0l!>N$j{(2a1jK+2aDiNuoGO_ zfq21N;1pN}ZO{UH!4B{>{uuKWxB#93PlE%1g8krMJ*YZ3lOvwcm_@yGB!?p@Y#N0L z^LW+vqHvJJERh1vs$<-Sx>T^J??j%Ey`sw+w`guMa+nI4Flp5sX;!P%&SI%r%Xv9u zYu2Jo4|VLqg{(6O<7&MI&H9(^eAqRPe7xLJ6G?h}z%un>7N>l@7qMg@SA{#?tlX*vVbe-stz^LR`K?km6~;Yf90z$d zS?1CBLYU`7ny_*>-os%FW$UW;VS??~ewU_UztD9gSWhhxd*a1~N{$IBbOh?Ol}`&@ z#-^`aui^@m@r^46mea*6EC{e0YQkD}EW)NRc0~9Jv)aY%0qfI#%GW#|SXp)|7vPH8 zD>qYU!m<2ZvliWCUSq8!aURPrkvvRD}@ww;;!b}jSu?3xhhju@vx24${fY4!KV(98F)7Y1kD+V^Vi z!g=j#A@q`NTDYJru^z9ak!aeq^g4^kCVJI;iiuRkL=$Jm&DD)6+ecU5`_LFSQewkG zeXF`hL&h!r&RcO8YC64Mn1r&|v#Rj4kCWcG$2?q|$hp!$f6*D`cG}%J?6Zr1YRVP) zn-9x!y<9RAVL~A{&#fv_HKU5fop0O3$IV+%9fyS^Kb{7oD07VTc^a2pxXZ6IE*#V= z?dbNa$Tr7rcctyxoBFm{ealis44H2#oWM<>exLp4E1bXTanF~y$ZAy}t8!JxCY7!- zEGkOAbEe*9&!}bPF!j#VyRxWlq2B4&*l8FKQ!eRA&hR_I)(=^>elWB12fdbilnp_n!86=&eMIC@=D%q(Sj>0gq)1G=CwO F`~}aao>l+= diff --git a/models/tacotron.py b/models/tacotron.py index 7653f1c..a8b04fb 100644 --- a/models/tacotron.py +++ b/models/tacotron.py @@ -11,6 +11,7 @@ class Tacotron(nn.Module): freq_dim=1025, r=5, padding_idx=None): super(Tacotron, self).__init__() + self.r = r self.mel_dim = mel_dim self.linear_dim = linear_dim self.embedding = nn.Embedding(len(symbols), embedding_dim, @@ -26,6 +27,7 @@ class Tacotron(nn.Module): self.last_linear = nn.Linear(mel_dim * 2, freq_dim) def forward(self, characters, mel_specs=None): + B = characters.size(0) inputs = self.embedding(characters) @@ -33,7 +35,7 @@ class Tacotron(nn.Module): encoder_outputs = self.encoder(inputs) # (B, T', mel_dim*r) - mel_outputs, alignments = self.decoder( + mel_outputs, alignments, stop_outputs = self.decoder( encoder_outputs, mel_specs) # Post net processing below @@ -41,8 +43,9 @@ class Tacotron(nn.Module): # Reshape # (B, T, mel_dim) mel_outputs = mel_outputs.view(B, -1, self.mel_dim) + stop_outputs = stop_outputs.view(B, -1) linear_outputs = self.postnet(mel_outputs) linear_outputs = self.last_linear(linear_outputs) - return mel_outputs, linear_outputs, alignments + return mel_outputs, linear_outputs, alignments, stop_outputs diff --git a/tests/layers_tests.py b/tests/layers_tests.py index 3fbab02..e8ebba0 100644 --- a/tests/layers_tests.py +++ b/tests/layers_tests.py @@ -37,18 +37,23 @@ class DecoderTests(unittest.TestCase): dummy_memory = T.autograd.Variable(T.rand(4, 120, 32)) print(layer) - output, alignment = layer(dummy_input, dummy_memory) + output, alignment, stop_output = layer(dummy_input, dummy_memory) print(output.shape) + print(" > Stop ", stop_output.shape) + assert output.shape[0] == 4 assert output.shape[1] == 120 / 5 assert output.shape[2] == 32 * 5 - + assert stop_output.shape[0] == 4 + assert stop_output.shape[1] == 120 / 5 + assert stop_output.shape[2] == 5 + class EncoderTests(unittest.TestCase): def test_in_out(self): layer = Encoder(128) - dummy_input = T.autograd.Variable(T.rand(4, 8, 128)) + dummy_input = T.autograd.Variable(T.rand(4, 8, 128)) print(layer) output = layer(dummy_input) diff --git a/tests/loader_tests.py b/tests/loader_tests.py index fdecd6e..dc023b6 100644 --- a/tests/loader_tests.py +++ b/tests/loader_tests.py @@ -32,7 +32,7 @@ class TestDataset(unittest.TestCase): c.power ) - dataloader = DataLoader(dataset, batch_size=c.batch_size, + dataloader = DataLoader(dataset, batch_size=2, shuffle=True, collate_fn=dataset.collate_fn, drop_last=True, num_workers=c.num_loader_workers) @@ -43,7 +43,8 @@ class TestDataset(unittest.TestCase): text_lengths = data[1] linear_input = data[2] mel_input = data[3] - item_idx = data[4] + stop_targets = data[4] + item_idx = data[5] neg_values = text_input[text_input < 0] check_count = len(neg_values) @@ -81,13 +82,16 @@ class TestDataset(unittest.TestCase): text_lengths = data[1] linear_input = data[2] mel_input = data[3] - item_idx = data[4] + stop_target = data[4] + item_idx = data[5] # check the last time step to be zero padded assert mel_input[0, -1].sum() == 0 assert mel_input[0, -2].sum() != 0 assert linear_input[0, -1].sum() == 0 assert linear_input[0, -2].sum() != 0 + assert stop_target[0, -1] == 1 + assert stop_target.sum() == 1 diff --git a/train.py b/train.py index 5302761..3531771 100644 --- a/train.py +++ b/train.py @@ -63,11 +63,12 @@ def signal_handler(signal, frame): sys.exit(1) -def train(model, criterion, data_loader, optimizer, epoch): +def train(model, criterion, critetion_stop, data_loader, optimizer, epoch): model = model.train() epoch_time = 0 avg_linear_loss = 0 avg_mel_loss = 0 + avg_stop_loss = 0 print(" | > Epoch {}/{}".format(epoch, c.epochs)) progbar = Progbar(len(data_loader.dataset) / c.batch_size) @@ -80,6 +81,7 @@ def train(model, criterion, data_loader, optimizer, epoch): text_lengths = data[1] linear_input = data[2] mel_input = data[3] + stop_targets = data[4] current_step = num_iter + args.restore_step + epoch * len(data_loader) + 1 @@ -93,6 +95,7 @@ def train(model, criterion, data_loader, optimizer, epoch): # convert inputs to variables text_input_var = Variable(text_input) mel_spec_var = Variable(mel_input) + stop_targets_var = Variable(stop_targets) linear_spec_var = Variable(linear_input, volatile=True) # sort sequence by length for curriculum learning @@ -109,9 +112,10 @@ def train(model, criterion, data_loader, optimizer, epoch): text_input_var = text_input_var.cuda() mel_spec_var = mel_spec_var.cuda() linear_spec_var = linear_spec_var.cuda() + stop_targets_var = stop_targets_var.cuda() # forward pass - mel_output, linear_output, alignments =\ + mel_output, linear_output, alignments, stop_output =\ model.forward(text_input_var, mel_spec_var) # loss computation @@ -119,7 +123,8 @@ def train(model, criterion, data_loader, optimizer, epoch): linear_loss = 0.5 * criterion(linear_output, linear_spec_var) \ + 0.5 * criterion(linear_output[:, :, :n_priority_freq], linear_spec_var[: ,: ,:n_priority_freq]) - loss = mel_loss + linear_loss + stop_loss = critetion_stop(stop_output, stop_targets_var) + loss = mel_loss + linear_loss + 0.25*stop_loss # backpass and check the grad norm loss.backward() @@ -136,6 +141,7 @@ def train(model, criterion, data_loader, optimizer, epoch): # update progbar.update(num_iter+1, values=[('total_loss', loss.data[0]), ('linear_loss', linear_loss.data[0]), + ('stop_loss', stop_loss.data[0]), ('mel_loss', mel_loss.data[0]), ('grad_norm', grad_norm)]) @@ -144,6 +150,7 @@ def train(model, criterion, data_loader, optimizer, epoch): tb.add_scalar('TrainIterLoss/LinearLoss', linear_loss.data[0], current_step) tb.add_scalar('TrainIterLoss/MelLoss', mel_loss.data[0], current_step) + tb.add_scalar('TrainIterLoss/StopLoss', stop_loss.data[0], current_step) tb.add_scalar('Params/LearningRate', optimizer.param_groups[0]['lr'], current_step) tb.add_scalar('Params/GradNorm', grad_norm, current_step) @@ -184,19 +191,21 @@ def train(model, criterion, data_loader, optimizer, epoch): avg_linear_loss /= (num_iter + 1) avg_mel_loss /= (num_iter + 1) - avg_total_loss = avg_mel_loss + avg_linear_loss + avg_stop_loss /= (num_iter + 1) + avg_total_loss = avg_mel_loss + avg_linear_loss + 0.25*avg_stop_loss # Plot Training Epoch Stats tb.add_scalar('TrainEpochLoss/TotalLoss', loss.data[0], current_step) tb.add_scalar('TrainEpochLoss/LinearLoss', linear_loss.data[0], current_step) tb.add_scalar('TrainEpochLoss/MelLoss', mel_loss.data[0], current_step) + tb.add_scalar('TrainEpochLoss/StopLoss', stop_loss.data[0], current_step) tb.add_scalar('Time/EpochTime', epoch_time, epoch) epoch_time = 0 return avg_linear_loss, current_step -def evaluate(model, criterion, data_loader, current_step): +def evaluate(model, criterion, criterion_stop, data_loader, current_step): model = model.eval() epoch_time = 0 @@ -206,6 +215,7 @@ def evaluate(model, criterion, data_loader, current_step): avg_linear_loss = 0 avg_mel_loss = 0 + avg_stop_loss = 0 for num_iter, data in enumerate(data_loader): start_time = time.time() @@ -215,38 +225,44 @@ def evaluate(model, criterion, data_loader, current_step): text_lengths = data[1] linear_input = data[2] mel_input = data[3] + stop_targets = data[4] # convert inputs to variables text_input_var = Variable(text_input) mel_spec_var = Variable(mel_input) linear_spec_var = Variable(linear_input, volatile=True) + stop_targets_var = Variable(stop_targets) # dispatch data to GPU if use_cuda: text_input_var = text_input_var.cuda() mel_spec_var = mel_spec_var.cuda() linear_spec_var = linear_spec_var.cuda() + stop_targets_var = stop_targets_var.cuda() # forward pass - mel_output, linear_output, alignments = model.forward(text_input_var) + mel_output, linear_output, alignments, stop_output = model.forward(text_input_var, mel_spec_var) # loss computation mel_loss = criterion(mel_output, mel_spec_var) linear_loss = 0.5 * criterion(linear_output, linear_spec_var) \ + 0.5 * criterion(linear_output[:, :, :n_priority_freq], linear_spec_var[: ,: ,:n_priority_freq]) - loss = mel_loss + linear_loss + stop_loss = criterion_stop(stop_output, stop_targets_var) + loss = mel_loss + linear_loss + 0.25*stop_loss step_time = time.time() - start_time epoch_time += step_time # update progbar.update(num_iter+1, values=[('total_loss', loss.data[0]), + ('stop_loss', stop_loss.data[0]), ('linear_loss', linear_loss.data[0]), ('mel_loss', mel_loss.data[0])]) avg_linear_loss += linear_loss.data[0] avg_mel_loss += mel_loss.data[0] + avg_stop_loss += stop_loss.data[0] # Diagnostic visualizations idx = np.random.randint(mel_input.shape[0]) @@ -278,12 +294,14 @@ def evaluate(model, criterion, data_loader, current_step): # compute average losses avg_linear_loss /= (num_iter + 1) avg_mel_loss /= (num_iter + 1) - avg_total_loss = avg_mel_loss + avg_linear_loss + avg_stop_loss /= (num_iter + 1) + avg_total_loss = avg_mel_loss + avg_linear_loss + 0.25*avg_stop_loss # Plot Learning Stats tb.add_scalar('ValEpochLoss/TotalLoss', avg_total_loss, current_step) tb.add_scalar('ValEpochLoss/LinearLoss', avg_linear_loss, current_step) tb.add_scalar('ValEpochLoss/MelLoss', avg_mel_loss, current_step) + tb.add_scalar('ValEpochLoss/StopLoss', avg_stop_loss, current_step) return avg_linear_loss @@ -336,13 +354,15 @@ def main(args): c.num_mels, c.num_freq, c.r) - + optimizer = optim.Adam(model.parameters(), lr=c.lr) if use_cuda: criterion = nn.L1Loss().cuda() + criterion_stop = nn.BCELoss().cuda() else: criterion = nn.L1Loss() + criterion_stop = nn.BCELoss() if args.restore_path: checkpoint = torch.load(args.restore_path) @@ -370,8 +390,8 @@ def main(args): best_loss = float('inf') for epoch in range(0, c.epochs): - train_loss, current_step = train(model, criterion, train_loader, optimizer, epoch) - val_loss = evaluate(model, criterion, val_loader, current_step) + train_loss, current_step = train(model, criterion, criterion_stop, train_loader, optimizer, epoch) + val_loss = evaluate(model, criterion, criterion_stop, val_loader, current_step) best_loss = save_best_model(model, optimizer, val_loss, best_loss, OUT_PATH, current_step, epoch) diff --git a/utils/data.py b/utils/data.py index a38092e..022fab1 100644 --- a/utils/data.py +++ b/utils/data.py @@ -14,6 +14,29 @@ def prepare_data(inputs): return np.stack([pad_data(x, max_len) for x in inputs]) +def pad_tensor(x, length): + _pad = 0 + assert x.ndim == 2 + return np.pad(x, [[0, 0], [0, length- x.shape[1]]], mode='constant', constant_values=_pad) + + +def prepare_tensor(inputs): + max_len = max((x.shape[1] for x in inputs)) + return np.stack([pad_tensor(x, max_len) for x in inputs]) + + +def pad_stop_target(x, length): + _pad = 1. + assert x.ndim == 1 + return np.pad(x, (0, length - x.shape[0]), mode='constant', constant_values=_pad) + + +def prepare_stop_target(inputs, out_steps): + max_len = max((x.shape[0] for x in inputs)) + remainder = max_len % out_steps + return np.stack([pad_stop_target(x, max_len + out_steps - remainder) for x in inputs]) + + def pad_per_step(inputs, pad_len): timesteps = inputs.shape[-1] return np.pad(inputs, [[0, 0], [0, 0],