Arnab Das commited on
Commit
b30e39a
1 Parent(s): 3f67209

AASIST model added.

Browse files
Files changed (4) hide show
  1. app.py +5 -3
  2. models.py +231 -0
  3. orig_aasist_epoch_1.pth +3 -0
  4. process_data.py +4 -1
app.py CHANGED
@@ -11,7 +11,9 @@ model_master = {
11
  "model_checkpoint": "ssl_aasist_epoch_7.pth"},
12
  "AASIST": {"eer_threshold": 1.8018419742584229,
13
  "data_process_func": "process_assist_input",
14
- "note": "This model is trained on ASVSpoof 2024 training data."}
 
 
15
  }
16
 
17
  model = MOD.Model(None, "cpu")
@@ -21,8 +23,6 @@ loaded_model = "SSL-AASIST (Trained on ASV-Spoof5)"
21
 
22
 
23
  def process(file, type):
24
- if type == "AASIST":
25
- return "Model AASIST is not yet implemented."
26
  global model
27
  global loaded_model
28
  inp = getattr(PD, model_master[type]["data_process_func"])(file)
@@ -54,6 +54,8 @@ file_proc = gr.Interface(
54
  examples=[
55
  ["./bonafide.flac", "SSL-AASIST (Trained on ASV-Spoof5)"],
56
  ["./fake.flac", "SSL-AASIST (Trained on ASV-Spoof5)"],
 
 
57
  ],
58
  cache_examples=True,
59
  allow_flagging="never",
 
11
  "model_checkpoint": "ssl_aasist_epoch_7.pth"},
12
  "AASIST": {"eer_threshold": 1.8018419742584229,
13
  "data_process_func": "process_assist_input",
14
+ "note": "This model is trained on ASVSpoof 2024 training data.",
15
+ "model_class":"AASIST_Model",
16
+ "model_checkpoint": "orig_aasist_epoch_1.pth"}
17
  }
18
 
19
  model = MOD.Model(None, "cpu")
 
23
 
24
 
25
  def process(file, type):
 
 
26
  global model
27
  global loaded_model
28
  inp = getattr(PD, model_master[type]["data_process_func"])(file)
 
54
  examples=[
55
  ["./bonafide.flac", "SSL-AASIST (Trained on ASV-Spoof5)"],
56
  ["./fake.flac", "SSL-AASIST (Trained on ASV-Spoof5)"],
57
+ ["./bonafide.flac", "AASIST"],
58
+ ["./fake.flac", "AASIST"],
59
  ],
60
  cache_examples=True,
61
  allow_flagging="never",
models.py CHANGED
@@ -1,6 +1,9 @@
1
  import torch
 
2
  import fairseq
 
3
  import torch.nn as nn
 
4
  from typing import Union
5
  import torch.nn.functional as F
6
 
@@ -633,3 +636,231 @@ class Model(nn.Module):
633
  output = self.out_layer(last_hidden)
634
 
635
  return output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
+ import random
3
  import fairseq
4
+ import numpy as np
5
  import torch.nn as nn
6
+ from torch import Tensor
7
  from typing import Union
8
  import torch.nn.functional as F
9
 
 
636
  output = self.out_layer(last_hidden)
637
 
638
  return output
639
+
640
+
641
+ class CONV(nn.Module):
642
+ @staticmethod
643
+ def to_mel(hz):
644
+ return 2595 * np.log10(1 + hz / 700)
645
+
646
+ @staticmethod
647
+ def to_hz(mel):
648
+ return 700 * (10**(mel / 2595) - 1)
649
+
650
+ def __init__(self,
651
+ out_channels,
652
+ kernel_size,
653
+ sample_rate=16000,
654
+ in_channels=1,
655
+ stride=1,
656
+ padding=0,
657
+ dilation=1,
658
+ bias=False,
659
+ groups=1,
660
+ mask=False):
661
+ super().__init__()
662
+ if in_channels != 1:
663
+
664
+ msg = "SincConv only support one input channel (here, in_channels = {%i})" % (
665
+ in_channels)
666
+ raise ValueError(msg)
667
+ self.out_channels = out_channels
668
+ self.kernel_size = kernel_size
669
+ self.sample_rate = sample_rate
670
+
671
+ # Forcing the filters to be odd (i.e, perfectly symmetrics)
672
+ if kernel_size % 2 == 0:
673
+ self.kernel_size = self.kernel_size + 1
674
+ self.stride = stride
675
+ self.padding = padding
676
+ self.dilation = dilation
677
+ self.mask = mask
678
+ if bias:
679
+ raise ValueError('SincConv does not support bias.')
680
+ if groups > 1:
681
+ raise ValueError('SincConv does not support groups.')
682
+
683
+ NFFT = 512
684
+ f = int(self.sample_rate / 2) * np.linspace(0, 1, int(NFFT / 2) + 1)
685
+ fmel = self.to_mel(f)
686
+ fmelmax = np.max(fmel)
687
+ fmelmin = np.min(fmel)
688
+ filbandwidthsmel = np.linspace(fmelmin, fmelmax, self.out_channels + 1)
689
+ filbandwidthsf = self.to_hz(filbandwidthsmel)
690
+
691
+ self.mel = filbandwidthsf
692
+ self.hsupp = torch.arange(-(self.kernel_size - 1) / 2,
693
+ (self.kernel_size - 1) / 2 + 1)
694
+ self.band_pass = torch.zeros(self.out_channels, self.kernel_size)
695
+ for i in range(len(self.mel) - 1):
696
+ fmin = self.mel[i]
697
+ fmax = self.mel[i + 1]
698
+ hHigh = (2*fmax/self.sample_rate) * \
699
+ np.sinc(2*fmax*self.hsupp/self.sample_rate)
700
+ hLow = (2*fmin/self.sample_rate) * \
701
+ np.sinc(2*fmin*self.hsupp/self.sample_rate)
702
+ hideal = hHigh - hLow
703
+
704
+ self.band_pass[i, :] = Tensor(np.hamming(
705
+ self.kernel_size)) * Tensor(hideal)
706
+
707
+ def forward(self, x, mask=False):
708
+ band_pass_filter = self.band_pass.clone().to(x.device)
709
+ if mask:
710
+ A = np.random.uniform(0, 20)
711
+ A = int(A)
712
+ A0 = random.randint(0, band_pass_filter.shape[0] - A)
713
+ band_pass_filter[A0:A0 + A, :] = 0
714
+ else:
715
+ band_pass_filter = band_pass_filter
716
+
717
+ self.filters = (band_pass_filter).view(self.out_channels, 1,
718
+ self.kernel_size)
719
+
720
+ return F.conv1d(x,
721
+ self.filters,
722
+ stride=self.stride,
723
+ padding=self.padding,
724
+ dilation=self.dilation,
725
+ bias=None,
726
+ groups=1)
727
+
728
+
729
+ class AASIST_Model(nn.Module):
730
+ def __init__(self, args, device):
731
+ super().__init__()
732
+
733
+ filts = [70, [1, 32], [32, 32], [32, 64], [64, 64]]
734
+ gat_dims = [64, 32]
735
+ pool_ratios =[0.5, 0.7, 0.5, 0.5]
736
+ temperatures =[2.0, 2.0, 100.0, 100.0]
737
+
738
+ self.conv_time = CONV(out_channels=filts[0],
739
+ kernel_size=128,
740
+ in_channels=1)
741
+ self.first_bn = nn.BatchNorm2d(num_features=1)
742
+
743
+ self.drop = nn.Dropout(0.5, inplace=True)
744
+ self.drop_way = nn.Dropout(0.2, inplace=True)
745
+ self.selu = nn.SELU(inplace=True)
746
+
747
+ self.encoder = nn.Sequential(
748
+ nn.Sequential(Residual_block_aasist(nb_filts=filts[1], first=True)),
749
+ nn.Sequential(Residual_block_aasist(nb_filts=filts[2])),
750
+ nn.Sequential(Residual_block_aasist(nb_filts=filts[3])),
751
+ nn.Sequential(Residual_block_aasist(nb_filts=filts[4])),
752
+ nn.Sequential(Residual_block_aasist(nb_filts=filts[4])),
753
+ nn.Sequential(Residual_block_aasist(nb_filts=filts[4])))
754
+
755
+ self.pos_S = nn.Parameter(torch.randn(1, 23, filts[-1][-1]))
756
+ self.master1 = nn.Parameter(torch.randn(1, 1, gat_dims[0]))
757
+ self.master2 = nn.Parameter(torch.randn(1, 1, gat_dims[0]))
758
+
759
+ self.GAT_layer_S = GraphAttentionLayer(filts[-1][-1],
760
+ gat_dims[0],
761
+ temperature=temperatures[0])
762
+ self.GAT_layer_T = GraphAttentionLayer(filts[-1][-1],
763
+ gat_dims[0],
764
+ temperature=temperatures[1])
765
+
766
+ self.HtrgGAT_layer_ST11 = HtrgGraphAttentionLayer(
767
+ gat_dims[0], gat_dims[1], temperature=temperatures[2])
768
+ self.HtrgGAT_layer_ST12 = HtrgGraphAttentionLayer(
769
+ gat_dims[1], gat_dims[1], temperature=temperatures[2])
770
+
771
+ self.HtrgGAT_layer_ST21 = HtrgGraphAttentionLayer(
772
+ gat_dims[0], gat_dims[1], temperature=temperatures[2])
773
+
774
+ self.HtrgGAT_layer_ST22 = HtrgGraphAttentionLayer(
775
+ gat_dims[1], gat_dims[1], temperature=temperatures[2])
776
+
777
+ self.pool_S = GraphPool(pool_ratios[0], gat_dims[0], 0.3)
778
+ self.pool_T = GraphPool(pool_ratios[1], gat_dims[0], 0.3)
779
+ self.pool_hS1 = GraphPool(pool_ratios[2], gat_dims[1], 0.3)
780
+ self.pool_hT1 = GraphPool(pool_ratios[2], gat_dims[1], 0.3)
781
+
782
+ self.pool_hS2 = GraphPool(pool_ratios[2], gat_dims[1], 0.3)
783
+ self.pool_hT2 = GraphPool(pool_ratios[2], gat_dims[1], 0.3)
784
+
785
+ self.out_layer = nn.Linear(5 * gat_dims[1], 2)
786
+
787
+ def forward(self, x, Freq_aug=False):
788
+
789
+ x = x.unsqueeze(1)
790
+ x = self.conv_time(x, mask=Freq_aug)
791
+ x = x.unsqueeze(dim=1)
792
+ x = F.max_pool2d(torch.abs(x), (3, 3))
793
+ x = self.first_bn(x)
794
+ x = self.selu(x)
795
+
796
+ # get embeddings using encoder
797
+ # (#bs, #filt, #spec, #seq)
798
+ e = self.encoder(x)
799
+
800
+ # spectral GAT (GAT-S)
801
+ e_S, _ = torch.max(torch.abs(e), dim=3) # max along time
802
+ e_S = e_S.transpose(1, 2) + self.pos_S
803
+
804
+ gat_S = self.GAT_layer_S(e_S)
805
+ out_S = self.pool_S(gat_S) # (#bs, #node, #dim)
806
+
807
+ # temporal GAT (GAT-T)
808
+ e_T, _ = torch.max(torch.abs(e), dim=2) # max along freq
809
+ e_T = e_T.transpose(1, 2)
810
+
811
+ gat_T = self.GAT_layer_T(e_T)
812
+ out_T = self.pool_T(gat_T)
813
+
814
+ # learnable master node
815
+ master1 = self.master1.expand(x.size(0), -1, -1)
816
+ master2 = self.master2.expand(x.size(0), -1, -1)
817
+
818
+ # inference 1
819
+ out_T1, out_S1, master1 = self.HtrgGAT_layer_ST11(
820
+ out_T, out_S, master=self.master1)
821
+
822
+ out_S1 = self.pool_hS1(out_S1)
823
+ out_T1 = self.pool_hT1(out_T1)
824
+
825
+ out_T_aug, out_S_aug, master_aug = self.HtrgGAT_layer_ST12(
826
+ out_T1, out_S1, master=master1)
827
+ out_T1 = out_T1 + out_T_aug
828
+ out_S1 = out_S1 + out_S_aug
829
+ master1 = master1 + master_aug
830
+
831
+ # inference 2
832
+ out_T2, out_S2, master2 = self.HtrgGAT_layer_ST21(
833
+ out_T, out_S, master=self.master2)
834
+ out_S2 = self.pool_hS2(out_S2)
835
+ out_T2 = self.pool_hT2(out_T2)
836
+
837
+ out_T_aug, out_S_aug, master_aug = self.HtrgGAT_layer_ST22(
838
+ out_T2, out_S2, master=master2)
839
+ out_T2 = out_T2 + out_T_aug
840
+ out_S2 = out_S2 + out_S_aug
841
+ master2 = master2 + master_aug
842
+
843
+ out_T1 = self.drop_way(out_T1)
844
+ out_T2 = self.drop_way(out_T2)
845
+ out_S1 = self.drop_way(out_S1)
846
+ out_S2 = self.drop_way(out_S2)
847
+ master1 = self.drop_way(master1)
848
+ master2 = self.drop_way(master2)
849
+
850
+ out_T = torch.max(out_T1, out_T2)
851
+ out_S = torch.max(out_S1, out_S2)
852
+ master = torch.max(master1, master2)
853
+
854
+ T_max, _ = torch.max(torch.abs(out_T), dim=1)
855
+ T_avg = torch.mean(out_T, dim=1)
856
+
857
+ S_max, _ = torch.max(torch.abs(out_S), dim=1)
858
+ S_avg = torch.mean(out_S, dim=1)
859
+
860
+ last_hidden = torch.cat(
861
+ [T_max, T_avg, S_max, S_avg, master.squeeze(1)], dim=1)
862
+
863
+ last_hidden = self.drop(last_hidden)
864
+ output = self.out_layer(last_hidden)
865
+
866
+ return last_hidden, output
orig_aasist_epoch_1.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f26ad87bca3d47e97ecfeac6fee6fcae93f62673a484d447c081d45911e3a027
3
+ size 1276136
process_data.py CHANGED
@@ -17,4 +17,7 @@ def process_ssl_assist_input(filepath):
17
  X_pad = pad(X)
18
  x_inp = Tensor(X_pad)
19
  x_inp = x_inp.unsqueeze(0)
20
- return x_inp
 
 
 
 
17
  X_pad = pad(X)
18
  x_inp = Tensor(X_pad)
19
  x_inp = x_inp.unsqueeze(0)
20
+ return x_inp
21
+
22
+ def process_assist_input(filepath):
23
+ return process_ssl_assist_input(filepath)