From 6786ea046e94827b31b67e0aefd13a6e6a687587 Mon Sep 17 00:00:00 2001 From: Harry Yang Date: Thu, 25 Apr 2024 10:03:00 -0400 Subject: [PATCH] implment the structure --- .gitignore | 3 +- Trainer.py | 57 +++ __pycache__/Trainer.cpython-310.pyc | Bin 0 -> 2150 bytes __pycache__/aug_helper.cpython-310.pyc | Bin 0 -> 7301 bytes __pycache__/dataset.cpython-310.pyc | Bin 0 -> 1651 bytes __pycache__/model.cpython-310.pyc | Bin 0 -> 7235 bytes __pycache__/utils.cpython-310.pyc | Bin 0 -> 16811 bytes aug_helper.py | 218 +++++++++++ dataset.py | 52 +++ main.py | 55 +++ model.py | 187 ++++++++++ submit.sh | 0 test.py | 0 utils.py | 493 +++++++++++++++++++++++++ 14 files changed, 1064 insertions(+), 1 deletion(-) create mode 100644 Trainer.py create mode 100644 __pycache__/Trainer.cpython-310.pyc create mode 100644 __pycache__/aug_helper.cpython-310.pyc create mode 100644 __pycache__/dataset.cpython-310.pyc create mode 100644 __pycache__/model.cpython-310.pyc create mode 100644 __pycache__/utils.cpython-310.pyc create mode 100644 aug_helper.py create mode 100644 dataset.py create mode 100644 main.py create mode 100644 model.py create mode 100644 submit.sh create mode 100644 test.py create mode 100644 utils.py diff --git a/.gitignore b/.gitignore index 354a328..1c8a54c 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,2 @@ -dataset/ \ No newline at end of file +dataset/ +.output/ \ No newline at end of file diff --git a/Trainer.py b/Trainer.py new file mode 100644 index 0000000..20e2264 --- /dev/null +++ b/Trainer.py @@ -0,0 +1,57 @@ +from tqdm import tqdm +import torch.optim as optim +from model import LabelSmoothCrossEntropyLoss +from dataset import get_dataset +import torch +import torch.nn as nn +import torch.optim as optim +import torch.nn.functional as F +from torch.optim.lr_scheduler import StepLR + + +def train_and_validate(model, criterion, device, train_loader, val_loader, optimizer, epoch): + model.train() + + for batch_idx, (data, target) in tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Training Epoch {epoch}"): + data, target = data.to(device), target.to(device) + optimizer.zero_grad() + output = model(data) + loss = criterion(output, target) + loss.backward() + optimizer.step() + + model.eval() + val_loss = 0 + correct = 0 + with torch.no_grad(): + for data, target in tqdm(val_loader, total=len(val_loader), desc=f"Validating Epoch {epoch}"): + data, target = data.to(device), target.to(device) + output = model(data) + val_loss += criterion(output, target).item() + pred = output.argmax(dim=1, keepdim=True) + correct += pred.eq(target.view_as(pred)).sum().item() + + val_loss /= len(val_loader.dataset) + val_accuracy = 100. * correct / len(val_loader.dataset) + + print(f'Validation set: Average loss: {val_loss:.4f}, Accuracy: {correct}/{len(val_loader.dataset)} ({val_accuracy:.0f}%)') + return val_loss, val_accuracy + +def train_model(data_dir, model, device, lr=0.01, momentum=0.9): + #train_loader, val_loader,_ = load_data(BATCH_SIZE,) + train_loader, val_loader = get_dataset(data_dir) + optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum, weight_decay=weight_decay) + criterion = LabelSmoothingLoss(classes=NUM_CLASSES, smoothing=0.1) + scheduler = StepLR(optimizer,step_size=100,gamma=0.25) + + for epoch in range(1, EPOCHS + 1): + train_and_validate(model, criterion, device, train_loader, val_loader, optimizer, epoch) + scheduler.step() + if epoch % 10 == 0: + checkpoint = { + 'epoch': epoch, + 'model_state_dict': model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'loss': criterion + } + torch.save(checkpoint, f"result/model_checkpoint_epoch_{epoch}.pth") \ No newline at end of file diff --git a/__pycache__/Trainer.cpython-310.pyc b/__pycache__/Trainer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..73adbab7aea14f9bcf6fae50a0c0533e2d050707 GIT binary patch literal 2150 zcmaJ?O>7%Q6rP#=U$5=dLZ{NOo`{w<8?>Q@%eFW>r z==|Dw524@m%kgEv%Qe{fE)a?+PEd+Mgf&kRoRW|fypb5G8JY!eCRS>P_JNKQI)$#A zdZG89??~dO9szr4DGZ8JD=DXyumZeIom~`GX+V8i+Qs1vB~8>Q?}I0(ft4vXXu1l_ zsg>wHOPXnxiOv$w@_e}`c-G%p0Tzx+Z6;bY5>d_syxL8{`YU&0N1HiVGw{@TSc07o zFhUZI(FD8_SP*H9urM^GSlW-!V?xQEaRYt&>L+CAz;=h8bjPM76D+;ylT4-7*wR)c zlI{o(eTk(ny%8CfWNB;{C)U`}&y5<~=xUA{!vJKpzPzb@kS5%3!u{sI?qB@h zyASUS4sRWoCg?I62GV~JP)lcbj|?laEGy$O_+62e37+6_b%G8w%Nut-hB}+@8yOh& z+(ndsv5hf!iP1gSGSGfKm>JVcoknm^=I;>(v1ZSL)Zd7?0Ylo&Zp;)Cb1=n`?vMV*Pp? zZ}KSKxl(_M7Y3w$4o)tu4anI>qXv#>-#7u2%45CF6if)FjD+J{yqEPwH|>IR%3?aHp?5zL zai`U#4TA@aEwH33$pSEuUjpLdbL1p; z;pa~`F|i4J9cyU~SMVGrf1duk;@0p*;$nUQK6>)m?Dy~eI@rDX6YRe2x0BmY=;v)(p72o!&|MTmFvysa+B*4D~m_H zHdFTU?GKjTZx(3aHGSY8xzA^{z|5m9)(DiVArE13eg?$+MNQ^4SlqNc<)9Bo(Y7a&c?B@>O=GQiR zvF<+@KmIWu-waPl8@ucJtoQ#gsMs~hH6 zndO;Ti6mwVC6sFgtzx7J5Fp6IGDsfS76po;=u7^Az6E({G3ZkhAmC%smo`A{{=PH2 zD=sN7F=x*0J9Ex={hg1=++5YbuXz83_6tjf@hfT^{F!Ln#FPF93SkJ=HT?3;e5TvH zXL5rXJkT)C&iLDi}xvUPAubnT0AYz z<9$YaL0k|QF>^^gBQBxmtoWRG7VmT7Iq^K+%c3b>!24;z-Z7e2_HjaErSr;Pzxbcu z-Fo9ji<$?2#0vVQ%P6El$M^Y1Mw|d7F4HI z=t+N$A~SZ35!(g(a+Vtbj|`M%juqFrdbVQ-{uA`CnZmqopk<9)wr)+voa4Ao@pF;x{*?}Mr~Lsg%lnv9L9GV9x{IkysBeAmEf_YE9uABQ(0!-Fol z_o6uH$`~W}amalf&rle}$ZFE3k;0CHb~}si-ynMBVoVWJ^em9wq`E@|4MjJ9M^Bq} z=((8Ipc$mUdT!G$>~zqVs_=p!ilZzD3f@kOD!vr-LlyQ=e>dz7us0mEE?Q z#{zq-s02Z)8>VRx>>0nkd^K&UFl((|UEO-=c$272yK$zUpBU%X(fYrEJvA1m0r2t zOVqqfkn-9&?~B1+Im)#6*kx@yNAECaCklFn32)}PjU4OUNe?U?6mP{e*FLy z!8$nfqA}z-xb;DOXy!aO1z*CnCD`qkVC2`04Of`sc~)j*Ea$}6p`C$MJ519#xt$rH zb(X;Mb4G6Eu(6G2X#LF0(T8uinMF04|2`jj!wP0q=vzxh?&com|6cAwTGomp!R}M> z8$9|)X3_2@VWt{r7q$xHjx0>F!ipqWK+k2fTDWQ0>vv`9d$G0ynHIL@QDH@y?D-4# zWYk%O1f(hrrC=)siwdXgFU&}67UtdX2&XUf!gU#xv(liQsNhc64da#+M(L(R_Z zkwp$RGuG@iM{r?2Vnb)9ovU3k}wbn71WNPq9q`P|X*+}nkUoim1Yv`$cK zIG5M+xi*J#I4gzK8i?>VOkz&P;T?FGEb2vZM-l(`NyVFsg`Fj;wOTkr-i=xkBuRg( z@Pw3oU9%)f_;)e4u-~{VQxh zTZs}yE!~P+t15}3ZCTW1H|(cS32aH0nrfz*XwIvP`dd-lA7p+_$%zY2eRq(`;M&V~ z)R$?7qpdvf-82go{DW`HSm5QBXG{wg6eh_1N^78${*3SSWEcn722lyB9QG2YFe4o( ziFneMNXHek|shWyzt@vhU#eZweM2^4Jb} zYytQ2J-$J@!L;Adhu`DsX-rd>sdyY+v~R9n??!3XLL9uI*Q3Q3jqUTTq}zoP3EJ@i zr#zkdnfp<1fqsu^^(#NMa`dC>Yv{X8{a>eoTpNP5DWV>L!@;<7i_a&TSt!Zc)xFBfs}m6BwDNY(3sgV31;eLZlM8pt%dW5e_SU%Qp_82 zSf@E-ohJ`kCy5Hdm6zhnwLNa%sy6Vu4eyvX!bT=Hvqn4W%0|-GZ;JXDh(&`OC`}=j z&W!2>w(UXI{=%1G)*ciz5wPf^QG22T>8uWnjNP#gXeMC|N2QsH`pr@e3f4ZvJuQy_ zY|!%9zCb%pgDeSXb59aK;jE-j(MHKobfZ|3d_1C~6RP*I79tcabLdmW+|1&Ciis(C zsd=l)mOq_05g9_-2)VBoiD~kqNIkWsvuJ30f!ki^aP#aDM-(-OR!&%*uP9OioL4K8 z*$X!jF{s@J?&TF)jckOLMT9AP=6eO)gd2ub4h{-^jI)j$%;2|DVFD6|(^9I}V-LOJ^o zP-X<1+SRr+;sSA4n2W|k6UOx1RMj~7+# z1A>&|i;Iwhvt*@;X_!oorH@C0z88({<*{pw+A_>8~pivep{) zIt_8Tp=fJ>Rhxa7Z4jnSr*M)%27VR}!cs?Os)rfA1x;GlnpId}u_$*H(R$j^dXufJ zqAOiM;UR!5>hN+Fy1&R(iq08`N*54PDcVb#qk!_jx^WW+XFoKT;G;4&qQ@-Zq(HzR zMInXDAptT7!Phxd3-d^chEDDXh7nhoxr1Jysa+cqX~+xc zrV2%x$cr>W3$5uDE;x`TDpS9Y&cbO8dIZ+31Rg;>PhGWc(usiSgE$$ZRaf@YsGG#i znj%rD4VvbTZP2f0Nru@ak}uqn!wB;j_pALRos0m}U|zxdegCPcMh9UQM1em)+i}OGp-jCIoq}L$4#w+Seua_`=)sKu z_k4%}Z08WrxTPWE5{Za>z=!Y#?NU^fNO|10GaGjTdk6k4N3LMwjIeS(a&{4(Aq;!_ zF^>x|D>k=sJa~u-|z87F{Q9%NW9JzS6k1 z`Ns%e)VSQZ7iFthxAU26HrDk@kKM(}z5q*_u2u@wMO`5&V5pl^5OCBCEi8g+O~BO3MYkqN+0AE_^j#SpPs}kVpEA4A5iT>D*lR!LwI#@_K3K5 z0uzof@xVdP96soGG5(wXTK(0-fBv6e-_R@0L~R$tQ){M>&@qN49*CISEiQQH5vOf( zoog&z*zj=Ich<}uZSCnfTn+hMLtV%%D221jkSg(VQjkdhq3Q23Q} zJwi?n;H|LrLb(4Yfvy!6Er?qk<<1WCb|DmEe5o&I|2O7ns92uJE|Xa}Qq9 z`b7N@<41tCjZST3e}f9RsKu_ic$2~v;UQdYVO*KHX05@$rDvxtXNRlL6Yj}?U)Cx+ zFi*;(j@!7vYwkMJ`87~~W!kEP`WEP_!Wb~rj?m@s%Zk8bbFOg9j9jYz0PUg*M5HaW zB}>rAJeri3jiGg3~N!)q@QE8F}Lf01n%vdnZkr**oo<9y)`Dj}IBy z+7Q`yjbU~EJ!21bj*JkrPX8byw@arnVL?yGW0E1Cpan0X(0Fuy$mpW*(0B-zpa!!f z{Pf72+>(JVb*{k{*;ZdR?GpUya7Fz(EuyjL7=~(ux5w4r(%>EyniEs>`Ac+Z^m)nH zzR(E)xRIF9k@SfW)KWjCHFQw?8ET~rO1h?1en5Zf?`W`AuB1)7{O&&xwZ@loGavf} z`o|JDQ(~BZM9++^a&(Xcw5^i%2uOfD6=jU+?f}#>z6#^+Ql`e}K4GsEM;Fp!#2We^ zDc`~ZA5@6*lrs~?=}G>W$oX&CWzs*Sf{?E-6B88d4r!F4-0i>9inK?5|E#gyfTbWk z)aeh%S`M&tl3iB+gsF5lT}z@^{h4m~?s$H|jv$D+nFariK1RT95z!JV`!AYaeG#u) z$fnR;5BvQXFTK^~`BGK>5*@_}ongBY$5+~exJCa?Kt}CHG*Ug#@sK(wSC0QLXuN!^ zBqQTf(H-2-kU=}pP^1y+D^zS!K@!t>Y@I#T`6iv)D6h+bKnbBggDy`&Oue8r!XS@#9p$g`q4nRhnoa1$*v5w9Vti|gDPR8>@$D8!R(;*Ldc2Can6nucw@sEXneSu-+(y{YA7WTK`% z5uId)G;p`UoiwCpYG%m1t6LMOP82f}s{A4^Mmo7ymOkkOjCXW(Tv?qsI>)9(?yAf7UJfQq>m01M5+i zY9E4sU;)S-;JE`lw}6A;9pDMLly^W75~KTQ7Y)#dXatfVaUaab1qONDY~}%UN3cIZ zv838y4P3v#YkV`qF!OI_KFo;n=R^}uEPbY!M?j%Kd7D7v3k3ADnh#o}kvLT0)Z_8* zh$lQf$1UPy9fSQVWoa4e_{?vKimiTzfu6DUfvgxP>2DI9fR{|eE%&^6%8AN!a)y4x z_dK2P@QLU4q_Ah8Gd-wmU1g^~y`$U0Ns*VBXms!ppKr^`KG#213FXSs*9a8aH2>Y(? ziRtctX=0M+K%#F2+-$;EtXY=Q_)XbPh9z3zE~tM-U?&U9E1*Q_|sf!}T=YF7t{;Zh=>>PPXDx3qUq4 zl`AYL3+qwtff$hy@cb@Laf&|n_xu#4cn^Av#TT$VLC2GU36`pEbb)TNEo!MP1KoI;CLpzF0Yi)%TWk-$_+oVK^N*oGSNUCg4tc-@+vx{Bq z&VW6${6H2}j>$y1AojtRRKcl84;*j8se$9mDmc{% zrw&d{*1@SyI7{Fx$p$!$38xV+%jHc|eduG2w{&T)>|@QwdQ%ij;rHThXCsg0jWo+z zWG_DM!w>Mf($C|;N92*kb6sji8O(Sc3$M~@j25D}L#h&q#mjwjbgSf_) z6nhY78}*(otzC2MK#W3~r;QhkJu!AATPkBuntM3=5%_!7j`2OHxsAU^7LEV{ziedR zx{o>d_X8C2n$EfsSkFEQj>FEn+i~H__+GaaPmcS)uQdI_$y3$ZC@M)Bc2pz_Cmw9& z57lverq)0Lr`_%M!@$d0QJ%*tENmI&QQ;4h?Cvm*x8tUzR@nO(2~SEjW@H()VmT~N zIZTr{+(>gaXQx5VnN_hOY+*lZm?wq(S@oIito+t?w!iwhhjVbW1LxsjR4S4_j|Zqw z=9K9eDhpW~Jqr>va~qNt;Fvp5VOWtd6zV+Hx#x~6l=hl|7IS2+8IV4{tQk8tT77B3 ztS^u3vA0K#eGAeC*9P}ZaH~@;PU8}?bRhvD5} z)Gu4K5f=3135l(ba705%}9G`Cg^-3u4?RY$9$R%tm> zFc~H2VQZ#)TCDh*-vaBvGljL^5XNShTG5zFmjFyHT|r=L}bA{w!%qY8dQ#w}M{o z2wH~U3fHk`?m^o^uN1#|8;ZXYL2NFi^k4p&&#$iB#&N%bQJ}!YyYf8EBq^%5 zV!%L>ccXr>d^2wMV-5BnMDnsM+-qra55Kjy3DoPU8sNph5l5R3Z^k!1EGl;>-7ue4 z+J%>Gq$+O>^QNs;y|83=p#GRseuYGzzs47>1jKeOOKW!(db}b|3J>`1{?2w+%i3or zo$V8gCpm-n7UeR1Sf`p?L&?sA0H6&X3Vv~5d*T#1I55_3!5ysw`2H0A)Hg_82YF(a zYZea{G*+*ZO8q8E##zpb?X?5znVOfW&q%$@pbG;Cb7~`98b3oNH*!exkvTSL9&nY28>G2Y$!(~(If4ruxuXgU z!rAe_b4T9D0#sKj+k?Z@=Haeusdq|E|K*jy?!^6J(+Mh)U>5faC)tQ28B`-^_Px?PmZ6nlJC5b# z=+GG0+#3ou(v7FDp3*)3!_#Z?HSPVy9;;XYABjhXmRzC~%goVGAW??+I z6T>`jbdxwzh1qTuPPDbve+b~!ky#nMWlg_S=vu4a-3t40(#bapUw<4z0)>}MHmcs{ z-gE>&`w4uCBSd%>2R@cn_BzUVub_8EWNjcUv&EV?4Ff(Wpt@rF_*~DPZ?d!K>nR4w zVX(RAnU73$HMhn#qODgLZOynBMxy6vZ+{W};Oq4};s9|Vs-j{nMp;L_BkFo%G3sKU z$xI&R|Ni0eXYW4HS*I$q{#EqfnXLC1;y2Ord$n;LXm202GR%DM@XT73Co!Ko0a0k@ z{ORAW{rtWE-2T~>=2B7ZX59o1IBCUN=hR=Y7oD&=MM4-YY@R-}CDlFiP`$$@nzy=6 za)zA{E2=@%4`*;USU@8UL^sgwJ17~Es}49mDON?}sVN%IO!4Bgy1jyvCVv%g_OtEN zizmIPfCmtoQ3s1g=sqe4>wqG3i+dJb?hb7qP8R~Sr8t8;4+3PkheV04woM1>?3(Du zo|SL#3xNQSfjILSW9$>(85EAIblPX&vdb{937lXot(qj%Sal8rZt)HN)Pf2!o53)T ziz=j-cRRy$n5p;B7P#?)ttgSmel{X>qR0yi(X{$=wpO(X!zklWy$gn-;F@Kir@pCs z%#1fHL|WKeNU07>oQ~o9C>cZXRW0wTmM}xU?bQYOPMr<*L*r9dr(|8dh3b>ak&cHa z`hR#I!W)auv*>-i8vGAA&-XtU&^q%!-eQU?k=i1Worhz7QGe^LchdVw77e!gvG$Vs zFxC&mV{|yg1nqSG1C@hLhX+4RR(1O7C2GTqWWjZw}P#YU9-Uu;1qy7%#G`JjT z{JnX`Xu~1tSl=1pZ(Q3$zP5v!giB$MHu~(iP@=kYM#wHo?xf8+V~m%eehyt&8xCxW z22?SvSX#wO?ScdD4zj{V5|%d`r62hSyD9@PpKKk-Df<1wYi&dcV!Lb-T*DV6oml+< zZGqheh8Iq|pF;DPRs_rN;oXcm-JM~(jlA4d@m3t=&5A~S6>=Vqt3GCTb0aE5r2dAC zI?2H>OoBB-awgel#Y@n@74fRrKCyUmM`JldFP@z??yIQi%uO3KV`myRe__IK>@aCK zd6+cvW-UpYyen;(yo2Du8d(TbAWaG~sHie3EzOq#M&QP6!DJ&TdH29D2)jN{rUjv45xtf>)^ebQO#;)>G@S<_x)R zW{s#kbDGg_h{ZaztEnjuocTjJ@=-J`^|#wgl;Y8fi>Ox%P%AD?0%i}-1it}Lxn(GZ{qPbLulh3* zZ8me;A#(p6O2(683bS3;lTG!*oFCrv)uWAXhbSs0;Tx6Zz(Y{vlQC4AVXpR`yjJYs z7W7~FR%w?w%qvCa0p*3FE$edWv9XW`G>Y1-bePEXIKLU)-~S>;UElu^=sq{H&*@+) z34RU}x|x;_GJPlvzH2?~!OPC|kow2_7v#;*K;;;!HWxX_OL`%4PMJyU4or;-BXtmb)tTv{?(ax=KXyvaJl`WO_aQ9G7&8VX= z^S1g+u7SY|oF}2KFq3xKMXz(V#G=whdY7tfp2QSRnf{4xA=-rq3(-3YHa@|q8I2sU zVU|}8_V4O0uOiL?Na~1m%^%g#Mot}nXNCFf9Bw}88(n}5GM%*rOL{)a(4;iV{5vWO zNzI#+)Zv0o9T4j>ped6KxJ6H=J1-;s@DbDU-8w3%IGpYv$@tpgBx5pX4e83w!{g1~ z9jTKR!u&KTp`TG7kWh1UdeW?`zha%)ZwcP>OoT2@XBz4@YuY2+WbGD-2Kf(JBW#p} zx~Nb2-%!z4%S>>-WOJCidW4?V| p~50HzLE~fZn7biXev9)h7FHo+{%WaE!f-Jw`Znj zd%Aniy|c5s+}4oYn1F5+AEkiE45&B|5UqkTD1wwc)S#3WAQvoxh;g-4G*BRd_P-KbDWlzr7SUG+kRtM|xb+Tacb`wjjMl(vwI}I$M$6D(S6APdU>_PfL0V>21z-q_<0Y z8tEO*HAr70>1{|~>s*KQb&}qW^iJn`q_3Cs4y1QEyOG{4>1&YQz%f{*+=aw}+ zzt`Dc-s@=2t&bRYYUSPKy%+TSKBV>P4NEuXbGlZ}t=?4LyQW7mluc`^&v9;B)18BB zhUDJ7h8cA9;}UYWUnVkUb9#^})$4_&qNgfnkxCq1o_|5nt&~bE@4ieVW)R$cAZ0-g zP?K~F$GoTmH!}CYjm$Z4lQ^#BQbF?mV$G>9-T&hIpTTXoo<3A!ORbA>1K|k5vsn<| z@U?>G>!85qacy0nT{i-~5g2oB&I~NKT%A+Ixk`epYbuSuZ`^ReEvcecS~#$9?%?xo zyY0Yz_ZE)Sml|cS;#KOk1Iu2e>h5ox3lfDwrB?9@1>W|o1!C(d{rp(_G55DBlx!5X zuC0i~5%`$y`@lY)wyOKa5FXdjw_-@w%vHnF7R@!iX`{??4BuWensMKDOq82`d_tSj znh8Jd#~f=~o6!7(Z-ud~cwNR&o|JOET`n;xM-1gDDUY?wB_`#Fp?nB^$NZ$9@`u*U zSD4LobGVs7p6#cd*tq5olcIhm{KwOdpfv7liwS>Z0t~xd17z?<&8)O!`}j#qqs_7A zIPxst1Wo#*Yr6c8t)aHW#u>vqPxur5$abwcirkbxvZk*NIR+@lCygGr_}RXktUns& z;GSWB)X(}`=8QSxg1(wrqqxx?!6?re(^_*9ZI1YpXf3;jl+Sch(xYn{`b5s+7{}6y zO+p-jp+4(Rq6Om-x!a!vDC4u2F@f6UZ{t#wiAs`0;XHz-HI`>`qYPOgF&oOUn#2~u~c>!=2gW3E>9~jh^hK=%@KG8X7QvOBg$Ot~7EOU1K6OseO$sv*=wR;3i^wLn*a zp_T&ERV6%PEpH*>=1sNa2FBS+-Z;0IH_o35^h%&F209{6pnHK{4)nP|pAU32axTyp zgl$4HEX}KmBKx^Ff$Y%JdRE^lzX|KXixp^l}WAMP*?S}5_Q zS{EOiJSsf%+K^|oH4RwPPig889NjY(wN;bY0P+l=7owyILDHSy1dsHjV@|9mZF>eV zY56hVo+tOk{kUUKYdv)=E9hs{?ec$KQ7Nh2BsY*e zmxLm|3*r#{rDC;OD6odRA0z`@TYA>W0_|CS!k9ohE0q6Z264U*_)qBew-ZLMk_elc zUDri~j8krNG2(lrk-h1pt$BtI2URlY5 z&z2|wTNmSG;Ufqx=b&2~1G-wl0IO*Q)1iKbjfj?w*~=rSnTv%JPSuRrEETE^`CDiN z2~2UpD*_xbnaPg&R`;W}I!@9V3mFEE)KP?I*`BS3KPc+y{RT3%9jGxkYLa3H(fhv04l$$)98peFUS zF$raq)(`4CpqQYJ2Eilb)SMaOeRe&j?nNVt&8b-u;mS=J@))Y{*`K97NQO+u74tS_ zNA-g!1;dSmP<^1XR6f2?ne!rkBiFVr##{UrTEFD0kn6Rzbl-a|M^} z;f*LLLbfY2Hh_l;%qM@IpXz&@!HmNe+l=$Bex?0Mu!P}g%bJ>iVu2I@fDHf`a>1cl z2xZM0YX;0){8OcSu~i!?J7gVu3WhHX%ot2HDngpZG}7)MjA2S!rz~~Mi^CYD;z~O4 zO9q%KzL)|lVLWJX(~tv`=uQG$*=*YG7^l?MmO# z+caZ%kJ#c6o}bTX!&oo6Gu0dL0~+Oq3O&RIf}FD1ZY>!9sPY01wV?=5f_$X-WA2gXoo-V zE4rs3y{{lJ6a|t9h@27FHLT&?Qz8s{1VfXR1_*|W%ooA1Z9~U{MT3Pq`0_^(+(i(u zI4o8OgAosFMlha1OLIcIAVSPDS52|@Eoj)=p|N7L^m&hn1ruTcA^__}6JSm-AA-RG zxViTYyQ$dzGwAoA9?p(xuoK21}4EsXm;R=g;@n~ ztTB_Si7Tv{Q_ESQ>?GRqNy=5%D=IIas5oZ}tkFMYMAc&mE|o@)8@2yg2NDWH1@d<$cTd7FX#1z?##AbuLJ{4F8W- zGSNxvV%$JDg3yDeRIN~2DAsD_s=E0L%0V)LR|p> zaGAK3rWh7@1WpWKF(hr%E|s(xKp_*vyJkHlD~#f@Q>kN_v0N*0jRI3;WNEojEv^*G zwNjl67{!-q=Q`yux5Yq2hdMyg_qs&GFCvJMP9$WBgUw$%S}eezA_J)lBqvBZ?Mh5c z--l7~vn&))FjYe6xj1Gs2)&@dBDdccCzgNYgF6g3u=^56yLB<%f`B`T1V_4MfrJ)6 zt^9C=&A_bIR|0FHGQSYwB4VT0;70XY655AdZ`HQt-;8__iL~T94mUj0e-jmO96%2O z88p(hMLOGHGhjgtA)|*v!Si5z2s}qjcju0LM!mKAh!RWv;H43ab4UFn^b; zg>DqUn}gBo%$B?tQShq>5eg1-_M`y$o^ERD# z^(abVVx~eyI|{8DPA#Xhbur$WT3JSgQ`??SXnoEWR9eT3>^xn2=AJ;EXW0&BZM@Up z;Avw$Lh1;DIEs6id6dixJz+WeyzUs6FkuC-Hb82IW{axgJ}e5_OO0?jp^yvHr05k} z4=bG6_pqFf-(>Q=1R-t5FIawbi+lZA}%NlR*j#=t39| z`J{WfMRC?YGMuUpGXEz*&R;w5DF;s%0m~rMRYRm_y6%976CMdFK}R2lm4tP_s6XVv z!zVTsJ@_#$uIRnMk_Z;m5SCWkwQ=NO31z^-qPR@@@J=*i06@x*wE+<8@&>9o^#2S3 zuIJ9cD(_%|_LG_hfqX0%YigD084?ahkW#d)#M!L~pTH_QPGQ~dSXzPZhE+{>eo39< zoqcd8x;~5`W)^vP2aNJ;C|EZg(FkdgZ7%b;xvYHyTK@oh{uvUct_ZlubJhaf%Vsa{ z=DngziLchhcrVsOmqu74?yz=-E`1Z~Da!lx2vUNk%{1pHSX&?!IX61baTRE~h`W2x zH1j`2lb3+%1$H7_B?VV0yd-RQAZ*CDun-P0Hq?bNyB}lfEqpO9(sJoRAX9H;`6|f_ zi44*>VnLEX>)=fN8jF9OWB{5ziTo}!?Es>>Jw-i(*#Q#ezfADvKzci>{28kWkeDhV z^jrkQ41$Uz*F>*OI992hg5YZRwW;7zPq2|TY}7lM?0Z*ApQo8GP`PGMPlMr-!cA@A zV`GWX1D971gp0n2#T>RMD3)xW+b-A^`S%nvxz7uYJVr!He9N1i0Ohy8T zn>6sSh*!quG7-L(ILf)b8OMeN_L48@SmMaG-D-M`9v^IHz~jSO$<<*$4xdK`y;|4~ zv{tP(LNALLwgE+~X^UfcigEPLmBR$WSR@|dTXNuIfp|#*R<>YUrmYp%lM!I4 z8bepzos%PT+rzt1ZWy5D1bqJE1*`P(G#CFvV%;_O=p18|n@ z$Jypn(Gkp4x}n58!C0p^iOHhnNfdmIg#tLHt_ZkjUDE=={blF7inc6}5~6KKuUwXn zzDJ?UW=0wJcZ2%$dR{(|aqYf?bjaet3%TJ(ZmgHBr&xCN%WWS>OELz zKdQ~aM*@dlvU9lZQ8-ufG5Aa3t=+|=s(iY$ySTx%XQEvOmkVPLzUD=2h&lQtS^9H9 zA}&6Y%MvVsbGmHJuUpvRU2fon09BFjyc7HXR;`6<@zY$ zE+e<(X$8s_1NPInNRf?tE?J@$SXb=SA=s%g+>u5Y#$I1$5r^XZ7=kQraHs9ri&^nT zOW(XxTEX$Vb{xl#B#x7i*rl|gvF*iCShiigZFUci>B#mzrWi+bFnZW!!`6DCTCRx= zN(-<1`XILY>kSqMv04E;Y7gM7PAa-%-oqK2s2FkeQQRqq0D_Dvd&>$=eryp`%V)6; z{TR#sh=jFzHvDSHfa4ckq|ia3firPSZg_^{FmHH>gjQ*Yl#ZP%*6cD1xNK9WL9iX; zVcRcO;U65cLoB$51@fBMeF+R_E+3x@x46}Qmh+_p{fD|FwE&E!#S1Y+!Z#Ug@lC*d z-GMC`8%8XD27dOx+jf|;zh%W*8++XV4Eb0*wD48j%MR}&nI*}Pu%lo&!g#67BSc-# z@*!4{2l{=)4z&=UfD1AkQ_S1I@#M%E1mVi|!JVWWq2*E9@CdCH9ifem*ugkCUa6Hu zn_NDKyF&*A4Wib?xPc&y!<7fON#M-q!5!|EcWdnNC)e*duK-Tew=hwSBj6+1r?;Xb<$0UGQ$R+DWn+ABrbJGhj6Cd+pF~ zraQ&hIMb1Ogj4w)Z1~e8V!WSVY>?6QDHP-hog!;6RYK^wOr&NIWHr%doQ)^mxTVlO z@rIq4pwx5yhub<2JeXr>Ffa(mEFc0Db`G~41iIm89CSt}L@Uc>CzU62 z;0}(Ce}hwyi|rY_J2k1n?RR`dkfMJlLNXl1qv$Vmech1q(=EkEV?P0v8`a|Tm_%wQ|!T6#y zW)OUR43QpuokZ$8*k6#%k;T=L%{a}~V`%|f@b7K3m&96%(UXW4WTOJIl-5|vNmAC* z`htg#r>}rV*i>nrMGbAhw(T{j-BjDA8j2+Ghw)OOPwxwieT3wrAP-sd>c|T+OO>$qsR{aJli`# z(rHv;(()Hk@CJ^xoK0b>gwS)DNX;O$Ha(Juc?t)gmARoj<~D1O3g(0-5O|||6TmQ! zGuCg>AA{vu5hLwx?UnyZ;|wEhsA&&snWvwyMZhfvoDe&mM464UJs#GxIDH0T9OcH5 zdjLztFgG6MCXjol2SI}`u#z~1fRELpk`zi_?BUdRyCfBr458#f51SqBlA)+1jgpcV z?Lz0)z5HLl0iw+!*bLaakBCRnW$tGC69Yhy55j8=HnnOmbc?lV>GHU zff{f2(mge{NR5f8#w2PyfiJ1J6TDSQCZm!mlsxHW&@Vodl#;2aWEv$;c_ZQ2cJ;mO z627f*rlVThQR_2aw(E_yMx>|KFB0nGVcqTJCx%lLvR#}tx*q$8XRt{VF2NLo?K2bVXl7f5+ zCyw0EZiQCrdqAFd`lsjNu1}#yIa;($JHM|jCff3P|DDx3q>_6VBCT8gOAo2_ zJ*4_)Jdyerl4nWU52e0>BJ|Di5yj@#4z7sFz-j5pG_)g$Afsly;&pZzb+mO%DDhtm| zhxg+2t6`&+DSSA)aoT!50;F&!v%M6-CDx%>vA)M@oBlhV?&s-FSz$ozS`c2ZSyY3irDC6ueN3+{B!>F$m6-J!}VG}no_DvE2Q3pvc zLEdLL;MBdez^HwBT$qQ?p?iAJ3JxwAk#g`8~b$l;XCwURc?I6J@xT#UCd)3NGeDTd#U`PLEqY6f2mU;!N zTJA!9Wzd)hi95qj`0gT(++AmsUbA8w literal 0 HcmV?d00001 diff --git a/aug_helper.py b/aug_helper.py new file mode 100644 index 0000000..8764b68 --- /dev/null +++ b/aug_helper.py @@ -0,0 +1,218 @@ +import torch +import random +import torch.nn.functional as F +import numpy as np +import torch.distributed as dist +import copy +epsilon = 1e-8 + + +class AugBasic: + def __init__(self, fs): + super().__init__() + self.fs = fs + self.fft_params = {} + if fs == 22050: + self.fft_params['win_len'] = [512, 1024, 2048] + self.fft_params['hop_len'] = [128, 256, 1024] + self.fft_params['n_fft'] = [512, 1024, 2048] + elif fs == 16000: + self.fft_params['win_len'] = [256, 512, 1024] + self.fft_params['hop_len'] = [256 // 4, 512 // 4, 1024 // 4] + self.fft_params['n_fft'] = [256, 512, 1024] + elif fs == 8000: + self.fft_params['win_len'] = [128, 256, 512] + self.fft_params['hop_len'] = [32, 64, 128] + self.fft_params['n_fft'] = [128, 256, 512] + else: + raise ValueError + + +def count_parameters(model): + # return sum(p.numel() for p in model.parameters() if p.requires_grad) + return sum(p.numel() for p in model.parameters() if p.requires_grad) + + +def make_weights_for_balanced_classes(samples, nclasses): + count = [0] * nclasses + for item in samples: + count[item[1]] += 1 + weight_per_class = [0.] * nclasses + N = float(sum(count)) + for i in range(nclasses): + weight_per_class[i] = N/float(count[i]) + weight = [0] * len(samples) + for idx, val in enumerate(samples): + weight[idx] = weight_per_class[val[1]] + return weight + + +def measure_inference_time(model, input, repetitions=300, use_16b=False): + device = torch.device("cuda") + model_= copy.deepcopy(model) + model_.eval() + starter = torch.cuda.Event(enable_timing=True) + ender = torch.cuda.Event(enable_timing=True) + # repetitions = 300 + timings = np.zeros((repetitions, 1)) + print(input.shape) + if use_16b: + input = input.half() + model_.half() + else: + pass + input = input.to(device) + model_.to(device) + for _ in range(10): + _ = model_(input) + with torch.no_grad(): + # GPU-WARM-UP + for rep in range(repetitions): + starter.record() + _ = model_(input) + ender.record() + # WAIT FOR GPU SYNC + torch.cuda.synchronize() + curr_time = starter.elapsed_time(ender) + timings[rep] = curr_time + mean_syn = np.sum(timings) / repetitions + std_syn = np.std(timings) + return mean_syn, std_syn + +def collate_fn(batch): + x = [item[0] for item in batch] + y = [item[1] for item in batch] + x = torch.stack(x, dim=0).contiguous() + return (x, y) + +def files_to_list(filename): + """ + Takes a text file of filenames and makes a list of filenames + """ + with open(filename, encoding="utf-8") as f: + files = f.readlines() + + files = [f.rstrip() for f in files] + return files + + +def find_first_nnz(t, q, dim=1): + _, mask_max_indices = torch.max(t == q, dim=dim) + return mask_max_indices + + +def accuracy(output, target, topk=(1,)): + """Computes the precision@k for the specified values of k""" + maxk = max(topk) + batch_size = target.size(0) + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + with torch.no_grad(): + correct = pred.eq(target.view(1, -1).expand_as(pred)) + return [correct[:k].view(-1).float().sum(0) * 100. / batch_size for k in topk] + + +def average_precision(output, target): + # sort examples + indices = output.argsort()[::-1] + # Computes prec@i + total_count_ = np.cumsum(np.ones((len(output), 1))) + target_ = target[indices] + ind = target_ == 1 + pos_count_ = np.cumsum(ind) + total = pos_count_[-1] + pos_count_[np.logical_not(ind)] = 0 + pp = pos_count_ / total_count_ + precision_at_i_ = np.sum(pp) + precision_at_i = precision_at_i_/(total + epsilon) + return precision_at_i + + +def mAP(targs, preds): + """Returns the model's average precision for each class + Return: + ap (FloatTensor): 1xK tensor, with avg precision for each class k + """ + if np.size(preds) == 0: + return 0 + ap = np.zeros((preds.shape[1])) + # compute average precision for each class + for k in range(preds.shape[1]): + # sort scores + scores = preds[:, k] + targets = targs[:, k] + # compute average precision + ap[k] = average_precision(scores, targets) + return 100*ap.mean() + +def pad_sample_seq(x, n_samples): + if x.size(-1) >= n_samples: + max_x_start = x.size(-1) - n_samples + x_start = random.randint(0, max_x_start) + x = x[x_start: x_start + n_samples] + else: + x = F.pad( + x, (0, n_samples - x.size(-1)), "constant" + ).data + return x + + +def pad_sample_seq_batch(x, n_samples): + if x.size(0) >= n_samples: + max_x_start = x.size(0) - n_samples + x_start = random.randint(0, max_x_start) + x = x[:, x_start: x_start + n_samples] + else: + x = F.pad( + x, (0, n_samples - x.size(1)), "constant" + ).data + return x + + +def add_weight_decay(model, weight_decay=1e-5, skip_list=()): + decay = [] + no_decay = [] + for name, param in model.named_parameters(): + # print(name) + if not param.requires_grad: + continue + if len(param.shape) == 1 or name in skip_list: + no_decay.append(param) + else: + decay.append(param) + return [ + {'params': no_decay, 'weight_decay': 0.}, + {'params': decay, 'weight_decay': weight_decay}] + + +def _get_bn_param_ids(net): + bn_ids = [] + for m in net.modules(): + print(m) + if isinstance(m, torch.nn.BatchNorm1d) or isinstance(m, torch.nn.LayerNorm): + bn_ids.append(id(m.weight)) + bn_ids.append(id(m.bias)) + elif isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.Linear): + if m.bias is not None: + bn_ids.append(id(m.bias)) + return bn_ids + + +def reduce_tensor(tensor, n): + rt = tensor.clone() + dist.all_reduce(rt, op=dist.ReduceOp.SUM) + rt /= n + return rt + + +def gather_tensor(tensor, n): + rt = tensor.clone() + tensor_list = [torch.zeros(n, device=tensor.device, dtype=torch.cuda.float()) for _ in range(n)] + dist.all_gather(tensor_list, rt) + return tensor_list + + +def parse_gpu_ids(gpu_ids): #list of ints + s = ''.join(str(x) + ',' for x in gpu_ids) + s = s.rstrip().rstrip(',') + return s \ No newline at end of file diff --git a/dataset.py b/dataset.py new file mode 100644 index 0000000..4ab37c5 --- /dev/null +++ b/dataset.py @@ -0,0 +1,52 @@ +from utils import AudioAugs +import os +import librosa +import pandas as pd +import numpy as np +import torch +import torchaudio + +def load_audio_files_with_torchaudio(path, file_paths, augmentor): + features = [] + for file_path in file_paths: + full_path = os.path.join(path, file_path) + waveform, sample_rate = torchaudio.load(full_path) + waveform = waveform.mean(dim=0, keepdim=True) # Ensure mono by averaging channels + augmented_waveform, _ = augmentor(waveform.squeeze(0).numpy()) + augmented_waveform = torch.tensor(augmented_waveform, dtype=torch.float32).unsqueeze(0) + mfccs = torchaudio.transforms.MFCC(sample_rate=sample_rate, n_mfcc=13)(augmented_waveform) + mfccs_mean = mfccs.mean(dim=2).squeeze(0).numpy() + features.append(mfccs_mean) + return features + + +def get_dataset(data_dir, apply_augmentation=True): + """ + Load dataset and process it for classification task with optional augmentation. + """ + train_audio_path = os.path.join(data_dir, 'train_mp3s') + test_audio_path = os.path.join(data_dir, 'test_mp3s') + label_file = os.path.join(data_dir, 'train_label.txt') + + labels = pd.read_csv(label_file, header=None, names=['file', 'label']) + + train_files = os.listdir(train_audio_path) + test_files = os.listdir(test_audio_path) + + # Instantiate the augmentor + augmentor = AudioAugs(k_augs=['flip', 'tshift', 'mulaw'], fs=22050) if apply_augmentation else None + + # Load and process audio files + train_features = load_audio_files_with_augmentation(train_audio_path, train_files, augmentor) if apply_augmentation else load_audio_files(train_audio_path, train_files) + test_features = load_audio_files(test_audio_path, test_files) # Assume no augmentation for testing + + train_df = pd.DataFrame(train_features) + train_df['label'] = labels['label'].values[:len(train_features)] # Make sure labels align correctly + + test_df = pd.DataFrame(test_features) + + return train_df, test_df + +# # Example usage +# data_dir = '/scratch/hy2611/ML_Competition/dataset' +# train_data, test_data = get_dataset(data_dir) diff --git a/main.py b/main.py new file mode 100644 index 0000000..88de130 --- /dev/null +++ b/main.py @@ -0,0 +1,55 @@ +from model import SoundNetRaw +from Trainer import train_model +from dataset import get_dataset +import torch +from tqdm import tqdm +import pandas as pd + + +NUM_CLASSES = 4 +EPOCHS = 200 +BATCH_SIZE = 32 +learning_rate = 0.01 +momentum = 0.9 +weight_decay = 0.0005 +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +def predict(model, device, test_loader): + model.eval() + predictions = [] + + with torch.no_grad(): + for data in tqdm(test_loader, total=len(test_loader), desc="Predicting"): + images = data[0].to(device) + outputs = model(images) + _, predicted = torch.max(outputs, 1) + predictions.extend(predicted.cpu().numpy()) + + return predictions + +def save_predictions_to_csv(predictions, file_name): + df = pd.DataFrame({'id': range(len(predictions)), 'category': predictions}) + df.to_csv(file_name, index=False) + + +if __name__ == '__main__': + model = SoundNetRaw( + nf=32, # Number of filters in the initial convolution layer + clip_length=66150 // 256, # Total samples (66150 for 3s at 22050 Hz) divided by the product of the downsampling factors + embed_dim=128, # Embedding dimension + n_layers=4, # Number of layers + nhead=8, # Number of attention heads + factors=[4, 4, 4, 4], # Downsampling factors for each layer + n_classes=4, # Number of classes (adjust based on your specific task) + dim_feedforward=512 # Dimensionality of the feedforward network within the transformer layers + ) + model.to(device) + data_dir = '/scratch/hy2611/ML_Competition/dataset' + train_model(data_dir, model, device) + + torch.save(model, "Limbo.pth") + + # test_loader = load_data(BATCH_SIZE,)[2] + _, test_loader = get_dataset(data_dir) + predictions = predict(model, device, test_loader) + save_predictions_to_csv(predictions, 'predictions.csv') \ No newline at end of file diff --git a/model.py b/model.py new file mode 100644 index 0000000..89eb159 --- /dev/null +++ b/model.py @@ -0,0 +1,187 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.modules.loss import _WeightedLoss + + +class LabelSmoothCrossEntropyLoss(_WeightedLoss): + def __init__(self, weight=None, reduction='mean', smoothing=0.0): + super().__init__(weight=weight, reduction=reduction) + self.smoothing = smoothing + self.weight = weight + self.reduction = reduction + + @staticmethod + def _smooth_one_hot(targets: torch.Tensor, n_classes: int, smoothing=0.0): + assert 0 <= smoothing < 1 + with torch.no_grad(): + targets = torch.empty(size=(targets.size(0), n_classes), + device=targets.device) \ + .fill_(smoothing / (n_classes - 1)) \ + .scatter_(1, targets.data.unsqueeze(1), 1. - smoothing) + return targets + + def forward(self, inputs, targets): + targets = LabelSmoothCrossEntropyLoss._smooth_one_hot(targets, inputs.size(-1), + self.smoothing) + lsm = F.log_softmax(inputs, -1) + + if self.weight is not None: + lsm = lsm * self.weight.unsqueeze(0) + + loss = -(targets * lsm).sum(-1) + + if self.reduction == 'sum': + loss = loss.sum() + elif self.reduction == 'mean': + loss = loss.mean() + return loss + + + + +class ResBlock1dTF(nn.Module): + def __init__(self, dim, dilation=1, kernel_size=3): + super().__init__() + self.block_t = nn.Sequential( + nn.ReflectionPad1d(dilation * (kernel_size//2)), + nn.Conv1d(dim, dim, kernel_size=kernel_size, stride=1, bias=False, dilation=dilation, groups=dim), + nn.BatchNorm1d(dim), + nn.LeakyReLU(0.2, True) + ) + self.block_f = nn.Sequential( + nn.Conv1d(dim, dim, 1, 1, bias=False), + nn.BatchNorm1d(dim), + nn.LeakyReLU(0.2, True) + ) + self.shortcut = nn.Conv1d(dim, dim, 1, 1) + def forward(self, x): + return self.shortcut(x) + self.block_f(x) + self.block_t(x) + + +class TAggregate(nn.Module): + def __init__(self, clip_length=None, embed_dim=64, n_layers=6, nhead=6, n_classes=None, dim_feedforward=512): + super(TAggregate, self).__init__() + self.num_tokens = 1 + drop_rate = 0.1 + enc_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=nhead, activation="gelu", dim_feedforward=dim_feedforward, dropout=drop_rate) + self.transformer_enc = nn.TransformerEncoder(enc_layer, num_layers=n_layers, norm=nn.LayerNorm(embed_dim)) + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.pos_embed = nn.Parameter(torch.zeros(1, clip_length + self.num_tokens, embed_dim)) + self.fc = nn.Linear(embed_dim, n_classes) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + with torch.no_grad(): + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + # nn.init.constant_(m.weight, 1) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Parameter): + with torch.no_grad(): + m.weight.data.normal_(0.0, 0.02) + # nn.init.orthogonal_(m.weight) + + def forward(self, x): + x = x.permute(0, 2, 1).contiguous() + cls_tokens = self.cls_token.expand(x.shape[0], -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + x += self.pos_embed + x.transpose_(1, 0) + o = self.transformer_enc(x) + pred = self.fc(o[0]) + return pred + + +class AADownsample(nn.Module): + def __init__(self, filt_size=3, stride=2, channels=None): + super(AADownsample, self).__init__() + self.filt_size = filt_size + self.stride = stride + self.channels = channels + ha = torch.arange(1, filt_size//2+1+1, 1) + a = torch.cat((ha, ha.flip(dims=[-1,])[1:])).float() + a = a / a.sum() + filt = a[None, :] + self.register_buffer('filt', filt[None, :, :].repeat((self.channels, 1, 1))) + + def forward(self, x): + x_pad = F.pad(x, (self.filt_size//2, self.filt_size//2), "reflect") + y = F.conv1d(x_pad, self.filt, stride=self.stride, padding=0, groups=x.shape[1]) + return y + + +class Down(nn.Module): + def __init__(self, channels, d=2, k=3): + super().__init__() + kk = d + 1 + self.down = nn.Sequential( + nn.ReflectionPad1d(kk // 2), + nn.Conv1d(channels, channels*2, kernel_size=kk, stride=1, bias=False), + nn.BatchNorm1d(channels*2), + nn.LeakyReLU(0.2, True), + AADownsample(channels=channels*2, stride=d, filt_size=k) + ) + + def forward(self, x): + x = self.down(x) + return x + + +class SoundNetRaw(nn.Module): + def __init__(self, nf=32, clip_length=None, embed_dim=128, n_layers=4, nhead=8, factors=[4, 4, 4, 4], n_classes=None, dim_feedforward=512): + super().__init__() + model = [ + nn.ReflectionPad1d(3), + nn.Conv1d(1, nf, kernel_size=7, stride=1, bias=False), + nn.BatchNorm1d(nf), + nn.LeakyReLU(0.2, True), + ] + self.start = nn.Sequential(*model) + model = [] + for i, f in enumerate(factors): + model += [Down(channels=nf, d=f, k=f*2+1)] + nf *= 2 + if i % 2 == 0: + model += [ResBlock1dTF(dim=nf, dilation=1, kernel_size=15)] + self.down = nn.Sequential(*model) + + factors = [2, 2] + model = [] + for _, f in enumerate(factors): + for i in range(1): + for j in range(3): + model += [ResBlock1dTF(dim=nf, dilation=3 ** j, kernel_size=15)] + model += [Down(channels=nf, d=f, k=f*2+1)] + nf *= 2 + self.down2 = nn.Sequential(*model) + self.project = nn.Conv1d(nf, embed_dim, 1) + self.clip_length = clip_length + self.tf = TAggregate(embed_dim=embed_dim, clip_length=clip_length, n_layers=n_layers, nhead=nhead, n_classes=n_classes, dim_feedforward=dim_feedforward) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Conv1d): + with torch.no_grad(): + m.weight.data.normal_(0.0, 0.02) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + elif isinstance(m, nn.BatchNorm1d): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def forward(self, x): + x = self.start(x) + x = self.down(x) + x = self.down2(x) + x = self.project(x) + pred = self.tf(x) + return pred + + +if __name__ == '__main__': + pass \ No newline at end of file diff --git a/submit.sh b/submit.sh new file mode 100644 index 0000000..e69de29 diff --git a/test.py b/test.py new file mode 100644 index 0000000..e69de29 diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..50d1aff --- /dev/null +++ b/utils.py @@ -0,0 +1,493 @@ +import numpy as np +import torch +import torchaudio +import random +import scipy +import torch.nn.functional as F +from scipy.sparse import coo_matrix +from aug_helper import AugBasic + + +class RandomRIR(AugBasic): + def __init__(self, fs, p=0.5): + self.p = p + self.fs = fs + + def rir(self, mic, n, r, rm, src): + nn = np.arange(-n, n+1, 1).astype(np.float32) + srcs = np.power(-1, nn) + rms=nn+0.5-0.5*srcs + xi=srcs*src[0]+rms*rm[0]-mic[0] + yj=srcs*src[1]+rms*rm[1]-mic[1] + zk=srcs*src[2]+rms*rm[2]-mic[2] + [i, j, k]=np.meshgrid(xi, yj, zk) + d = np.sqrt(i**2+j**2+k**2) + t = np.round(self.fs*d/343.)+1 + [e, f, g]=np.meshgrid(nn, nn, nn) + c = np.power(r, np.abs(e)+np.abs(f)+np.abs(g)) + e = c/d + y = np.ones_like(d).reshape(-1).astype(np.int32) + t = t.reshape(-1).astype(np.int32) + e = e.reshape(-1) + h = coo_matrix((e, (t, y))).todense()[:, 1] + h = np.array(h).ravel() + h = h/np.abs(h).max() + if h.shape[0] % 2 == 0: + h = h[:-1] + return h + + def __call__(self, sample): + if random.random() < self.p: + r = 2 * np.random.rand(1) - 1 + n = 3 + + x = 20 * np.random.rand(1) + y = 20 * np.random.rand(1) + z = 4 * np.random.rand(1) + rm = np.array([x, y, z]) + + x = rm[0] * np.random.rand(1) + y = rm[1] * np.random.rand(1) + z = rm[2] * np.random.rand(1) + + mic = np.array([x, y, z]) + x = rm[0] * np.random.rand(1) + y = rm[1] * np.random.rand(1) + z = rm[2] * np.random.rand(1) + + src = np.array([x, y, z]) + + h = self.rir(mic, n, r, rm, src) + h = torch.from_numpy(h).float() + sample = sample[None, None, :] + sample = F.pad(sample, (h.shape[-1]//2, h.shape[-1]//2), "reflect") + sample = F.conv1d(sample, h[None, None, :], bias=None, stride=1, padding=0, dilation=1, + groups=sample.shape[1]) + return sample, h + + +class RandomLPHPFilter(AugBasic): + def __init__(self, fs, p=0.5, fc_lp=None, fc_hp=None): + self.p = p + self.fs = fs + self.fc_lp = fc_lp + self.fc_hp = fc_hp + self.num_taps = 15 + + def __call__(self, sample): + if random.random() < self.p: + a = 0.25 + if random.random() < 0.5: + fc = 0.5 + random.random() * 0.25 + filt = scipy.signal.firwin(self.num_taps, fc, window='hamming') + else: + fc = random.random() * 0.25 + filt = scipy.signal.firwin(self.num_taps, fc, window='hamming', pass_zero=False) + filt = torch.from_numpy(filt).float() + filt = filt / filt.sum() + sample = F.pad(sample.view(1, 1, -1), (filt.shape[0]//2, filt.shape[0]//2), mode="reflect") + sample = F.conv1d(sample, filt.view(1, 1, -1), stride=1, groups=1) + sample = sample.view(-1) + return sample + + +class RandomTimeShift(AugBasic): + def __init__(self, p=0.5, max_time_shift=None): + self.p = p + self.max_time_shift = max_time_shift + + def __call__(self, sample): + if random.random() < self.p: + if self.max_time_shift is None: + self.max_time_shift = sample.shape[-1] // 10 + int_d = 2*random.randint(0, self.max_time_shift)-self.max_time_shift + frac_d = np.round(100*(random.random()-0.5)) / 100 + if int_d + frac_d == 0: + return sample + if int_d > 0: + pad = torch.zeros(int_d, dtype=sample.dtype) + sample = torch.cat((pad, sample[:-int_d]), dim=-1) + elif int_d < 0: + pad = torch.zeros(-int_d, dtype=sample.dtype) + sample = torch.cat((sample[-int_d:], pad), dim=-1) + else: + pass + if frac_d == 0: + return sample + n = sample.shape[-1] + dw = 2 * np.pi / n + if n % 2 == 1: + wp = torch.arange(0, np.pi, dw) + wn = torch.arange(-dw, -np.pi, -dw).flip(dims=(-1,)) + else: + wp = torch.arange(0, np.pi, dw) + wn = torch.arange(-dw, -np.pi - dw, -dw).flip(dims=(-1,)) + w = torch.cat((wp, wn), dim=-1) + phi = frac_d * w + sample = torch.fft.ifft(torch.fft.fft(sample) * torch.exp(-1j * phi)).real + return sample + + +class RandomTimeMasking(AugBasic): + def __init__(self, p=0.5, n_mask=None): + self.n_mask = n_mask + self.p = p + + def __call__(self, sample): + if self.n_mask is None: + self.n_mask = int(0.05 * sample.shape[-1]) + if random.random() < self.p: + max_start = sample.size(-1) - self.n_mask + idx_rand = random.randint(0, max_start) + sample[idx_rand:idx_rand + self.n_mask] = torch.randn(self.n_mask) * 1e-6 + return sample + + +class RandomMuLawCompression(AugBasic): + def __init__(self, p=0.5, n_channels=256): + self.n_channels = n_channels + self.p = p + + def __call__(self, sample): + if random.random() < self.p: + e = torchaudio.functional.mu_law_encoding(sample, self.n_channels) + sample = torchaudio.functional.mu_law_decoding(e, self.n_channels) + return sample + + +class RandomAmp(AugBasic): + def __init__(self, low, high, p=0.5): + self.low = low + self.high = high + self.p = p + + def __call__(self, sample): + if random.random() < self.p: + amp = torch.FloatTensor(1).uniform_(self.low, self.high) + sample.mul_(amp) + return sample + + +class RandomFlip(AugBasic): + def __init__(self, p=0.5): + self.p = p + + def __call__(self, sample): + if random.random() < self.p: + sample.data = torch.flip(sample.data, dims=[-1, ]) + return sample + + +class RandomAdd180Phase(AugBasic): + def __init__(self, p=0.5): + self.p = p + + def __call__(self, sample): + if random.random() < self.p: + sample.mul_(-1) + return sample + + +class RandomAdditiveWhiteGN(AugBasic): + def __init__(self, p=0.5, snr_db=30): + self.snr_db = snr_db + self.min_snr_db = 30 + self.p = p + + def __call__(self, sample): + if random.random() < self.p: + s = torch.sqrt(torch.mean(sample ** 2)) + snr_db = self.min_snr_db + torch.rand(1) * (self.snr_db - self.min_snr_db) + sgm = s * 10 ** (-snr_db / 20.) + w = torch.randn_like(sample).mul_(sgm) + sample.add_(w) + return sample + + +class RandomAdditiveUN(AugBasic): + def __init__(self, snr_db=35, p=0.5): + self.snr_db = snr_db + self.min_snr_db = 30 + self.p = p + + def __call__(self, sample): + if random.random() < self.p: + s = torch.sqrt(torch.mean(sample ** 2)) + snr_db = self.min_snr_db + torch.rand(1) * (self.snr_db - self.min_snr_db) + sgm = s * 10 ** (-snr_db / 20.) * np.sqrt(3) + w = torch.rand_like(sample).mul_(2 * sgm).add_(-sgm) + sample.add_(w) + return sample + + +class RandomAdditivePinkGN(AugBasic): + def __init__(self, snr_db=35, p=0.5): + self.snr_db = snr_db + self.min_snr_db = 30 + self.p = p + + def __call__(self, sample): + if random.random() < self.p: + s = torch.sqrt(torch.mean(sample ** 2)) + n = sample.shape[-1] + w = torch.randn(n) + nn = n // 2 + 1 + k = torch.arange(1, nn + 1, 1).float() + W = torch.fft.fft(w) + W = W[:nn] / k.sqrt() + W = torch.cat((W, W.flip(dims=(-1,))[1:-1].conj()), dim=-1) + w = torch.fft.ifft(W).real + w.add_(w.mean()).div_(w.std()) + snr_db = self.min_snr_db + torch.rand(1) * (self.snr_db - self.min_snr_db) + sgm = s * 10 ** (-snr_db / 20.) + sample.add_(w.mul_(sgm)) + return sample + + +class RandomAdditiveVioletGN(AugBasic): + def __init__(self, p=0.5, snr_db=35): + self.snr_db = snr_db + self.min_snr_db = 30 + self.p = p + + def __call__(self, sample): + if random.random() < self.p: + s = torch.sqrt(torch.mean(sample ** 2)) + n = sample.shape[-1] + w = torch.randn(n) + nn = n // 2 + 1 + k = torch.arange(1, nn + 1, 1).float() + W = torch.fft.fft(w) + W = W[:nn] * k + W = torch.cat((W, W.flip(dims=(-1,))[1:-1].conj()), dim=-1) + w = torch.fft.ifft(W).real + w.add_(w.mean()).div_(w.std()) + snr_db = self.min_snr_db + torch.rand(1) * (self.snr_db - self.min_snr_db) + sgm = s * 10 ** (-snr_db / 20.) + sample.add_(w.mul_(sgm)) + return sample + + +class RandomAdditiveRedGN(AugBasic): + def __init__(self, p=0.5, snr_db=35): + self.snr_db = snr_db + self.min_snr_db = 30 + self.p = p + + def __call__(self, sample): + if random.random() < self.p: + s = torch.sqrt(torch.mean(sample ** 2)) + n = sample.shape[-1] + w = torch.randn(n) + nn = n // 2 + 1 + k = torch.arange(1, nn + 1, 1).float() + W = torch.fft.fft(w) + W = W[:nn] / k + W = torch.cat((W, W.flip(dims=(-1,))[1:-1].conj()), dim=-1) + w = torch.fft.ifft(W).real + w.add_(w.mean()).div_(w.std()) + snr_db = self.min_snr_db + torch.rand(1) * (self.snr_db - self.min_snr_db) + sgm = s * 10 ** (-snr_db / 20.) + sample.add_(w.mul_(sgm)) + return sample + + +class RandomAdditiveBlueGN(AugBasic): + def __init__(self, p=0.5, snr_db=35): + self.snr_db = snr_db + self.min_snr_db = 30 + self.p = p + + def __call__(self, sample): + if random.random() < self.p: + s = torch.sqrt(torch.mean(sample ** 2)) + n = sample.shape[-1] + w = torch.randn(n) + nn = n // 2 + 1 + k = torch.arange(1, nn + 1, 1).float() + W = torch.fft.fft(w) + W = W[:nn] * k.sqrt() + W = torch.cat((W, W.flip(dims=(-1,))[1:-1].conj()), dim=-1) + w = torch.fft.ifft(W).real + w.add_(w.mean()).div_(w.std()) + snr_db = self.min_snr_db + torch.rand(1) * (self.snr_db - self.min_snr_db) + sgm = s * 10 ** (-snr_db / 20.) + sample.add_(w.mul_(sgm)) + return sample + + +class RandomFreqShift(AugBasic): + def __init__(self, sgm, fs, p=0.5): + super().__init__(fs=fs) + self.sgm = sgm + self.p = p + + def __call__(self, sample): + if random.random() < self.p: + win_idx = random.randint(0, len(self.fft_params['win_len']) - 1) + df = self.fs / self.fft_params['win_len'][win_idx] + f_shift = torch.randn(1).mul_(self.sgm * df) + t = torch.arange(0, self.fft_params['win_len'][win_idx], 1).float() + w = torch.real(torch.exp(-1j * 2 * np.pi * t * f_shift)) + X = torch.stft(sample, + win_length=self.fft_params['win_len'][win_idx], + hop_length=self.fft_params['hop_len'][win_idx], + n_fft=self.fft_params['n_fft'][win_idx], + window=w, + return_complex=True) + sample = torch.istft(X, + win_length=self.fft_params['win_len'][win_idx], + hop_length=self.fft_params['hop_len'][win_idx], + n_fft=self.fft_params['n_fft'][win_idx]) + + return sample + + +class RandomAddSine(AugBasic): + def __init__(self, fs, snr_db=35, max_freq=50, p=0.5): + self.snr_db = snr_db + self.max_freq = max_freq + self.min_snr_db = 30 + self.p = p + self.fs = fs + + def __call__(self, sample): + n = torch.arange(0, sample.shape[-1], 1) + f = self.max_freq * torch.rand(1) + 3 * torch.randn(1) + if random.random() < self.p: + snr_db = self.min_snr_db + torch.rand(1) * (self.snr_db - self.min_snr_db) + t = n * 1. / self.fs + s = (sample ** 2).mean().sqrt() + sgm = s * np.sqrt(2) * 10 ** (-snr_db / 20.) + b = sgm * torch.sin(2 * np.pi * f * t + torch.rand(1) * np.pi) + sample.add_(b) + + return sample + + +class RandomAmpSegment(AugBasic): + def __init__(self, low, high, max_len=None, p=0.5): + self.low = low + self.high = high + self.max_len = max_len + self.p = p + + def __call__(self, sample): + if random.random() < self.p: + if self.max_len is None: + self.max_len = sample.shape[-1] // 10 + idx = random.randint(0, self.max_len) + amp = torch.FloatTensor(1).uniform_(self.low, self.high) + sample[idx: idx + self.max_len].mul_(amp) + return sample + + +class RandomPhNoise(AugBasic): + def __init__(self, fs, sgm=0.01, p=0.5): + super().__init__(fs=fs) + self.sgm = sgm + self.p = p + + def __call__(self, sample): + if random.random() < self.p: + win_idx = random.randint(0, len(self.fft_params['win_len']) - 1) + sgm_noise = self.sgm + 0.01 * torch.rand(1) + X = torch.stft(sample, + win_length=self.fft_params['win_len'][win_idx], + hop_length=self.fft_params['hop_len'][win_idx], + n_fft=self.fft_params['n_fft'][win_idx], + return_complex=True) + w = sgm_noise * torch.rand_like(X) + phn = torch.exp(1j * w) + X.mul_(phn) + sample = torch.istft(X, + win_length=self.fft_params['win_len'][win_idx], + hop_length=self.fft_params['hop_len'][win_idx], + n_fft=self.fft_params['n_fft'][win_idx]) + return sample + + +class RandomCyclicShift(AugBasic): + def __init__(self, max_time_shift=None, p=0.5): + self.max_time_shift = max_time_shift + self.p = p + + def __call__(self, sample): + if random.random() < self.p: + if self.max_time_shift is None: + self.max_time_shift = sample.shape[-1] + int_d = random.randint(0, self.max_time_shift - 1) + if int_d > 0: + sample = torch.cat((sample[-int_d:], sample[:-int_d]), dim=-1) + else: + pass + return sample + + + +class AudioAugs(): + def __init__(self, k_augs, fs, p=0.5, snr_db=30): + self.noise_vec = ['awgn', 'abgn', 'apgn', 'argn', 'avgn', 'aun', 'phn', 'sine'] + augs = {} + for aug in k_augs: + if aug == 'amp': + augs['amp'] = RandomAmp(p=p, low=0.5, high=1.3) + elif aug == 'flip': + augs['flip'] = RandomFlip(p) + elif aug == 'neg': + augs['neg'] = RandomAdd180Phase(p) + elif aug == 'awgn': + augs['awgn'] = RandomAdditiveWhiteGN(p=p, snr_db=snr_db) + elif aug == 'abgn': + augs['abgn'] = RandomAdditiveBlueGN(p=p, snr_db=snr_db) + elif aug == 'argn': + augs['argn'] = RandomAdditiveRedGN(p=p, snr_db=snr_db) + elif aug == 'avgn': + augs['avgn'] = RandomAdditiveVioletGN(p=p, snr_db=snr_db) + elif aug == 'apgn': + augs['apgn'] = RandomAdditivePinkGN(p=p, snr_db=snr_db) + elif aug == 'mulaw': + augs['mulaw'] = RandomMuLawCompression(p=p, n_channels=256) + elif aug == 'tmask': + augs['tmask'] = RandomTimeMasking(p=p, n_mask=int(0.1 * fs)) + elif aug == 'tshift': + augs['tshift'] = RandomTimeShift(p=p, max_time_shift=int(0.1 * fs)) + elif aug == 'sine': + augs['sine'] = RandomAddSine(p=p, fs=fs) + elif aug == 'cycshift': + augs['cycshift'] = RandomCyclicShift(p=p, max_time_shift=None) + elif aug == 'ampsegment': + augs['ampsegment'] = RandomAmpSegment(p=p, low=0.5, high=1.3, max_len=int(0.1 * fs)) + elif aug == 'aun': + augs['aun'] = RandomAdditiveUN(p=p, snr_db=snr_db) + elif aug == 'phn': + augs['phn'] = RandomPhNoise(p=p, fs=fs, sgm=0.01) + elif aug == 'fshift': + augs['fshift'] = RandomFreqShift(fs=fs, sgm=1, p=p) + else: + raise ValueError("{} not supported".format(aug)) + self.augs = augs + self.augs_signal = [a for a in augs if a not in self.noise_vec] + self.augs_noise = [a for a in augs if a in self.noise_vec] + + def __call__(self, sample, **kwargs): + augs = self.augs_signal.copy() + augs_noise = self.augs_noise + random.shuffle(augs) + if len(augs_noise) > 0: + i = random.randint(0, len(augs_noise) - 1) + augs.append(augs_noise[i]) + for aug in augs: + sample = self.augs[aug](sample) + return sample + + +if __name__ == "__main__": + r = RandomRIR(fs=22050, p=1) + x = torch.zeros(22050) + x[0:100] = 1 + y = r(x) + import matplotlib.pyplot as plt + plt.plot(x) + plt.plot(y[0].view(-1), 'r') + plt.show() \ No newline at end of file