From 822afbabc60c29625c2f05417341029484bd36dc Mon Sep 17 00:00:00 2001 From: czy97 Date: Mon, 22 Nov 2021 16:25:24 +0800 Subject: [PATCH] add diarization demo for unispeech_sat pre-training model --- UniSpeech-SAT/speaker_diarization/README.md | 34 ++ .../config/infer_est_nspk1.yaml | 31 ++ .../config/unispeech_sat.th | Bin 0 -> 20719 bytes .../speaker_diarization/diarization.py | 321 ++++++++++++ .../speaker_diarization/models/models.py | 391 ++++++++++++++ .../speaker_diarization/models/transformer.py | 147 ++++++ .../speaker_diarization/models/utils.py | 78 +++ .../speaker_diarization/requirements.txt | 11 + .../speaker_diarization/tmp/mix_0000496.wav | 1 + .../speaker_diarization/utils/dataset.py | 484 ++++++++++++++++++ .../speaker_diarization/utils/kaldi_data.py | 162 ++++++ .../utils/parse_options.sh | 97 ++++ .../speaker_diarization/utils/utils.py | 189 +++++++ 13 files changed, 1946 insertions(+) create mode 100644 UniSpeech-SAT/speaker_diarization/README.md create mode 100644 UniSpeech-SAT/speaker_diarization/config/infer_est_nspk1.yaml create mode 100644 UniSpeech-SAT/speaker_diarization/config/unispeech_sat.th create mode 100644 UniSpeech-SAT/speaker_diarization/diarization.py create mode 100644 UniSpeech-SAT/speaker_diarization/models/models.py create mode 100644 UniSpeech-SAT/speaker_diarization/models/transformer.py create mode 100644 UniSpeech-SAT/speaker_diarization/models/utils.py create mode 100644 UniSpeech-SAT/speaker_diarization/requirements.txt create mode 120000 UniSpeech-SAT/speaker_diarization/tmp/mix_0000496.wav create mode 100644 UniSpeech-SAT/speaker_diarization/utils/dataset.py create mode 100644 UniSpeech-SAT/speaker_diarization/utils/kaldi_data.py create mode 100644 UniSpeech-SAT/speaker_diarization/utils/parse_options.sh create mode 100644 UniSpeech-SAT/speaker_diarization/utils/utils.py diff --git a/UniSpeech-SAT/speaker_diarization/README.md b/UniSpeech-SAT/speaker_diarization/README.md new file mode 100644 index 0000000..8bfeb6b --- /dev/null +++ b/UniSpeech-SAT/speaker_diarization/README.md @@ -0,0 +1,34 @@ +## Pre-training Representations for Speaker Diarization + +### Downstream Model + +[EEND-vector-clustering](https://arxiv.org/abs/2105.09040) + +### Pre-trained models + +- It should be noted that the diarization system is trained on 8k audio data. + +| Model | 2 spk DER | 3 spk DER | 4 spk DER | 5 spk DER | 6 spk DER | ALL spk DER | +| ------------------------------------------------------------ | --------- | --------- | --------- | --------- | --------- | ----------- | +| EEND-vector-clustering | 7.96 | 11.93 | 16.38 | 21.21 | 23.1 | 12.49 | +| [**UniSpeech-SAT large**](https://drive.google.com/file/d/16OwIyOk2uYm0aWtSPaS0S12xE8RxF7k_/view?usp=sharing) | 5.93 | 10.66 | 12.90 | 16.48 | 23.25 | 10.92 | + +### How to use? + +#### Environment Setup + +1. `pip install -r requirements.txt` +2. Install fairseq code + - For UniSpeech-SAT large, we should install the [Unispeech-SAT](https://github.com/microsoft/UniSpeech/tree/main/UniSpeech-SAT) fairseq code. + +#### Example + +1. First, you should download the pre-trained model in the above table to `checkpoint_path`. +2. Then, run the following codes: + - The wav file is the multi-talker simulated speech from Librispeech corpus. +3. The output will be written in `out.rttm` by default. + +```bash +python diarization.py --wav_path tmp/mix_0000496.wav --model_init $checkpoint_path +``` + diff --git a/UniSpeech-SAT/speaker_diarization/config/infer_est_nspk1.yaml b/UniSpeech-SAT/speaker_diarization/config/infer_est_nspk1.yaml new file mode 100644 index 0000000..60daa1e --- /dev/null +++ b/UniSpeech-SAT/speaker_diarization/config/infer_est_nspk1.yaml @@ -0,0 +1,31 @@ +# Copyright (c) 2021 Nippon Telegraph and Telephone corporation (NTT). +# All rights reserved + +# inference options +est_nspk: 1 +sil_spk_th: 0.05 +ahc_dis_th: 1.0 +clink_dis: 1.0e+4 +model: + n_speakers: 3 + all_n_speakers: 0 + feat_dim: 1024 + n_units: 256 + n_heads: 8 + n_layers: 6 + dropout_rate: 0.1 + spk_emb_dim: 256 + sr: 8000 + frame_shift: 320 + frame_size: 200 + context_size: 0 + subsampling: 1 + feat_type: "config/unispeech_sat.th" + feature_selection: "hidden_states" + interpolate_mode: "linear" +dataset: + chunk_size: 750 + frame_shift: 320 + sampling_rate: 8000 + subsampling: 1 + num_speakers: 3 diff --git a/UniSpeech-SAT/speaker_diarization/config/unispeech_sat.th b/UniSpeech-SAT/speaker_diarization/config/unispeech_sat.th new file mode 100644 index 0000000000000000000000000000000000000000..2c498cf72b5d4a330f6ffaa6a3ff7c1c2f081eb5 GIT binary patch literal 20719 zcmch9b(~v8_x?_uVx_njFHqTdg0uw+g$}b6U!;U=HoKd%v3ruV-NHigF7EE`?(XjH z?(Qz%^UO`|w7l>8`TYL+g%5eoGc$MY=$UiQG&@5N9@5e>VnoY-d`7j5Zpm=ACqFAU zu`^T4OsMo0w%UA&mfaWn%p272wTx{aERScqx;@qJ4I0}%R9L!{DdxPvZSBL5RW26G zC2vT-H*{=|q%A3FrEuKldNs$@iL=I6Go`+4&%~a3M~-W0n_r#SkxduVjid_aYMjZJ(uE9n=h9tVE~cwl z&eXC!>1rVqPN!OX)7f&VmXE|*)3eJEIa94+(B$RbV(p8|n`dWAogL^c<+3$zalikW zxpmGJ3ut3^zSQk4F<^UE&E-11B~5CmKcVBNZRWCmEs)bPBVC0|O$IRC&6&=sx7>i8DVv!u z ztF+mH$`mW!sspx#BFfsrWMLljxvICC9%QRW&H`OpUAi<%-iqXfdNiiZE;1~bG?vmD zZMIylSgG};vzcrU7DdkU*3>PwVVOc9-JPlRNby3xDkGMkoAcH(EmPW;MyId>)0GTM zF%fL$+FIEG@#+~_MsqVY83YVVSH4?H)_S~k+S-T9v$A~bGGBMVt|{q%t`;nw^#*LS zs`ZKtqzs5G(Q1!$cdS2y()@I#kf|0k=~?;QY;XO6_C-;ppJd;%flZlMD^@1@$=-%_ zo8<4^bPoG*T)#KoZeZ*|``CmwyK5_D>8rFbD=+sZwpk6})$_TU*Xp;x?|E6~n)98~ zuu%ITsbZggFKnNeM%N3qs^8&ATiuRjrwkzHJL{f3%^_CQqG%dvpeeYXiZjjC0*?Bw$v-y5$Vd$@%K@#ny!`8 z-IaRP+iIX4Q_v+lKV2`$3hT5BTZTb&t2b@F-DX{)3)9tFM!xz=S#N89TlVBKTK74oH=e=Y>O)|=t)wII7AQ?Lqdd%rzFx<8dwZwFuB2kF_Kda1YS?I@E8MHr+@ zqvZ801MSzhQ`>*fZ`Ql>zw$eC3qQZh|HyAXziXSF+u;48%d)+dvK+o%yWPO0oNQNb zH@hz8=2#Y5j~uGlguC}l?OA4i&$0tOaw4GaWC5h*w4N=@+hf3AK$S>X zN_5s`ZHs=DGlsoAeImBHH-VS9U-@*ApzPqQEo>iWc?#_2- zJhFuwo%IX%Hq}WwTgYWfUcqWa+4&t<(hC-?dd0T(m8E6*?oyd^&CR@Uk&-_%Hc=+T znP_D_)OP$oa?rcG`?p>`+j_j|@38urcB;$rX6Y41msYJS&V zG2*MW36+|6RO3)9=P&k*Er^lw*KE~0&Cf#%Y+i?qaoUY) z)jQoU03#LFh_u~1W58N)LhZs|sQ8)l?eI&UeanJ1oi$)j-bO8HH&Xs7aJFx9{Hp>j+d-@7c>aN@~kd(*O^FL&wyo|B&*m`PVmoBXAu z%Ey|X-Qxd>J&6~k<@M}csl(RplHP8UQa`RHdjsDeB$gH?LlTl9Ypgu6WY7`rZ;2VulNS zVosDR-fg~%DT!#AJy(iF@Ad&3tB{&1QrR7DP!YiplMShNx<FZZ97+`$e?{}#xzACubE|p&I0e|fx+OS|o#e2}HbF^3-OL3|sqEV}Q54Bn4UlV5* z(XA(UmLxD^mOJqD1rnA(bjeoNBkrz*a=aJ()!E4EIu>HoGXavh%_K!2@Qwb}<|D7e?cc=GC!<>dwpDq-=SNpx! z>~2L^N8TouGAwHQL`T(o-8cRg%z#3!dT;nmMbt~W*o*?bDJ#}4#Bx^+181!;@2$f} zz{iu%{(fM=9#xqvVrAaj{oXsa0a&BH-{-IXccm_94Qsu7@3kwD$9wPBZK8i>f6(uJ zD9_qg)=T+nC6|L!2c~u`F*cv-Rz~Pa5U=tP3ax)~B|V ztmt|c+ag`-tK_`T{M{lGSjF(lB!2E!*6`a38F{|yeW6#d`n2iJ<%%6L^qu*l_oZnW zVo&O3R##Uk?XtfzEu-x5*UYw|I2#TUhgrgLU+Yn}lXlGfI8j&0`^Fv*-MK>D`_}i7 z8Y|{Ix8gCauD0fq`rZ{l;BTS))5y56unopUgOqS1CG7oSwYjaZBcG{y ze-8MIq}(lWP^e7ruW6HSI}M-HZB7jl;##^}=HYLDzcd?yV5;>4y?>l~#c*|q)zF!d z=xs4m?d29(+y2=AiW5z=DJw}8wd9|bk}Si}!Gk1m+T_m~Z@1x6JW|aSB)}}L6AzXG z_B6vj&1ZQCIIMSDbs6zV9*RmN$$pw4c|Tv&dLdKcVN#+IjYai1fF;(H&GB%_5M5_O zn#jraWJ)F3WQ{hDK=w$x3v7MOY}@ul7J4m0NDGy5eHbYP$u=CTbD=_h2NqqZCMbVV zbPJbjbYw9p>FdvhvR#K3mjcouoU?YyC7#JkAZ?xQv^?R@ybW8F%1Gy>Z7Ljd z&dFOuO{S$;d8-ht)<7tRmkFMQm+FV_s=TbEnkjz30V)^~Ljex*SNd zv!t`j+iGz5<>pl6x+3k=1(c6f#=L?Q@=t#(M;!PW8NzN}(LU@uhhhr(g;$b9QQkI$ znd^ebAX&V!pNiFw_Lg!IESH{Dc@;kcF_a+D_skH5>YFVqHpi>l7h%n0*~a~01Bq=K z)Uq0C>Gb_R38r;OjB|n9)%|xs88xEqqT21Yt@0Sj^CPH@8J(5kDz71_7)P|Xk?xy8 zye8@z=j-|a{K|sx!1_FyJTI- z36_q(bNxqU)ObBqsO-vYUQpq57xRe6N>zRMt?vR+xV$_LY+Wg7@BJ3KFcZh8)IHm#PZSkyZEqfr^5}plY8g;DgQdy|lGre)asF0$ zq9n}ZR+L5PzqiOCkuz4a`A4+euW%UoTmL`5dtz@<1{vpjqgSw`!v46YCicc!TR9>% z_Tb=03$1M|*0{_VM^!$+}IK=8Q2eq&dW}pBW?t5)yzP+d%=up$bbrmX0h9MYox#fwd#2r zVL5*kgUA9;m&AV7@{Rq0vYF0Ik+;Q@&|P3+-OFBL(@Y!NIw3Vq{!gyYZAimAZMj^f z%G*iGv`N=K{L8sFeRY)WuRn@#N-}*s1KE=gp8MXktC!iwX2ZM2O{RVC_L7aLde6Z# zc?Vmx8L0E4j0?t|ct5ac+J?({*Hn59n7}1z$@k1qTqHz`s8n=D*4!U)XEY%OJg4V z7xL~Xve^n`i=cCR*k@elhD%OWTxIBvJteQ7XUcbTPT)}UUXoyMcx7X7T8QXC{KErr z^LTI2L6+i}$$$>7TizjSr$e~!INC@*b}eUZ8s@l~mngk8^0^atZso%QImGh=Im`or9O1(Q zIm$-_a*U4*J@K#uY`fgIy=13Av;1#*JV59B0Y z5XdRMuqnq{`JzA$@x_50=1T%O!j}eelrIb97+)U9alRsu6MSVLC;6&CPVv=EIo`_G z1agS44dgIi7swI5K9Hk)Lm1A!dk2Ln0G4+V0B9}eUwKN83> zel(Ed{8%6-`0+qa@)Loa;wPJOs+FG#MRK#uYAfgI--0y)7i z26B>L3gi^O3>oSzB=*`Xj)eGCN5cG?BN2Yxkto06NQ~cfB+hR+lHj);N%A|6r1;%{ zi0XRJkr2P{NSHrxB*Gs$66KE^iSfsd#Q76P68xznN&d`{6n`ENQDR>>65=l%3G-Ku zMEGk*qWp~`G5*$(IDh9zg1>hp$v-%f;vWMd>g*>+Lj1EMVgALD2>b z9ZB(sfQX7)#E}q>bR^7+IuhZ<9EtMcj>LEgN8-GsBMBbmNRpRwB*mixBFb)QM?$=e zBVk_Fkq9s6NR*d%B*rT^66X~iN$^UJBza{=QoKq)MD4BWNQhT+B+RQj65%n9M0pKI zV!Wm!abC-j1h4H#lGkx0#p?z{6ySP}gm|nYVP4;n2yfs>ls9xF#^W4`^LR%RJi(D9 zPjn>3tpO2L7;+@UVMoFoaU{Y~N1_~aB*t+^;+$|K!AVDwoN^?^8#PH-bmGR2gm@E2 z!n~;?5uW5olqWk9<0+2Bc{4{6ytyMu-olX-PYsCZ#4Q~O@m7w6d72{;-rA8UZ{tXe zr#lj7?MQ;RbtK8gkrcNDM0DbIj)ZuIBVpd&kqGbLNR)SUB*r^A66c*AN$@U?Bzady zQVeU>>YcFY#N8YT@$Qaqv?V0TG>8bR@(jN5WioB*GO(qU<>mV|FCY zRYwwBb0o=iM^Zd1AfgjzI}+kKj)b|-kqFOqB+C0Z661XxiF3ar3Et0e2gP8KGu;qALmGdk9Q==CpePg69XbT@gzq=e6k~9KE;s;pXx}IPje*3r#lko zGaO0qnT{m+EJsp&c9TRzC!XU-h|hH-%;z~0;qx7d@&%5>_(DhGe32swzSxl@U*brL zFAa$3#LFBB@#T($`3grOe5E5%zRHmpU+qYouW=;7*E*8q>l{h(^#Ku`c!MJ$zR{5| z-{eSyZ+0Zgw>T2xTOEn>ZH^@Pc1MzYha)MzGa#Z9?{Xx>cRLd1dmM@Iy^cirK1X7F zzaw#ez>x$$=tz+>t0h;Yf_1bR^DCIg;R~9ZB*t zj->e6fQU|f&XEv5??{+ma3sPnIuhlV9EtJEj>P#DM-u$1BT0VEkrclk5YdTmI1=JF z9SQSWjzsuvN22_WBQbv0kvPBSNP^#YB*`B*lHv~oB0BLSM?(CuBVqo;kqCe4NR&Tw zB*vdR66Y@*N${7BB>5{xQv7v5L??dZNQl36B+TDA65;P1iSiGQ#P~-?;{1~%3I5rU zB>&<_ihm7==)~U~3Gwfag!vCgBK)T#QU1%382{}^od0no!7bt)TRS9@O0se!#e@}iEUc(H(pPF&oP5HI0Kn3r@U z!lN9C@=}h(c(fyNUfPiaFXKp(mvtn?%LPPq;_{A!cm+qoyrLr!UdfRtuk1*SS8*iH zt2&b4)f`Fk>W-v%Oh7~@uHi_C*K{PzYdI3(wH=A_I*!D6T}R@)o+Ak!>qwHNdtkvNAONpRSaBu5-caWo*J6Jw5qIPOT86OKeU z=}44Qj>LE)N8-G(BMIKbktA>GNQx%~M0Db0M?yTskuY!ONQ5_cB+6Sj662|k#Cc0c z61&Tog)dJ;YgCVcO=C-1VnV= zj*f(QCr84(vm+7S#gQoQ>PU>CbFAKpCgZ%DBMIK!ktFZoNQ(Cii0H(bj)ZtGN5Y(T zB*J?;66K5|G460A&RIti-04V?bB?6AOGwUMk=TEi%iY3fa*uqE)vEEITpsB`daxd% zhw5Q^xE`Sw(IfSudNIAYUP3RaN9m>XXuY&vMlY+E)644>^on{Vy|P|Kuc}wmtLrg( z4ZWsbORufh(d+8<^jN*V-av1t$LaBUf}W^bbx4PGL`QW@$8|y{bxLofH`bfzP4y%_ zSx?cM>CN>PdaB-1Z>6W{t@So~y4HGIZFHO7PS4QW>mBrtdMCZJ-bL@K+x2dGcfE(+ zQ_s|U>9pQkXLN_o>Q0^0UAkNM=)5M~s|&iQOS-Hp+S9D7x~A)TmY%KW=srDH@1ytC z{dzyWzn-TL&O-meS$twpQKOLr|47l zY5H`1hCWlDrO(#q=yUaX`h0zXzEEGJFV>gnOZ8>?a(#uqQeUO7*4OB3^>zAseS^MH z-=uHWx9D5-ZTfb7hrUzarSI1F=zH~j`hNX@eo#N8AJ&iPNA+X+as7mTQa`1i*3al? z^>g}p{epf`zocK*ujp6xYx;HlhJI7OrQg=?=y&yd`hER@{!o9UKh~e zQh%kt*5BxF^>_Mv{e%8d|D=D`zvy4}Z~AxrhyGLlrT^Cdm=QeP-YQjMk=$YGK(p*xH3y9v!pVklvzrd(aJ2X%reR>tITrBEU(N8%B-l&O3JLP z%qq&Ps?2K2tgg%$W!6w;O=Z?nW^HBGQD$9b)>CGzGV3d|fifE^GftWD%1lsZqB5<@ zgp>&@6Hz9rOiY=$G6`jp%A}OpNSTe5*+iL5m6@c>WF`M7#cZa`=E`iL%v5EzRAwt> zrYWDYLsWdnmJ~ zGBcIgOPRDXdn=PsrbC&mGM&ogl<88YTbUkZ^2(4hy~-4nDJoM^rmRdw8BZBjrm9R$ znYuEwl$ou}9A)~HnXAk`%IvF5zcTwNv%fO)lsQ0|1C=>QnS+%%M43aCIZT=P$_yxT zxH3m5bEGmyDRZ zboH`?>s-AY;U-rvPq@U@D-iB)^@@b6TfGwD)>f}fxUki$5bkRAs)TD=y&B<$RFg5xT4fM5N;>+ zj)aRzy%XV1QtwQ-j?}vlZX)%rgiA=>PPl*6yAiG)_3ngQN4*E(!cp%@xNFoi3D=Bz zFTxF@PSZ1hz3Ev%hMog-(DQ&Sy#VN>7XdkX3D89^1G?!IKo7kN$kS^8qSpbv^ah|n zZvu++7NA6L1IqLcphCF)(;ng8Pcz}hPge8kV5d8o+n0^Et zLO%fxrJn(Z(Jz4c^ebS1eghm%zXOh-KLAJ4pMaz2FTm0CH{clh2XHL4h%2d&BLz5~ z1_4f>!GIHK2;d|d3OJdD0ZyUefKzD%;51qUa5{|yoI#5M&ZNZvXVKz-vuO#yIkY6; zTp9&9kCpk@J%780r6~I-r zD&T5b4R8&u4!D-a0Is7o0N2x+fE#Enz>Tyv;3irJa5Jq7xP{gO+)85sx6%55+i3&9 z9ke0fP8tWei^c=)rU`(1Xd>WVY6aX!A;A3<20TC!z=IS8JVY_T!xRTRLJ7d5lmt9R zDZt~j5#R~h81N)*0(gox1w2iY0MF25z_T<3@EmOhc%C)~yg*w3UZkmjmuO4C%d{2X z6`BTkm9_@FM%w^hr|EznynpRKT!_wGj#!ep>Dvh)C2g9@_^ro0Dn*~;7=+5{-Pq_Zz=))p|XHjepCS} zIRSx#6ftnHqAGBRq8f0hqB?MxqFKP{xD4GLYL{T4bq@uaNMHTG>Tujlvz{M5y z1D8;=A8<)U`vXTQng?7;(E-5GiVg%Wt>_@&GKvlcF01Ge;Btx%1un1XFyIP`<^xw$ zGyq&l(c!?A6&(RwMbVMKRTUitTusr@z||EU10191Sl}9pjsvc#=y>2-icSEot>{GH zI*LvLuB+%|;ChNq0ghF4DsX*8rvW!mbUJWDMP~rVDLNB4UeQ^=35w1JPE>RbuvO8y zz>uQzfMG@F10#wq07eyE2#hJZ2pCs%F)*R%5@1r%rNES;%YYjxx*WK%qAP%#D7q53 zsiLcZlN4PIoUG^?;1osI0yk519dL6+*8{gubOUgzq8ou*D!K`{m7<%0(-hqT+*;AC zz-<)W2Ark^e}K&MUMd66+H^vP0?e(-4#6!+(Xe5z&#Z`37o0uDd1j;o(85BJp{avzu%PHo zU{TRqz>=c3fn`PS04s{#1$v6!1G1v`fmKBx0Bedq1lARO1e~SlW8iE>p8)45`V`ow z=riD4MV|xrQS=3HUqxR6`xSi!+)vTh!2K0{1DvPmTi^kTz5^bp=zHKnihckdtmsGJ zA&Pzi9;)bP;9-h>0nS(SD{w&3Z@|M9{SG`r(I3Dg75xc3O3`1yqZR!PJVwz!z+)95 zoGb1OA(AUT4IzjtjtwDxE8YzubSo|nA!;jr4k2JG&JH0~D;^IaOe<~=Awny@4D&7(y)G96$ zA<8O#6CuDV&J!WFDjpOetSW94A)+e26d{-@4izDeDqa;KgetBTA$ltQ6(MjcP8K0% zDxMZ0Tq^DsAyO(n7a>S0ju#<5D&7|%G%79_Au1|<7$G1k&KMyUDjpdj3@UCJAp$DC z86o&74jLiuDP9^O|pr%qVUn zA;KuWBO$mb4kRJ2C|)EXq$sW=A(|-uBq5L}P9-6RD4r!Dd?@ZEA#x}_CLw4jjwT^q zDBdO^R46VdAxbEICm}#6&L<%@C>|&wBq(kuAtETgC?Oaq4k;lHC|)Tc1Sqa4A^Io& zDIxGDPAVbhC!Q)H+$ZiTA<`#4DO)CqFtTg!ar&iV)Q^KPy5&&-^04gvYW4nE2-9{waNJ5kxecjnZ@$8$@me4O)et2(oDurgKcsN$sUsNqaD;Oj&p7f zmnB_(zJx?n;U`v*=>KpX75IrB5(n9mXoH`~NE~ciP?BGyX60Yxa}{Yr{3`uFl~rc0 zNn%^QBU7zUn4yP{IDPe-CT_c3OAG#=@jw4I-;ynf;BWKcC)o0HkgYb4{OeEiE%bN# z{uhai#9N;`yL)F*O)dC;hE{;!ZvrmZBK*W);m^PSLg2rjUa)5T-rIf7lpZ#pZuDV> z9yV-Tqr_^1#|LF?dxSM{pGMB;8NMLT;pYP_M>V+2rgUQ{ezz~r3~E7nzwG}3yD_ws literal 0 HcmV?d00001 diff --git a/UniSpeech-SAT/speaker_diarization/diarization.py b/UniSpeech-SAT/speaker_diarization/diarization.py new file mode 100644 index 0000000..df433b1 --- /dev/null +++ b/UniSpeech-SAT/speaker_diarization/diarization.py @@ -0,0 +1,321 @@ +import sys +import h5py +import soundfile as sf +import fire +import math +import yamlargparse +import numpy as np +from torch.utils.data import DataLoader +import torch +from utils.utils import parse_config_or_kwargs +from utils.dataset import DiarizationDataset +from models.models import TransformerDiarization +from scipy.signal import medfilt +from sklearn.cluster import AgglomerativeClustering +from scipy.spatial import distance +from utils.kaldi_data import KaldiData + + +def get_cl_sil(args, acti, cls_num): + n_chunks = len(acti) + mean_acti = np.array([np.mean(acti[i], axis=0) + for i in range(n_chunks)]).flatten() + n = args.num_speakers + sil_spk_th = args.sil_spk_th + + cl_lst = [] + sil_lst = [] + for chunk_idx in range(n_chunks): + if cls_num is not None: + if args.num_speakers > cls_num: + mean_acti_bi = np.array([mean_acti[n * chunk_idx + s_loc_idx] + for s_loc_idx in range(n)]) + min_idx = np.argmin(mean_acti_bi) + mean_acti[n * chunk_idx + min_idx] = 0.0 + + for s_loc_idx in range(n): + a = n * chunk_idx + (s_loc_idx + 0) % n + b = n * chunk_idx + (s_loc_idx + 1) % n + if mean_acti[a] > sil_spk_th and mean_acti[b] > sil_spk_th: + cl_lst.append((a, b)) + else: + if mean_acti[a] <= sil_spk_th: + sil_lst.append(a) + + return cl_lst, sil_lst + + +def clustering(args, svec, cls_num, ahc_dis_th, cl_lst, sil_lst): + org_svec_len = len(svec) + svec = np.delete(svec, sil_lst, 0) + + # update cl_lst idx + _tbl = [i - sum(sil < i for sil in sil_lst) for i in range(org_svec_len)] + cl_lst = [(_tbl[_cl[0]], _tbl[_cl[1]]) for _cl in cl_lst] + + distMat = distance.cdist(svec, svec, metric='euclidean') + for cl in cl_lst: + distMat[cl[0], cl[1]] = args.clink_dis + distMat[cl[1], cl[0]] = args.clink_dis + + clusterer = AgglomerativeClustering( + n_clusters=cls_num, + affinity='precomputed', + linkage='average', + distance_threshold=ahc_dis_th) + clusterer.fit(distMat) + + if cls_num is not None: + print("oracle n_clusters is known") + else: + print("oracle n_clusters is unknown") + print("estimated n_clusters by constraind AHC: {}" + .format(len(np.unique(clusterer.labels_)))) + cls_num = len(np.unique(clusterer.labels_)) + + sil_lab = cls_num + insert_sil_lab = [sil_lab for i in range(len(sil_lst))] + insert_sil_lab_idx = [sil_lst[i] - i for i in range(len(sil_lst))] + print("insert_sil_lab : {}".format(insert_sil_lab)) + print("insert_sil_lab_idx : {}".format(insert_sil_lab_idx)) + clslab = np.insert(clusterer.labels_, + insert_sil_lab_idx, + insert_sil_lab).reshape(-1, args.num_speakers) + print("clslab : {}".format(clslab)) + + return clslab, cls_num + + +def merge_act_max(act, i, j): + for k in range(len(act)): + act[k, i] = max(act[k, i], act[k, j]) + act[k, j] = 0.0 + return act + + +def merge_acti_clslab(args, acti, clslab, cls_num): + sil_lab = cls_num + for i in range(len(clslab)): + _lab = clslab[i].reshape(-1, 1) + distM = distance.cdist(_lab, _lab, metric='euclidean').astype(np.int64) + for j in range(len(distM)): + distM[j][:j] = -1 + idx_lst = np.where(np.count_nonzero(distM == 0, axis=1) > 1) + merge_done = [] + for j in idx_lst[0]: + for k in (np.where(distM[j] == 0))[0]: + if j != k and clslab[i, j] != sil_lab and k not in merge_done: + print("merge : (i, j, k) == ({}, {}, {})".format(i, j, k)) + acti[i] = merge_act_max(acti[i], j, k) + clslab[i, k] = sil_lab + merge_done.append(j) + + return acti, clslab + + +def stitching(args, acti, clslab, cls_num): + n_chunks = len(acti) + s_loc = args.num_speakers + sil_lab = cls_num + s_tot = max(cls_num, s_loc-1) + + # Extend the max value of s_loc_idx to s_tot+1 + add_acti = [] + for chunk_idx in range(n_chunks): + zeros = np.zeros((len(acti[chunk_idx]), s_tot+1)) + if s_tot+1 > s_loc: + zeros[:, :-(s_tot+1-s_loc)] = acti[chunk_idx] + else: + zeros = acti[chunk_idx] + add_acti.append(zeros) + acti = np.array(add_acti) + + out_chunks = [] + for chunk_idx in range(n_chunks): + # Make sloci2lab_dct. + # key: s_loc_idx + # value: estimated label by clustering or sil_lab + cls_set = set() + for s_loc_idx in range(s_tot+1): + cls_set.add(s_loc_idx) + + sloci2lab_dct = {} + for s_loc_idx in range(s_tot+1): + if s_loc_idx < s_loc: + sloci2lab_dct[s_loc_idx] = clslab[chunk_idx][s_loc_idx] + if clslab[chunk_idx][s_loc_idx] in cls_set: + cls_set.remove(clslab[chunk_idx][s_loc_idx]) + else: + if clslab[chunk_idx][s_loc_idx] != sil_lab: + raise ValueError + else: + sloci2lab_dct[s_loc_idx] = list(cls_set)[s_loc_idx-s_loc] + + # Sort by label value + sloci2lab_lst = sorted(sloci2lab_dct.items(), key=lambda x: x[1]) + + # Select sil_lab_idx + sil_lab_idx = None + for idx_lab in sloci2lab_lst: + if idx_lab[1] == sil_lab: + sil_lab_idx = idx_lab[0] + break + if sil_lab_idx is None: + raise ValueError + + # Get swap_idx + # [idx of label(0), idx of label(1), ..., idx of label(s_tot)] + swap_idx = [sil_lab_idx for j in range(s_tot+1)] + for lab in range(s_tot+1): + for idx_lab in sloci2lab_lst: + if lab == idx_lab[1]: + swap_idx[lab] = idx_lab[0] + + print("swap_idx {}".format(swap_idx)) + swap_acti = acti[chunk_idx][:, swap_idx] + swap_acti = np.delete(swap_acti, sil_lab, 1) + out_chunks.append(swap_acti) + + return out_chunks + + +def prediction(num_speakers, net, wav_list, chunk_len_list): + acti_lst = [] + svec_lst = [] + len_list = [] + + with torch.no_grad(): + for wav, chunk_len in zip(wav_list, chunk_len_list): + wav = wav.to('cuda') + outputs = net.batch_estimate(torch.unsqueeze(wav, 0)) + ys = outputs[0] + + for i in range(num_speakers): + spkivecs = outputs[i+1] + svec_lst.append(spkivecs[0].cpu().detach().numpy()) + + acti = ys[0][-chunk_len:].cpu().detach().numpy() + acti_lst.append(acti) + len_list.append(chunk_len) + + acti_arr = np.concatenate(acti_lst, axis=0) # totol_len x num_speakers + svec_arr = np.stack(svec_lst) # (chunk_num x num_speakers) x emb_dim + len_arr = np.array(len_list) # chunk_num + + return acti_arr, svec_arr, len_arr + +def cluster(args, conf, acti_arr, svec_arr, len_arr): + + acti_list = [] + n_chunks = len_arr.shape[0] + start = 0 + for i in range(n_chunks): + chunk_len = len_arr[i] + acti_list.append(acti_arr[start: start+chunk_len]) + start += chunk_len + acti = np.array(acti_list) + svec = svec_arr + + # initialize clustering setting + cls_num = None + ahc_dis_th = args.ahc_dis_th + # Get cannot-link index list and silence index list + cl_lst, sil_lst = get_cl_sil(args, acti, cls_num) + + n_samples = n_chunks * args.num_speakers - len(sil_lst) + min_n_samples = 2 + if cls_num is not None: + min_n_samples = cls_num + + if n_samples >= min_n_samples: + # clustering (if cls_num is None, update cls_num) + clslab, cls_num =\ + clustering(args, svec, cls_num, ahc_dis_th, cl_lst, sil_lst) + # merge + acti, clslab = merge_acti_clslab(args, acti, clslab, cls_num) + # stitching + out_chunks = stitching(args, acti, clslab, cls_num) + else: + out_chunks = acti + + outdata = np.vstack(out_chunks) + # Saving the resuts + return outdata + +def make_rttm(args, conf, cluster_data): + args.frame_shift = conf['model']['frame_shift'] + args.subsampling = conf['model']['subsampling'] + args.sampling_rate = conf['dataset']['sampling_rate'] + + with open(args.out_rttm_file, 'w') as wf: + a = np.where(cluster_data > args.threshold, 1, 0) + if args.median > 1: + a = medfilt(a, (args.median, 1)) + for spkid, frames in enumerate(a.T): + frames = np.pad(frames, (1, 1), 'constant') + changes, = np.where(np.diff(frames, axis=0) != 0) + fmt = "SPEAKER {:s} 1 {:7.2f} {:7.2f} {:s} " + for s, e in zip(changes[::2], changes[1::2]): + print(fmt.format( + args.session, + s * args.frame_shift * args.subsampling / args.sampling_rate, + (e - s) * args.frame_shift * args.subsampling / args.sampling_rate, + args.session + "_" + str(spkid)), file=wf) + +def main(args): + conf = parse_config_or_kwargs(args.config_path) + num_speakers = conf['dataset']['num_speakers'] + args.num_speakers = num_speakers + + # Prepare model + model_parameter_dict = torch.load(args.model_init)['model'] + model_all_n_speakers = model_parameter_dict["embed.weight"].shape[0] + conf['model']['all_n_speakers'] = model_all_n_speakers + net = TransformerDiarization(**conf['model']) + net.load_state_dict(model_parameter_dict, strict=False) + net.eval() + net = net.to("cuda") + + audio, sr = sf.read(args.wav_path, dtype="float32") + audio_len = audio.shape[0] + chunk_size, frame_shift, subsampling = conf['dataset']['chunk_size'], conf['model']['frame_shift'], conf['model']['subsampling'] + scale_ratio = int(frame_shift * subsampling) + chunk_audio_size = chunk_size * scale_ratio + wav_list, chunk_len_list = [], [] + for i in range(0, math.ceil(1.0 * audio_len / chunk_audio_size)): + start, end = i*chunk_audio_size, (i+1)*chunk_audio_size + if end > audio_len: + chunk_len_list.append(int((audio_len-start) / scale_ratio)) + end = audio_len + start = max(0, audio_len - chunk_audio_size) + else: + chunk_len_list.append(chunk_size) + wav_list.append(audio[start:end]) + wav_list = [torch.from_numpy(wav).float() for wav in wav_list] + + acti_arr, svec_arr, len_arr = prediction(num_speakers, net, wav_list, chunk_len_list) + cluster_data = cluster(args, conf, acti_arr, svec_arr, len_arr) + make_rttm(args, conf, cluster_data) + +if __name__ == '__main__': + parser = yamlargparse.ArgumentParser(description='decoding') + parser.add_argument('--wav_path', + help='the input wav path', + default="tmp/mix_0000496.wav") + parser.add_argument('--config_path', + help='config file path', + default="config/infer_est_nspk1.yaml") + parser.add_argument('--model_init', + help='model initialize path', + default="") + parser.add_argument('--sil_spk_th', default=0.05, type=float) + parser.add_argument('--ahc_dis_th', default=1.0, type=float) + parser.add_argument('--clink_dis', default=1.0e+4, type=float) + parser.add_argument('--session', default='Anonymous', help='the name of the output speaker') + parser.add_argument('--out_rttm_file', default='out.rttm', help='the output rttm file') + parser.add_argument('--threshold', default=0.4, type=float) + parser.add_argument('--median', default=25, type=int) + + + args = parser.parse_args() + main(args) diff --git a/UniSpeech-SAT/speaker_diarization/models/models.py b/UniSpeech-SAT/speaker_diarization/models/models.py new file mode 100644 index 0000000..bdad03c --- /dev/null +++ b/UniSpeech-SAT/speaker_diarization/models/models.py @@ -0,0 +1,391 @@ +# Copyright (c) 2021 Nippon Telegraph and Telephone corporation (NTT). +# All rights reserved + +import sys +import numpy as np +import torch +import torch.nn.functional as F +import torch.nn as nn +import torchaudio.transforms as trans +from collections import OrderedDict +from itertools import permutations +from models.transformer import TransformerEncoder +from .utils import UpstreamExpert + + +class GradMultiply(torch.autograd.Function): + @staticmethod + def forward(ctx, x, scale): + ctx.scale = scale + res = x.new(x) + return res + + @staticmethod + def backward(ctx, grad): + return grad * ctx.scale, None + + + +""" +P: number of permutation +T: number of frames +C: number of speakers (classes) +B: mini-batch size +""" + + +def batch_pit_loss_parallel(outputs, labels, ilens=None): + """ calculate the batch pit loss parallelly + Args: + outputs (torch.Tensor): B x T x C + labels (torch.Tensor): B x T x C + ilens (torch.Tensor): B + Returns: + perm (torch.Tensor): permutation for outputs (Batch, num_spk) + loss + """ + + if ilens is None: + mask, scale = 1.0, outputs.shape[1] + else: + scale = torch.unsqueeze(torch.LongTensor(ilens), 1).to(outputs.device) + mask = outputs.new_zeros(outputs.size()[:-1]) + for i, chunk_len in enumerate(ilens): + mask[i, :chunk_len] += 1.0 + mask /= scale + + def loss_func(output, label): + # return torch.mean(F.binary_cross_entropy_with_logits(output, label, reduction='none'), dim=tuple(range(1, output.dim()))) + return torch.sum(F.binary_cross_entropy_with_logits(output, label, reduction='none') * mask, dim=-1) + + def pair_loss(outputs, labels, permutation): + return sum([loss_func(outputs[:,:,s], labels[:,:,t]) for s, t in enumerate(permutation)]) / len(permutation) + + device = outputs.device + num_spk = outputs.shape[-1] + all_permutations = list(permutations(range(num_spk))) + losses = torch.stack([pair_loss(outputs, labels, p) for p in all_permutations], dim=1) + loss, perm = torch.min(losses, dim=1) + perm = torch.index_select(torch.tensor(all_permutations, device=device, dtype=torch.long), 0, perm) + return torch.mean(loss), perm + + +def fix_state_dict(state_dict): + new_state_dict = OrderedDict() + for k, v in state_dict.items(): + if k.startswith('module.'): + # remove 'module.' of DataParallel + k = k[7:] + if k.startswith('net.'): + # remove 'net.' of PadertorchModel + k = k[4:] + new_state_dict[k] = v + return new_state_dict + + +class TransformerDiarization(nn.Module): + def __init__(self, + n_speakers, + all_n_speakers, + feat_dim, + n_units, + n_heads, + n_layers, + dropout_rate, + spk_emb_dim, + sr=8000, + frame_shift=256, + frame_size=1024, + context_size=0, + subsampling=1, + feat_type='fbank', + feature_selection='default', + interpolate_mode='linear', + update_extract=False, + feature_grad_mult=1.0 + ): + super(TransformerDiarization, self).__init__() + self.context_size = context_size + self.subsampling = subsampling + self.feat_type = feat_type + self.feature_selection = feature_selection + self.sr = sr + self.frame_shift = frame_shift + self.interpolate_mode = interpolate_mode + self.update_extract = update_extract + self.feature_grad_mult = feature_grad_mult + + if feat_type == 'fbank': + self.feature_extract = trans.MelSpectrogram(sample_rate=sr, + n_fft=frame_size, + win_length=frame_size, + hop_length=frame_shift, + f_min=0.0, + f_max=sr // 2, + pad=0, + n_mels=feat_dim) + else: + self.feature_extract = UpstreamExpert(feat_type) + # self.feature_extract = torch.hub.load('s3prl/s3prl', 'hubert_local', ckpt=feat_type) + if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(self.feature_extract.model.encoder.layers[23].self_attn, "fp32_attention"): + self.feature_extract.model.encoder.layers[23].self_attn.fp32_attention = False + if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(self.feature_extract.model.encoder.layers[11].self_attn, "fp32_attention"): + self.feature_extract.model.encoder.layers[11].self_attn.fp32_attention = False + self.feat_num = self.get_feat_num() + self.feature_weight = nn.Parameter(torch.zeros(self.feat_num)) + # for param in self.feature_extract.parameters(): + # param.requires_grad = False + self.resample = trans.Resample(orig_freq=sr, new_freq=16000) + + if feat_type != 'fbank' and feat_type != 'mfcc': + freeze_list = ['final_proj', 'label_embs_concat', 'mask_emb', 'project_q', 'quantizer', 'spk_proj', 'layer_norm_for_extract'] + for name, param in self.feature_extract.named_parameters(): + for freeze_val in freeze_list: + if freeze_val in name: + param.requires_grad = False + break + if not self.update_extract: + for param in self.feature_extract.parameters(): + param.requires_grad = False + + self.instance_norm = nn.InstanceNorm1d(feat_dim) + + feat_dim = feat_dim * (self.context_size*2 + 1) + self.enc = TransformerEncoder( + feat_dim, n_layers, n_units, h=n_heads, dropout_rate=dropout_rate) + self.linear = nn.Linear(n_units, n_speakers) + + for i in range(n_speakers): + setattr(self, '{}{:d}'.format("linear", i), nn.Linear(n_units, spk_emb_dim)) + + self.n_speakers = n_speakers + self.embed = nn.Embedding(all_n_speakers, spk_emb_dim) + self.alpha = nn.Parameter(torch.rand(1)[0] + torch.Tensor([0.5])[0]) + self.beta = nn.Parameter(torch.rand(1)[0] + torch.Tensor([0.5])[0]) + + def get_feat_num(self): + self.feature_extract.eval() + wav = [torch.randn(self.sr).to(next(self.feature_extract.parameters()).device)] + with torch.no_grad(): + features = self.feature_extract(wav) + select_feature = features[self.feature_selection] + if isinstance(select_feature, (list, tuple)): + return len(select_feature) + else: + return 1 + + def fix_except_embedding(self, requires_grad=False): + for name, param in self.named_parameters(): + if 'embed' not in name: + param.requires_grad = requires_grad + + def modfy_emb(self, weight): + self.embed = nn.Embedding.from_pretrained(weight) + + def splice(self, data, context_size): + # data: B x feat_dim x time_len + data = torch.unsqueeze(data, -1) + kernel_size = context_size*2 + 1 + splice_data = F.unfold(data, kernel_size=(kernel_size, 1), padding=(context_size, 0)) + return splice_data + + def get_feat(self, xs): + wav_len = xs.shape[-1] + chunk_size = int(wav_len / self.frame_shift) + chunk_size = int(chunk_size / self.subsampling) + + self.feature_extract.eval() + if self.update_extract: + xs = self.resample(xs) + feature = self.feature_extract([sample for sample in xs]) + else: + with torch.no_grad(): + if self.feat_type == 'fbank': + feature = self.feature_extract(xs) + 1e-6 # B x feat_dim x time_len + feature = feature.log() + else: + xs = self.resample(xs) + feature = self.feature_extract([sample for sample in xs]) + + if self.feat_type != "fbank" and self.feat_type != "mfcc": + feature = feature[self.feature_selection] + if isinstance(feature, (list, tuple)): + feature = torch.stack(feature, dim=0) + else: + feature = feature.unsqueeze(0) + + norm_weights = F.softmax(self.feature_weight, dim=-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) + feature = (norm_weights * feature).sum(dim=0) + feature = torch.transpose(feature, 1, 2) + 1e-6 + + feature = self.instance_norm(feature) + feature = self.splice(feature, self.context_size) + feature = feature[:, :, ::self.subsampling] + feature = F.interpolate(feature, chunk_size, mode=self.interpolate_mode) + feature = torch.transpose(feature, 1, 2) + + if self.feature_grad_mult != 1.0: + feature = GradMultiply.apply(feature, self.feature_grad_mult) + + return feature + + def forward(self, inputs): + if isinstance(inputs, list): + xs = inputs[0] + else: + xs = inputs + feature = self.get_feat(xs) + + pad_shape = feature.shape + emb = self.enc(feature) + ys = self.linear(emb) + ys = ys.reshape(pad_shape[0], pad_shape[1], -1) + + spksvecs = [] + for i in range(self.n_speakers): + spkivecs = getattr(self, '{}{:d}'.format("linear", i))(emb) + spkivecs = spkivecs.reshape(pad_shape[0], pad_shape[1], -1) + spksvecs.append(spkivecs) + + return ys, spksvecs + + def get_loss(self, inputs, ys, spksvecs, cal_spk_loss=True): + ts = inputs[1] + ss = inputs[2] + ns = inputs[3] + ilens = inputs[4] + ilens = [ilen.item() for ilen in ilens] + + pit_loss, sigmas = batch_pit_loss_parallel(ys, ts, ilens) + if cal_spk_loss: + spk_loss = self.spk_loss_parallel(spksvecs, ys, ts, ss, sigmas, ns, ilens) + else: + spk_loss = torch.tensor(0.0).to(pit_loss.device) + + alpha = torch.clamp(self.alpha, min=sys.float_info.epsilon) + + return {'spk_loss':spk_loss, + 'pit_loss': pit_loss} + + + def batch_estimate(self, xs): + out = self(xs) + ys = out[0] + spksvecs = out[1] + spksvecs = list(zip(*spksvecs)) + outputs = [ + self.estimate(spksvec, y) + for (spksvec, y) in zip(spksvecs, ys)] + outputs = list(zip(*outputs)) + + return outputs + + def batch_estimate_with_perm(self, xs, ts, ilens=None): + out = self(xs) + ys = out[0] + if ts[0].shape[1] > ys[0].shape[1]: + # e.g. the case of training 3-spk model with 4-spk data + add_dim = ts[0].shape[1] - ys[0].shape[1] + y_device = ys[0].device + zeros = [torch.zeros(ts[0].shape).to(y_device) + for i in range(len(ts))] + _ys = [] + for zero, y in zip(zeros, ys): + _zero = zero + _zero[:, :-add_dim] = y + _ys.append(_zero) + _, sigmas = batch_pit_loss_parallel(_ys, ts, ilens) + else: + _, sigmas = batch_pit_loss_parallel(ys, ts, ilens) + spksvecs = out[1] + spksvecs = list(zip(*spksvecs)) + outputs = [self.estimate(spksvec, y) + for (spksvec, y) in zip(spksvecs, ys)] + outputs = list(zip(*outputs)) + zs = outputs[0] + + if ts[0].shape[1] > ys[0].shape[1]: + # e.g. the case of training 3-spk model with 4-spk data + add_dim = ts[0].shape[1] - ys[0].shape[1] + z_device = zs[0].device + zeros = [torch.zeros(ts[0].shape).to(z_device) + for i in range(len(ts))] + _zs = [] + for zero, z in zip(zeros, zs): + _zero = zero + _zero[:, :-add_dim] = z + _zs.append(_zero) + zs = _zs + outputs[0] = zs + outputs.append(sigmas) + + # outputs: [zs, nmz_wavg_spk0vecs, nmz_wavg_spk1vecs, ..., sigmas] + return outputs + + def estimate(self, spksvec, y): + outputs = [] + z = torch.sigmoid(y.transpose(1, 0)) + + outputs.append(z.transpose(1, 0)) + for spkid, spkvec in enumerate(spksvec): + norm_spkvec_inv = 1.0 / torch.norm(spkvec, dim=1) + # Normalize speaker vectors before weighted average + spkvec = torch.mul( + spkvec.transpose(1, 0), norm_spkvec_inv + ).transpose(1, 0) + wavg_spkvec = torch.mul( + spkvec.transpose(1, 0), z[spkid] + ).transpose(1, 0) + sum_wavg_spkvec = torch.sum(wavg_spkvec, dim=0) + nmz_wavg_spkvec = sum_wavg_spkvec / torch.norm(sum_wavg_spkvec) + outputs.append(nmz_wavg_spkvec) + + # outputs: [z, nmz_wavg_spk0vec, nmz_wavg_spk1vec, ...] + return outputs + + def spk_loss_parallel(self, spksvecs, ys, ts, ss, sigmas, ns, ilens): + ''' + spksvecs (List[torch.Tensor, ...]): [B x T x emb_dim, ...] + ys (torch.Tensor): B x T x 3 + ts (torch.Tensor): B x T x 3 + ss (torch.Tensor): B x 3 + sigmas (torch.Tensor): B x 3 + ns (torch.Tensor): B x total_spk_num x 1 + ilens (List): B + ''' + chunk_spk_num = len(spksvecs) # 3 + + len_mask = ys.new_zeros((ys.size()[:-1])) # B x T + for i, len_val in enumerate(ilens): + len_mask[i,:len_val] += 1.0 + ts = ts * len_mask.unsqueeze(-1) + len_mask = len_mask.repeat((chunk_spk_num, 1)) # B*3 x T + + spk_vecs = torch.cat(spksvecs, dim=0) # B*3 x T x emb_dim + # Normalize speaker vectors before weighted average + spk_vecs = F.normalize(spk_vecs, dim=-1) + + ys = torch.permute(torch.sigmoid(ys), dims=(2, 0, 1)) # 3 x B x T + ys = ys.reshape(-1, ys.shape[-1]).unsqueeze(-1) # B*3 x T x 1 + + weight_spk_vec = ys * spk_vecs # B*3 x T x emb_dim + weight_spk_vec *= len_mask.unsqueeze(-1) + sum_spk_vec = torch.sum(weight_spk_vec, dim=1) # B*3 x emb_dim + norm_spk_vec = F.normalize(sum_spk_vec, dim=1) + + embeds = F.normalize(self.embed(ns[0]).squeeze(), dim=1) # total_spk_num x emb_dim + dist = torch.cdist(norm_spk_vec, embeds) # B*3 x total_spk_num + logits = -1.0 * torch.add(torch.clamp(self.alpha, min=sys.float_info.epsilon) * torch.pow(dist, 2), self.beta) + label = torch.gather(ss, 1, sigmas).transpose(0, 1).reshape(-1, 1).squeeze() # B*3 + label[label==-1] = 0 + valid_spk_mask = torch.gather(torch.sum(ts, dim=1), 1, sigmas).transpose(0, 1) # 3 x B + valid_spk_mask = (torch.flatten(valid_spk_mask) > 0).float() # B*3 + + valid_spk_loss_num = torch.sum(valid_spk_mask).item() + if valid_spk_loss_num > 0: + loss = F.cross_entropy(logits, label, reduction='none') * valid_spk_mask / valid_spk_loss_num + # uncomment the line below, the loss result is same as batch_spk_loss + # loss = F.cross_entropy(logits, label, reduction='none') * valid_spk_mask / valid_spk_mask.shape[0] + return torch.sum(loss) + else: + return torch.tensor(0.0).to(ys.device) diff --git a/UniSpeech-SAT/speaker_diarization/models/transformer.py b/UniSpeech-SAT/speaker_diarization/models/transformer.py new file mode 100644 index 0000000..d2deb9b --- /dev/null +++ b/UniSpeech-SAT/speaker_diarization/models/transformer.py @@ -0,0 +1,147 @@ +# Copyright (c) 2021 Nippon Telegraph and Telephone corporation (NTT). +# All rights reserved + +import numpy as np +import torch +import torch.nn.functional as F +import torch.nn as nn +from torch.optim.lr_scheduler import _LRScheduler + + +class NoamScheduler(_LRScheduler): + """ learning rate scheduler used in the transformer + See https://arxiv.org/pdf/1706.03762.pdf + lrate = d_model**(-0.5) * \ + min(step_num**(-0.5), step_num*warmup_steps**(-1.5)) + Scaling factor is implemented as in + http://nlp.seas.harvard.edu/2018/04/03/attention.html#optimizer + """ + + def __init__( + self, optimizer, d_model, warmup_steps, tot_step, scale, + last_epoch=-1 + ): + self.d_model = d_model + self.warmup_steps = warmup_steps + self.tot_step = tot_step + self.scale = scale + super(NoamScheduler, self).__init__(optimizer, last_epoch) + + def get_lr(self): + self.last_epoch = max(1, self.last_epoch) + step_num = self.last_epoch + val = self.scale * self.d_model ** (-0.5) * \ + min(step_num ** (-0.5), step_num * self.warmup_steps ** (-1.5)) + + return [base_lr / base_lr * val for base_lr in self.base_lrs] + + +class MultiHeadSelfAttention(nn.Module): + """ Multi head "self" attention layer + """ + + def __init__(self, n_units, h=8, dropout_rate=0.1): + super(MultiHeadSelfAttention, self).__init__() + self.linearQ = nn.Linear(n_units, n_units) + self.linearK = nn.Linear(n_units, n_units) + self.linearV = nn.Linear(n_units, n_units) + self.linearO = nn.Linear(n_units, n_units) + self.d_k = n_units // h + self.h = h + self.dropout = nn.Dropout(p=dropout_rate) + # attention for plot + self.att = None + + def forward(self, x, batch_size): + # x: (BT, F) + q = self.linearQ(x).reshape(batch_size, -1, self.h, self.d_k) + k = self.linearK(x).reshape(batch_size, -1, self.h, self.d_k) + v = self.linearV(x).reshape(batch_size, -1, self.h, self.d_k) + + scores = torch.matmul( + q.transpose(1, 2), k.permute(0, 2, 3, 1)) / np.sqrt(self.d_k) + # scores: (B, h, T, T) = (B, h, T, d_k) x (B, h, d_k, T) + self.att = F.softmax(scores, dim=3) + p_att = self.dropout(self.att) + x = torch.matmul(p_att, v.transpose(1, 2)) + x = x.transpose(1, 2).reshape(-1, self.h * self.d_k) + + return self.linearO(x) + + +class PositionwiseFeedForward(nn.Module): + """ Positionwise feed-forward layer + """ + + def __init__(self, n_units, d_units, dropout_rate): + super(PositionwiseFeedForward, self).__init__() + self.linear1 = nn.Linear(n_units, d_units) + self.linear2 = nn.Linear(d_units, n_units) + self.dropout = nn.Dropout(p=dropout_rate) + + def forward(self, x): + return self.linear2(self.dropout(F.relu(self.linear1(x)))) + + +class PositionalEncoding(nn.Module): + """ Positional encoding function + """ + + def __init__(self, n_units, dropout_rate, max_len): + super(PositionalEncoding, self).__init__() + self.dropout = nn.Dropout(p=dropout_rate) + positions = np.arange(0, max_len, dtype='f')[:, None] + dens = np.exp( + np.arange(0, n_units, 2, dtype='f') * -(np.log(10000.) / n_units)) + self.enc = np.zeros((max_len, n_units), dtype='f') + self.enc[:, ::2] = np.sin(positions * dens) + self.enc[:, 1::2] = np.cos(positions * dens) + self.scale = np.sqrt(n_units) + + def forward(self, x): + x = x * self.scale + self.xp.array(self.enc[:, :x.shape[1]]) + return self.dropout(x) + + +class TransformerEncoder(nn.Module): + def __init__(self, idim, n_layers, n_units, + e_units=2048, h=8, dropout_rate=0.1): + super(TransformerEncoder, self).__init__() + self.linear_in = nn.Linear(idim, n_units) + # self.lnorm_in = nn.LayerNorm(n_units) + self.pos_enc = PositionalEncoding(n_units, dropout_rate, 5000) + self.n_layers = n_layers + self.dropout = nn.Dropout(p=dropout_rate) + for i in range(n_layers): + setattr(self, '{}{:d}'.format("lnorm1_", i), + nn.LayerNorm(n_units)) + setattr(self, '{}{:d}'.format("self_att_", i), + MultiHeadSelfAttention(n_units, h, dropout_rate)) + setattr(self, '{}{:d}'.format("lnorm2_", i), + nn.LayerNorm(n_units)) + setattr(self, '{}{:d}'.format("ff_", i), + PositionwiseFeedForward(n_units, e_units, dropout_rate)) + self.lnorm_out = nn.LayerNorm(n_units) + + def forward(self, x): + # x: (B, T, F) ... batch, time, (mel)freq + BT_size = x.shape[0] * x.shape[1] + # e: (BT, F) + e = self.linear_in(x.reshape(BT_size, -1)) + # Encoder stack + for i in range(self.n_layers): + # layer normalization + e = getattr(self, '{}{:d}'.format("lnorm1_", i))(e) + # self-attention + s = getattr(self, '{}{:d}'.format("self_att_", i))(e, x.shape[0]) + # residual + e = e + self.dropout(s) + # layer normalization + e = getattr(self, '{}{:d}'.format("lnorm2_", i))(e) + # positionwise feed-forward + s = getattr(self, '{}{:d}'.format("ff_", i))(e) + # residual + e = e + self.dropout(s) + # final layer normalization + # output: (BT, F) + return self.lnorm_out(e) diff --git a/UniSpeech-SAT/speaker_diarization/models/utils.py b/UniSpeech-SAT/speaker_diarization/models/utils.py new file mode 100644 index 0000000..041f9fd --- /dev/null +++ b/UniSpeech-SAT/speaker_diarization/models/utils.py @@ -0,0 +1,78 @@ +import torch +import fairseq +from packaging import version +import torch.nn.functional as F +from fairseq import tasks +from fairseq.checkpoint_utils import load_checkpoint_to_cpu +from fairseq.dataclass.utils import convert_namespace_to_omegaconf +from omegaconf import OmegaConf +from s3prl.upstream.interfaces import UpstreamBase +from torch.nn.utils.rnn import pad_sequence + +def load_model(filepath): + state = torch.load(filepath, map_location=lambda storage, loc: storage) + # state = load_checkpoint_to_cpu(filepath) + state["cfg"] = OmegaConf.create(state["cfg"]) + + if "args" in state and state["args"] is not None: + cfg = convert_namespace_to_omegaconf(state["args"]) + elif "cfg" in state and state["cfg"] is not None: + cfg = state["cfg"] + else: + raise RuntimeError( + f"Neither args nor cfg exist in state keys = {state.keys()}" + ) + + task = tasks.setup_task(cfg.task) + if "task_state" in state: + task.load_state_dict(state["task_state"]) + + model = task.build_model(cfg.model) + + return model, cfg, task + + +################### +# UPSTREAM EXPERT # +################### +class UpstreamExpert(UpstreamBase): + def __init__(self, ckpt, **kwargs): + super().__init__(**kwargs) + assert version.parse(fairseq.__version__) > version.parse( + "0.10.2" + ), "Please install the fairseq master branch." + + model, cfg, task = load_model(ckpt) + self.model = model + self.task = task + + if len(self.hooks) == 0: + module_name = "self.model.encoder.layers" + for module_id in range(len(eval(module_name))): + self.add_hook( + f"{module_name}[{module_id}]", + lambda input, output: input[0].transpose(0, 1), + ) + self.add_hook("self.model.encoder", lambda input, output: output[0]) + + def forward(self, wavs): + if self.task.cfg.normalize: + wavs = [F.layer_norm(wav, wav.shape) for wav in wavs] + + device = wavs[0].device + wav_lengths = torch.LongTensor([len(wav) for wav in wavs]).to(device) + wav_padding_mask = ~torch.lt( + torch.arange(max(wav_lengths)).unsqueeze(0).to(device), + wav_lengths.unsqueeze(1), + ) + padded_wav = pad_sequence(wavs, batch_first=True) + + features, feat_padding_mask = self.model.extract_features( + padded_wav, + padding_mask=wav_padding_mask, + mask=None, + ) + return { + "default": features, + } + diff --git a/UniSpeech-SAT/speaker_diarization/requirements.txt b/UniSpeech-SAT/speaker_diarization/requirements.txt new file mode 100644 index 0000000..07b7632 --- /dev/null +++ b/UniSpeech-SAT/speaker_diarization/requirements.txt @@ -0,0 +1,11 @@ +soundfile +fire +sentencepiece +tqdm +pyyaml +h5py +yamlargparse +sklearn +matplotlib +torchaudio +s3rpl diff --git a/UniSpeech-SAT/speaker_diarization/tmp/mix_0000496.wav b/UniSpeech-SAT/speaker_diarization/tmp/mix_0000496.wav new file mode 120000 index 0000000..86792f6 --- /dev/null +++ b/UniSpeech-SAT/speaker_diarization/tmp/mix_0000496.wav @@ -0,0 +1 @@ +/mnt/lustre/sjtu/home/czy97/workspace/sd/EEND-vec-clustering/EEND-vector-clustering/egs/mini_librispeech/v1/data/simu/wav/dev_clean_2_ns3_beta2_500/100/mix_0000496.wav \ No newline at end of file diff --git a/UniSpeech-SAT/speaker_diarization/utils/dataset.py b/UniSpeech-SAT/speaker_diarization/utils/dataset.py new file mode 100644 index 0000000..6c48c48 --- /dev/null +++ b/UniSpeech-SAT/speaker_diarization/utils/dataset.py @@ -0,0 +1,484 @@ +# -*- coding: utf-8 -*- # +"""*********************************************************************************************""" +# FileName [ dataset.py ] +# Synopsis [ the speaker diarization dataset ] +# Source [ Refactored from https://github.com/hitachi-speech/EEND ] +# Author [ Jiatong Shi ] +# Copyright [ Copyright(c), Johns Hopkins University ] +"""*********************************************************************************************""" + + +############### +# IMPORTATION # +############### +import io +import os +import subprocess +import sys + +# -------------# +import numpy as np +import soundfile as sf +import torch +from torch.nn.utils.rnn import pad_sequence + +# -------------# +from torch.utils.data.dataset import Dataset +# -------------# + + +def _count_frames(data_len, size, step): + # no padding at edges, last remaining samples are ignored + return int((data_len - size + step) / step) + + +def _gen_frame_indices(data_length, size=2000, step=2000): + i = -1 + for i in range(_count_frames(data_length, size, step)): + yield i * step, i * step + size + + if i * step + size < data_length: + if data_length - (i + 1) * step > 0: + if i == -1: + yield (i + 1) * step, data_length + else: + yield data_length - size, data_length + + +def _gen_chunk_indices(data_len, chunk_size): + step = chunk_size + start = 0 + while start < data_len: + end = min(data_len, start + chunk_size) + yield start, end + start += step + + +####################### +# Diarization Dataset # +####################### +class DiarizationDataset(Dataset): + def __init__( + self, + mode, + data_dir, + chunk_size=2000, + frame_shift=256, + sampling_rate=16000, + subsampling=1, + use_last_samples=True, + num_speakers=3, + filter_spk=False + ): + super(DiarizationDataset, self).__init__() + + self.mode = mode + self.data_dir = data_dir + self.chunk_size = chunk_size + self.frame_shift = frame_shift + self.subsampling = subsampling + self.n_speakers = num_speakers + self.chunk_indices = [] if mode != "test" else {} + + self.data = KaldiData(self.data_dir) + self.all_speakers = sorted(self.data.spk2utt.keys()) + self.all_n_speakers = len(self.all_speakers) + + # make chunk indices: filepath, start_frame, end_frame + for rec in self.data.wavs: + data_len = int(self.data.reco2dur[rec] * sampling_rate / frame_shift) + data_len = int(data_len / self.subsampling) + if mode == "test": + self.chunk_indices[rec] = [] + if mode != "test": + for st, ed in _gen_frame_indices(data_len, chunk_size, chunk_size): + self.chunk_indices.append( + (rec, st * self.subsampling, ed * self.subsampling) + ) + else: + for st, ed in _gen_chunk_indices(data_len, chunk_size): + self.chunk_indices[rec].append( + (rec, st, ed) + ) + + if mode != "test": + if filter_spk: + self.filter_spk() + print(len(self.chunk_indices), " chunks") + else: + self.rec_list = list(self.chunk_indices.keys()) + print(len(self.rec_list), " recordings") + + def __len__(self): + return ( + len(self.rec_list) + if type(self.chunk_indices) == dict + else len(self.chunk_indices) + ) + + def filter_spk(self): + # filter the spk in spk2utt but will not be used in training + # i.e. the chunks contains more spk than self.n_speakers + occur_spk_set = set() + + new_chunk_indices = [] # filter the chunk that more than self.num_speakers + for idx in range(self.__len__()): + rec, st, ed = self.chunk_indices[idx] + + filtered_segments = self.data.segments[rec] + # all the speakers in this recording not the chunk + speakers = np.unique( + [self.data.utt2spk[seg['utt']] for seg in filtered_segments] + ).tolist() + n_speakers = self.n_speakers + # we assume that in each chunk the speaker number is less or equal than self.n_speakers + # but the speaker number in the whole recording may exceed self.n_speakers + if self.n_speakers < len(speakers): + n_speakers = len(speakers) + + # Y: (length,), T: (frame_num, n_speakers) + Y, T = self._get_labeled_speech(rec, st, ed, n_speakers) + # the spk index exist in this chunk data + exist_spk_idx = np.sum(T, axis=0) > 0.5 # bool index + chunk_spk_num = np.sum(exist_spk_idx) + if chunk_spk_num <= self.n_speakers: + spk_arr = np.array(speakers) + valid_spk_arr = spk_arr[exist_spk_idx[:spk_arr.shape[0]]] + for spk in valid_spk_arr: + occur_spk_set.add(spk) + + new_chunk_indices.append((rec, st, ed)) + self.chunk_indices = new_chunk_indices + self.all_speakers = sorted(list(occur_spk_set)) + self.all_n_speakers = len(self.all_speakers) + + def __getitem__(self, i): + if self.mode != "test": + rec, st, ed = self.chunk_indices[i] + + filtered_segments = self.data.segments[rec] + # all the speakers in this recording not the chunk + speakers = np.unique( + [self.data.utt2spk[seg['utt']] for seg in filtered_segments] + ).tolist() + n_speakers = self.n_speakers + # we assume that in each chunk the speaker number is less or equal than self.n_speakers + # but the speaker number in the whole recording may exceed self.n_speakers + if self.n_speakers < len(speakers): + n_speakers = len(speakers) + + # Y: (length,), T: (frame_num, n_speakers) + Y, T = self._get_labeled_speech(rec, st, ed, n_speakers) + # the spk index exist in this chunk data + exist_spk_idx = np.sum(T, axis=0) > 0.5 # bool index + chunk_spk_num = np.sum(exist_spk_idx) + if chunk_spk_num > self.n_speakers: + # the speaker number in a chunk exceed our pre-set value + return None, None, None + + # the map from within recording speaker index to global speaker index + S_arr = -1 * np.ones(n_speakers).astype(np.int64) + for seg in filtered_segments: + speaker_index = speakers.index(self.data.utt2spk[seg['utt']]) + try: + all_speaker_index = self.all_speakers.index( + self.data.utt2spk[seg['utt']]) + except: + # we have pre-filter some spk in self.filter_spk + all_speaker_index = -1 + S_arr[speaker_index] = all_speaker_index + # If T[:, n_speakers - 1] == 0.0, then S_arr[n_speakers - 1] == -1, + # so S_arr[n_speakers - 1] is not used for training, + # e.g., in the case of training 3-spk model with 2-spk data + + # filter the speaker not exist in this chunk and ensure there are self.num_speakers outputs + T_exist = T[:,exist_spk_idx] + T = np.zeros((T_exist.shape[0], self.n_speakers), dtype=np.int32) + T[:,:T_exist.shape[1]] = T_exist + # subsampling for Y will be done in the model forward function + T = T[::self.subsampling] + + S_arr_exist = S_arr[exist_spk_idx] + S_arr = -1 * np.ones(self.n_speakers).astype(np.int64) + S_arr[:S_arr_exist.shape[0]] = S_arr_exist + + n = np.arange(self.all_n_speakers, dtype=np.int64).reshape(self.all_n_speakers, 1) + return Y, T, S_arr, n, T.shape[0] + else: + len_ratio = self.frame_shift * self.subsampling + chunks = self.chunk_indices[self.rec_list[i]] + Ys = [] + chunk_len_list = [] + for (rec, st, ed) in chunks: + chunk_len = ed - st + if chunk_len != self.chunk_size: + st = max(0, ed - self.chunk_size) + Y, _ = self.data.load_wav(rec, st * len_ratio, ed * len_ratio) + Ys.append(Y) + chunk_len_list.append(chunk_len) + return Ys, self.rec_list[i], chunk_len_list + + def get_allnspk(self): + return self.all_n_speakers + + def _get_labeled_speech( + self, rec, start, end, n_speakers=None, use_speaker_id=False + ): + """Extracts speech chunks and corresponding labels + + Extracts speech chunks and corresponding diarization labels for + given recording id and start/end times + + Args: + rec (str): recording id + start (int): start frame index + end (int): end frame index + n_speakers (int): number of speakers + if None, the value is given from data + Returns: + data: speech chunk + (n_samples) + T: label + (n_frmaes, n_speakers)-shaped np.int32 array. + """ + data, rate = self.data.load_wav( + rec, start * self.frame_shift, end * self.frame_shift + ) + frame_num = end - start + filtered_segments = self.data.segments[rec] + # filtered_segments = self.data.segments[self.data.segments['rec'] == rec] + speakers = np.unique( + [self.data.utt2spk[seg["utt"]] for seg in filtered_segments] + ).tolist() + if n_speakers is None: + n_speakers = len(speakers) + T = np.zeros((frame_num, n_speakers), dtype=np.int32) + + if use_speaker_id: + all_speakers = sorted(self.data.spk2utt.keys()) + S = np.zeros((frame_num, len(all_speakers)), dtype=np.int32) + + for seg in filtered_segments: + speaker_index = speakers.index(self.data.utt2spk[seg["utt"]]) + if use_speaker_id: + all_speaker_index = all_speakers.index(self.data.utt2spk[seg["utt"]]) + start_frame = np.rint(seg["st"] * rate / self.frame_shift).astype(int) + end_frame = np.rint(seg["et"] * rate / self.frame_shift).astype(int) + rel_start = rel_end = None + if start <= start_frame and start_frame < end: + rel_start = start_frame - start + if start < end_frame and end_frame <= end: + rel_end = end_frame - start + if rel_start is not None or rel_end is not None: + T[rel_start:rel_end, speaker_index] = 1 + if use_speaker_id: + S[rel_start:rel_end, all_speaker_index] = 1 + + if use_speaker_id: + return data, T, S + else: + return data, T + + def collate_fn(self, batch): + valid_samples = [sample for sample in batch if sample[0] is not None] + + wav_list, binary_label_list, spk_label_list= [], [], [] + all_spk_idx_list, len_list = [], [] + for sample in valid_samples: + wav_list.append(torch.from_numpy(sample[0]).float()) + binary_label_list.append(torch.from_numpy(sample[1]).long()) + spk_label_list.append(torch.from_numpy(sample[2]).long()) + all_spk_idx_list.append(torch.from_numpy(sample[3]).long()) + len_list.append(sample[4]) + wav_batch = pad_sequence(wav_list, batch_first=True, padding_value=0.0) + binary_label_batch = pad_sequence(binary_label_list, batch_first=True, padding_value=1).long() + spk_label_batch = torch.stack(spk_label_list) + all_spk_idx_batch = torch.stack(all_spk_idx_list) + len_batch = torch.LongTensor(len_list) + + return wav_batch, binary_label_batch.float(), spk_label_batch, all_spk_idx_batch, len_batch + + def collate_fn_infer(self, batch): + assert len(batch) == 1 # each batch should contain one recording + Ys, rec, chunk_len_list = batch[0] + wav_list = [torch.from_numpy(Y).float() for Y in Ys] + + return wav_list, rec, chunk_len_list + + +####################### +# Kaldi-style Dataset # +####################### +class KaldiData: + """This class holds data in kaldi-style directory.""" + + def __init__(self, data_dir): + """Load kaldi data directory.""" + self.data_dir = data_dir + self.segments = self._load_segments_rechash( + os.path.join(self.data_dir, "segments") + ) + self.utt2spk = self._load_utt2spk(os.path.join(self.data_dir, "utt2spk")) + self.wavs = self._load_wav_scp(os.path.join(self.data_dir, "wav.scp")) + self.reco2dur = self._load_reco2dur(os.path.join(self.data_dir, "reco2dur")) + self.spk2utt = self._load_spk2utt(os.path.join(self.data_dir, "spk2utt")) + + def load_wav(self, recid, start=0, end=None): + """Load wavfile given recid, start time and end time.""" + data, rate = self._load_wav(self.wavs[recid], start, end) + return data, rate + + def _load_segments(self, segments_file): + """Load segments file as array.""" + if not os.path.exists(segments_file): + return None + return np.loadtxt( + segments_file, + dtype=[("utt", "object"), ("rec", "object"), ("st", "f"), ("et", "f")], + ndmin=1, + ) + + def _load_segments_hash(self, segments_file): + """Load segments file as dict with uttid index.""" + ret = {} + if not os.path.exists(segments_file): + return None + for line in open(segments_file): + utt, rec, st, et = line.strip().split() + ret[utt] = (rec, float(st), float(et)) + return ret + + def _load_segments_rechash(self, segments_file): + """Load segments file as dict with recid index.""" + ret = {} + if not os.path.exists(segments_file): + return None + for line in open(segments_file): + utt, rec, st, et = line.strip().split() + if rec not in ret: + ret[rec] = [] + ret[rec].append({"utt": utt, "st": float(st), "et": float(et)}) + return ret + + def _load_wav_scp(self, wav_scp_file): + """Return dictionary { rec: wav_rxfilename }.""" + if os.path.exists(wav_scp_file): + lines = [line.strip().split(None, 1) for line in open(wav_scp_file)] + return {x[0]: x[1] for x in lines} + else: + wav_dir = os.path.join(self.data_dir, "wav") + return { + os.path.splitext(filename)[0]: os.path.join(wav_dir, filename) + for filename in sorted(os.listdir(wav_dir)) + } + + def _load_wav(self, wav_rxfilename, start=0, end=None): + """This function reads audio file and return data in numpy.float32 array. + "lru_cache" holds recently loaded audio so that can be called + many times on the same audio file. + OPTIMIZE: controls lru_cache size for random access, + considering memory size + """ + if wav_rxfilename.endswith("|"): + # input piped command + p = subprocess.Popen( + wav_rxfilename[:-1], + shell=True, + stdout=subprocess.PIPE, + ) + data, samplerate = sf.read( + io.BytesIO(p.stdout.read()), + dtype="float32", + ) + # cannot seek + data = data[start:end] + elif wav_rxfilename == "-": + # stdin + data, samplerate = sf.read(sys.stdin, dtype="float32") + # cannot seek + data = data[start:end] + else: + # normal wav file + data, samplerate = sf.read(wav_rxfilename, start=start, stop=end) + return data, samplerate + + def _load_utt2spk(self, utt2spk_file): + """Returns dictionary { uttid: spkid }.""" + lines = [line.strip().split(None, 1) for line in open(utt2spk_file)] + return {x[0]: x[1] for x in lines} + + def _load_spk2utt(self, spk2utt_file): + """Returns dictionary { spkid: list of uttids }.""" + if not os.path.exists(spk2utt_file): + return None + lines = [line.strip().split() for line in open(spk2utt_file)] + return {x[0]: x[1:] for x in lines} + + def _load_reco2dur(self, reco2dur_file): + """Returns dictionary { recid: duration }.""" + if not os.path.exists(reco2dur_file): + return None + lines = [line.strip().split(None, 1) for line in open(reco2dur_file)] + return {x[0]: float(x[1]) for x in lines} + + def _process_wav(self, wav_rxfilename, process): + """This function returns preprocessed wav_rxfilename. + Args: + wav_rxfilename: + input + process: + command which can be connected via pipe, use stdin and stdout + Returns: + wav_rxfilename: output piped command + """ + if wav_rxfilename.endswith("|"): + # input piped command + return wav_rxfilename + process + "|" + # stdin "-" or normal file + return "cat {0} | {1} |".format(wav_rxfilename, process) + + def _extract_segments(self, wavs, segments=None): + """This function returns generator of segmented audio. + Yields (utterance id, numpy.float32 array). + TODO?: sampling rate is not converted. + """ + if segments is not None: + # segments should be sorted by rec-id + for seg in segments: + wav = wavs[seg["rec"]] + data, samplerate = self.load_wav(wav) + st_sample = np.rint(seg["st"] * samplerate).astype(int) + et_sample = np.rint(seg["et"] * samplerate).astype(int) + yield seg["utt"], data[st_sample:et_sample] + else: + # segments file not found, + # wav.scp is used as segmented audio list + for rec in wavs: + data, samplerate = self.load_wav(wavs[rec]) + yield rec, data + +if __name__ == "__main__": + args = { + 'mode': 'train', + 'data_dir': "/mnt/lustre/sjtu/home/czy97/workspace/sd/EEND-vec-clustering/EEND-vector-clustering/egs/mini_librispeech/v1/data/simu/data/train_clean_5_ns3_beta2_500", + 'chunk_size': 2001, + 'frame_shift': 256, + 'sampling_rate': 8000, + 'num_speakers':3 + } + + torch.manual_seed(6) + dataset = DiarizationDataset(**args) + + from torch.utils.data import DataLoader + dataloader = DataLoader(dataset, batch_size=2, shuffle=True, collate_fn=dataset.collate_fn) + data_iter = iter(dataloader) + # wav_batch, binary_label_batch, spk_label_batch, all_spk_idx_batch, len_batch = next(data_iter) + data = next(data_iter) + for val in data: + print(val.shape) + + # from torch.utils.data import DataLoader + # dataloader = DataLoader(dataset, batch_size=1, shuffle=True, collate_fn=dataset.collate_fn_infer) + # data_iter = iter(dataloader) + # wav_list, binary_label_list, rec = next(data_iter) diff --git a/UniSpeech-SAT/speaker_diarization/utils/kaldi_data.py b/UniSpeech-SAT/speaker_diarization/utils/kaldi_data.py new file mode 100644 index 0000000..42f6d5e --- /dev/null +++ b/UniSpeech-SAT/speaker_diarization/utils/kaldi_data.py @@ -0,0 +1,162 @@ +# Copyright 2019 Hitachi, Ltd. (author: Yusuke Fujita) +# Licensed under the MIT license. +# +# This library provides utilities for kaldi-style data directory. + + +from __future__ import print_function +import os +import sys +import numpy as np +import subprocess +import soundfile as sf +import io +from functools import lru_cache + + +def load_segments(segments_file): + """ load segments file as array """ + if not os.path.exists(segments_file): + return None + return np.loadtxt( + segments_file, + dtype=[('utt', 'object'), + ('rec', 'object'), + ('st', 'f'), + ('et', 'f')], + ndmin=1) + + +def load_segments_hash(segments_file): + ret = {} + if not os.path.exists(segments_file): + return None + for line in open(segments_file): + utt, rec, st, et = line.strip().split() + ret[utt] = (rec, float(st), float(et)) + return ret + + +def load_segments_rechash(segments_file): + ret = {} + if not os.path.exists(segments_file): + return None + for line in open(segments_file): + utt, rec, st, et = line.strip().split() + if rec not in ret: + ret[rec] = [] + ret[rec].append({'utt':utt, 'st':float(st), 'et':float(et)}) + return ret + + +def load_wav_scp(wav_scp_file): + """ return dictionary { rec: wav_rxfilename } """ + lines = [line.strip().split(None, 1) for line in open(wav_scp_file)] + return {x[0]: x[1] for x in lines} + + +@lru_cache(maxsize=1) +def load_wav(wav_rxfilename, start=0, end=None): + """ This function reads audio file and return data in numpy.float32 array. + "lru_cache" holds recently loaded audio so that can be called + many times on the same audio file. + OPTIMIZE: controls lru_cache size for random access, + considering memory size + """ + if wav_rxfilename.endswith('|'): + # input piped command + p = subprocess.Popen(wav_rxfilename[:-1], shell=True, + stdout=subprocess.PIPE) + data, samplerate = sf.read(io.BytesIO(p.stdout.read()), + dtype='float32') + # cannot seek + data = data[start:end] + elif wav_rxfilename == '-': + # stdin + data, samplerate = sf.read(sys.stdin, dtype='float32') + # cannot seek + data = data[start:end] + else: + # normal wav file + data, samplerate = sf.read(wav_rxfilename, start=start, stop=end) + return data, samplerate + + +def load_utt2spk(utt2spk_file): + """ returns dictionary { uttid: spkid } """ + lines = [line.strip().split(None, 1) for line in open(utt2spk_file)] + return {x[0]: x[1] for x in lines} + + +def load_spk2utt(spk2utt_file): + """ returns dictionary { spkid: list of uttids } """ + if not os.path.exists(spk2utt_file): + return None + lines = [line.strip().split() for line in open(spk2utt_file)] + return {x[0]: x[1:] for x in lines} + + +def load_reco2dur(reco2dur_file): + """ returns dictionary { recid: duration } """ + if not os.path.exists(reco2dur_file): + return None + lines = [line.strip().split(None, 1) for line in open(reco2dur_file)] + return {x[0]: float(x[1]) for x in lines} + + +def process_wav(wav_rxfilename, process): + """ This function returns preprocessed wav_rxfilename + Args: + wav_rxfilename: input + process: command which can be connected via pipe, + use stdin and stdout + Returns: + wav_rxfilename: output piped command + """ + if wav_rxfilename.endswith('|'): + # input piped command + return wav_rxfilename + process + "|" + else: + # stdin "-" or normal file + return "cat {} | {} |".format(wav_rxfilename, process) + + +def extract_segments(wavs, segments=None): + """ This function returns generator of segmented audio as + (utterance id, numpy.float32 array) + TODO?: sampling rate is not converted. + """ + if segments is not None: + # segments should be sorted by rec-id + for seg in segments: + wav = wavs[seg['rec']] + data, samplerate = load_wav(wav) + st_sample = np.rint(seg['st'] * samplerate).astype(int) + et_sample = np.rint(seg['et'] * samplerate).astype(int) + yield seg['utt'], data[st_sample:et_sample] + else: + # segments file not found, + # wav.scp is used as segmented audio list + for rec in wavs: + data, samplerate = load_wav(wavs[rec]) + yield rec, data + + +class KaldiData: + def __init__(self, data_dir): + self.data_dir = data_dir + self.segments = load_segments_rechash( + os.path.join(self.data_dir, 'segments')) + self.utt2spk = load_utt2spk( + os.path.join(self.data_dir, 'utt2spk')) + self.wavs = load_wav_scp( + os.path.join(self.data_dir, 'wav.scp')) + self.reco2dur = load_reco2dur( + os.path.join(self.data_dir, 'reco2dur')) + self.spk2utt = load_spk2utt( + os.path.join(self.data_dir, 'spk2utt')) + + def load_wav(self, recid, start=0, end=None): + data, rate = load_wav( + self.wavs[recid], start, end) + return data, rate diff --git a/UniSpeech-SAT/speaker_diarization/utils/parse_options.sh b/UniSpeech-SAT/speaker_diarization/utils/parse_options.sh new file mode 100644 index 0000000..71fb9e5 --- /dev/null +++ b/UniSpeech-SAT/speaker_diarization/utils/parse_options.sh @@ -0,0 +1,97 @@ +#!/usr/bin/env bash + +# Copyright 2012 Johns Hopkins University (Author: Daniel Povey); +# Arnab Ghoshal, Karel Vesely + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + + +# Parse command-line options. +# To be sourced by another script (as in ". parse_options.sh"). +# Option format is: --option-name arg +# and shell variable "option_name" gets set to value "arg." +# The exception is --help, which takes no arguments, but prints the +# $help_message variable (if defined). + + +### +### The --config file options have lower priority to command line +### options, so we need to import them first... +### + +# Now import all the configs specified by command-line, in left-to-right order +for ((argpos=1; argpos<$#; argpos++)); do + if [ "${!argpos}" == "--config" ]; then + argpos_plus1=$((argpos+1)) + config=${!argpos_plus1} + [ ! -r $config ] && echo "$0: missing config '$config'" && exit 1 + . $config # source the config file. + fi +done + + +### +### Now we process the command line options +### +while true; do + [ -z "${1:-}" ] && break; # break if there are no arguments + case "$1" in + # If the enclosing script is called with --help option, print the help + # message and exit. Scripts should put help messages in $help_message + --help|-h) if [ -z "$help_message" ]; then echo "No help found." 1>&2; + else printf "$help_message\n" 1>&2 ; fi; + exit 0 ;; + --*=*) echo "$0: options to scripts must be of the form --name value, got '$1'" + exit 1 ;; + # If the first command-line argument begins with "--" (e.g. --foo-bar), + # then work out the variable name as $name, which will equal "foo_bar". + --*) name=`echo "$1" | sed s/^--// | sed s/-/_/g`; + # Next we test whether the variable in question is undefned-- if so it's + # an invalid option and we die. Note: $0 evaluates to the name of the + # enclosing script. + # The test [ -z ${foo_bar+xxx} ] will return true if the variable foo_bar + # is undefined. We then have to wrap this test inside "eval" because + # foo_bar is itself inside a variable ($name). + eval '[ -z "${'$name'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1; + + oldval="`eval echo \\$$name`"; + # Work out whether we seem to be expecting a Boolean argument. + if [ "$oldval" == "true" ] || [ "$oldval" == "false" ]; then + was_bool=true; + else + was_bool=false; + fi + + # Set the variable to the right value-- the escaped quotes make it work if + # the option had spaces, like --cmd "queue.pl -sync y" + eval $name=\"$2\"; + + # Check that Boolean-valued arguments are really Boolean. + if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then + echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2 + exit 1; + fi + shift 2; + ;; + *) break; + esac +done + + +# Check for an empty argument to the --cmd option, which can easily occur as a +# result of scripting errors. +[ ! -z "${cmd+xxx}" ] && [ -z "$cmd" ] && echo "$0: empty argument to --cmd option" 1>&2 && exit 1; + + +true; # so this script returns exit code 0. diff --git a/UniSpeech-SAT/speaker_diarization/utils/utils.py b/UniSpeech-SAT/speaker_diarization/utils/utils.py new file mode 100644 index 0000000..3eb0fa1 --- /dev/null +++ b/UniSpeech-SAT/speaker_diarization/utils/utils.py @@ -0,0 +1,189 @@ +import os +import struct +import logging +import torch +import math +import numpy as np +import random +import yaml +import torch.distributed as dist +import torch.nn.functional as F + + +# ------------------------------ Logger ------------------------------ +# log to console or a file +def get_logger( + name, + format_str="%(asctime)s [%(pathname)s:%(lineno)s - %(levelname)s ] %(message)s", + date_format="%Y-%m-%d %H:%M:%S", + file=False): + """ + Get python logger instance + """ + logger = logging.getLogger(name) + logger.setLevel(logging.INFO) + # file or console + handler = logging.StreamHandler() if not file else logging.FileHandler( + name) + handler.setLevel(logging.INFO) + formatter = logging.Formatter(fmt=format_str, datefmt=date_format) + handler.setFormatter(formatter) + logger.addHandler(handler) + return logger + + +# log to concole and file at the same time +def get_logger_2( + name, + format_str="%(asctime)s [%(pathname)s:%(lineno)s - %(levelname)s ] %(message)s", + date_format="%Y-%m-%d %H:%M:%S"): + logger = logging.getLogger(name) + logger.setLevel(logging.INFO) + + # Create handlers + c_handler = logging.StreamHandler() + f_handler = logging.FileHandler(name) + c_handler.setLevel(logging.INFO) + f_handler.setLevel(logging.INFO) + + # Create formatters and add it to handlers + c_format = logging.Formatter(fmt=format_str, datefmt=date_format) + f_format = logging.Formatter(fmt=format_str, datefmt=date_format) + c_handler.setFormatter(c_format) + f_handler.setFormatter(f_format) + + # Add handlers to the logger + logger.addHandler(c_handler) + logger.addHandler(f_handler) + + return logger + + +# ------------------------------ Logger ------------------------------ + +# ------------------------------ Pytorch Distributed Training ------------------------------ +def getoneNode(): + nodelist = os.environ['SLURM_JOB_NODELIST'] + nodelist = nodelist.strip().split(',')[0] + import re + text = re.split('[-\[\]]', nodelist) + if ('' in text): + text.remove('') + return text[0] + '-' + text[1] + '-' + text[2] + + +def dist_init(host_addr, rank, local_rank, world_size, port=23456): + host_addr_full = 'tcp://' + host_addr + ':' + str(port) + dist.init_process_group("nccl", init_method=host_addr_full, + rank=rank, world_size=world_size) + num_gpus = torch.cuda.device_count() + # torch.cuda.set_device(local_rank) + assert dist.is_initialized() + + +def cleanup(): + dist.destroy_process_group() + + +def average_gradients(model, world_size): + size = float(world_size) + for param in model.parameters(): + if (param.requires_grad and param.grad is not None): + dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM) + param.grad.data /= size + + +def data_reduce(data): + dist.all_reduce(data, op=dist.ReduceOp.SUM) + return data / torch.distributed.get_world_size() + + +# ------------------------------ Pytorch Distributed Training ------------------------------ + + +# ------------------------------ Hyper-parameter Dynamic Change ------------------------------ +def reduce_lr(optimizer, initial_lr, final_lr, current_iter, max_iter, coeff=1.0): + current_lr = coeff * math.exp((current_iter / max_iter) * math.log(final_lr / initial_lr)) * initial_lr + for param_group in optimizer.param_groups: + param_group['lr'] = current_lr + + +def get_reduce_lr(initial_lr, final_lr, current_iter, max_iter): + current_lr = math.exp((current_iter / max_iter) * math.log(final_lr / initial_lr)) * initial_lr + return current_lr + + +def set_lr(optimizer, lr): + for param_group in optimizer.param_groups: + param_group['lr'] = lr + +# ------------------------------ Hyper-parameter Dynamic Change ------------------------------ + +# ---------------------- About Configuration -------------------- +def parse_config_or_kwargs(config_file, **kwargs): + with open(config_file) as con_read: + yaml_config = yaml.load(con_read, Loader=yaml.FullLoader) + # passed kwargs will override yaml config + return dict(yaml_config, **kwargs) + + +def store_yaml(config_file, store_path, **kwargs): + with open(config_file, 'r') as f: + config_lines = f.readlines() + + keys_list = list(kwargs.keys()) + with open(store_path, 'w') as f: + for line in config_lines: + if ':' in line and line.split(':')[0] in keys_list: + key = line.split(':')[0] + line = '{}: {}\n'.format(key, kwargs[key]) + f.write(line) + + +# ---------------------- About Configuration -------------------- + + +def check_dir(dir): + if not os.path.exists(dir): + os.mkdir(dir) + + +def set_seed(seed=66): + np.random.seed(seed) + random.seed(seed) + + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + # torch.backends.cudnn.deterministic = True + # torch.backends.cudnn.benchmark = False + + +# when store the model wrongly with "module" involved, +# we remove it here +def correct_key(state_dict): + keys = list(state_dict.keys()) + if 'module' not in keys[0]: + return state_dict + else: + new_state_dict = {} + for key in keys: + new_key = '.'.join(key.split('.')[1:]) + new_state_dict[new_key] = state_dict[key] + return new_state_dict + + +def validate_path(dir_name): + """ + :param dir_name: Create the directory if it doesn't exist + :return: None + """ + dir_name = os.path.dirname(dir_name) # get the path + if not os.path.exists(dir_name) and (dir_name != ''): + os.makedirs(dir_name) + + +def get_lr(optimizer): + for param_group in optimizer.param_groups: + return param_group['lr']