Spaces:
Running
Running
Arnab Das
commited on
Commit
•
b30e39a
1
Parent(s):
3f67209
AASIST model added.
Browse files- app.py +5 -3
- models.py +231 -0
- orig_aasist_epoch_1.pth +3 -0
- process_data.py +4 -1
app.py
CHANGED
@@ -11,7 +11,9 @@ model_master = {
|
|
11 |
"model_checkpoint": "ssl_aasist_epoch_7.pth"},
|
12 |
"AASIST": {"eer_threshold": 1.8018419742584229,
|
13 |
"data_process_func": "process_assist_input",
|
14 |
-
"note": "This model is trained on ASVSpoof 2024 training data."
|
|
|
|
|
15 |
}
|
16 |
|
17 |
model = MOD.Model(None, "cpu")
|
@@ -21,8 +23,6 @@ loaded_model = "SSL-AASIST (Trained on ASV-Spoof5)"
|
|
21 |
|
22 |
|
23 |
def process(file, type):
|
24 |
-
if type == "AASIST":
|
25 |
-
return "Model AASIST is not yet implemented."
|
26 |
global model
|
27 |
global loaded_model
|
28 |
inp = getattr(PD, model_master[type]["data_process_func"])(file)
|
@@ -54,6 +54,8 @@ file_proc = gr.Interface(
|
|
54 |
examples=[
|
55 |
["./bonafide.flac", "SSL-AASIST (Trained on ASV-Spoof5)"],
|
56 |
["./fake.flac", "SSL-AASIST (Trained on ASV-Spoof5)"],
|
|
|
|
|
57 |
],
|
58 |
cache_examples=True,
|
59 |
allow_flagging="never",
|
|
|
11 |
"model_checkpoint": "ssl_aasist_epoch_7.pth"},
|
12 |
"AASIST": {"eer_threshold": 1.8018419742584229,
|
13 |
"data_process_func": "process_assist_input",
|
14 |
+
"note": "This model is trained on ASVSpoof 2024 training data.",
|
15 |
+
"model_class":"AASIST_Model",
|
16 |
+
"model_checkpoint": "orig_aasist_epoch_1.pth"}
|
17 |
}
|
18 |
|
19 |
model = MOD.Model(None, "cpu")
|
|
|
23 |
|
24 |
|
25 |
def process(file, type):
|
|
|
|
|
26 |
global model
|
27 |
global loaded_model
|
28 |
inp = getattr(PD, model_master[type]["data_process_func"])(file)
|
|
|
54 |
examples=[
|
55 |
["./bonafide.flac", "SSL-AASIST (Trained on ASV-Spoof5)"],
|
56 |
["./fake.flac", "SSL-AASIST (Trained on ASV-Spoof5)"],
|
57 |
+
["./bonafide.flac", "AASIST"],
|
58 |
+
["./fake.flac", "AASIST"],
|
59 |
],
|
60 |
cache_examples=True,
|
61 |
allow_flagging="never",
|
models.py
CHANGED
@@ -1,6 +1,9 @@
|
|
1 |
import torch
|
|
|
2 |
import fairseq
|
|
|
3 |
import torch.nn as nn
|
|
|
4 |
from typing import Union
|
5 |
import torch.nn.functional as F
|
6 |
|
@@ -633,3 +636,231 @@ class Model(nn.Module):
|
|
633 |
output = self.out_layer(last_hidden)
|
634 |
|
635 |
return output
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import torch
|
2 |
+
import random
|
3 |
import fairseq
|
4 |
+
import numpy as np
|
5 |
import torch.nn as nn
|
6 |
+
from torch import Tensor
|
7 |
from typing import Union
|
8 |
import torch.nn.functional as F
|
9 |
|
|
|
636 |
output = self.out_layer(last_hidden)
|
637 |
|
638 |
return output
|
639 |
+
|
640 |
+
|
641 |
+
class CONV(nn.Module):
|
642 |
+
@staticmethod
|
643 |
+
def to_mel(hz):
|
644 |
+
return 2595 * np.log10(1 + hz / 700)
|
645 |
+
|
646 |
+
@staticmethod
|
647 |
+
def to_hz(mel):
|
648 |
+
return 700 * (10**(mel / 2595) - 1)
|
649 |
+
|
650 |
+
def __init__(self,
|
651 |
+
out_channels,
|
652 |
+
kernel_size,
|
653 |
+
sample_rate=16000,
|
654 |
+
in_channels=1,
|
655 |
+
stride=1,
|
656 |
+
padding=0,
|
657 |
+
dilation=1,
|
658 |
+
bias=False,
|
659 |
+
groups=1,
|
660 |
+
mask=False):
|
661 |
+
super().__init__()
|
662 |
+
if in_channels != 1:
|
663 |
+
|
664 |
+
msg = "SincConv only support one input channel (here, in_channels = {%i})" % (
|
665 |
+
in_channels)
|
666 |
+
raise ValueError(msg)
|
667 |
+
self.out_channels = out_channels
|
668 |
+
self.kernel_size = kernel_size
|
669 |
+
self.sample_rate = sample_rate
|
670 |
+
|
671 |
+
# Forcing the filters to be odd (i.e, perfectly symmetrics)
|
672 |
+
if kernel_size % 2 == 0:
|
673 |
+
self.kernel_size = self.kernel_size + 1
|
674 |
+
self.stride = stride
|
675 |
+
self.padding = padding
|
676 |
+
self.dilation = dilation
|
677 |
+
self.mask = mask
|
678 |
+
if bias:
|
679 |
+
raise ValueError('SincConv does not support bias.')
|
680 |
+
if groups > 1:
|
681 |
+
raise ValueError('SincConv does not support groups.')
|
682 |
+
|
683 |
+
NFFT = 512
|
684 |
+
f = int(self.sample_rate / 2) * np.linspace(0, 1, int(NFFT / 2) + 1)
|
685 |
+
fmel = self.to_mel(f)
|
686 |
+
fmelmax = np.max(fmel)
|
687 |
+
fmelmin = np.min(fmel)
|
688 |
+
filbandwidthsmel = np.linspace(fmelmin, fmelmax, self.out_channels + 1)
|
689 |
+
filbandwidthsf = self.to_hz(filbandwidthsmel)
|
690 |
+
|
691 |
+
self.mel = filbandwidthsf
|
692 |
+
self.hsupp = torch.arange(-(self.kernel_size - 1) / 2,
|
693 |
+
(self.kernel_size - 1) / 2 + 1)
|
694 |
+
self.band_pass = torch.zeros(self.out_channels, self.kernel_size)
|
695 |
+
for i in range(len(self.mel) - 1):
|
696 |
+
fmin = self.mel[i]
|
697 |
+
fmax = self.mel[i + 1]
|
698 |
+
hHigh = (2*fmax/self.sample_rate) * \
|
699 |
+
np.sinc(2*fmax*self.hsupp/self.sample_rate)
|
700 |
+
hLow = (2*fmin/self.sample_rate) * \
|
701 |
+
np.sinc(2*fmin*self.hsupp/self.sample_rate)
|
702 |
+
hideal = hHigh - hLow
|
703 |
+
|
704 |
+
self.band_pass[i, :] = Tensor(np.hamming(
|
705 |
+
self.kernel_size)) * Tensor(hideal)
|
706 |
+
|
707 |
+
def forward(self, x, mask=False):
|
708 |
+
band_pass_filter = self.band_pass.clone().to(x.device)
|
709 |
+
if mask:
|
710 |
+
A = np.random.uniform(0, 20)
|
711 |
+
A = int(A)
|
712 |
+
A0 = random.randint(0, band_pass_filter.shape[0] - A)
|
713 |
+
band_pass_filter[A0:A0 + A, :] = 0
|
714 |
+
else:
|
715 |
+
band_pass_filter = band_pass_filter
|
716 |
+
|
717 |
+
self.filters = (band_pass_filter).view(self.out_channels, 1,
|
718 |
+
self.kernel_size)
|
719 |
+
|
720 |
+
return F.conv1d(x,
|
721 |
+
self.filters,
|
722 |
+
stride=self.stride,
|
723 |
+
padding=self.padding,
|
724 |
+
dilation=self.dilation,
|
725 |
+
bias=None,
|
726 |
+
groups=1)
|
727 |
+
|
728 |
+
|
729 |
+
class AASIST_Model(nn.Module):
|
730 |
+
def __init__(self, args, device):
|
731 |
+
super().__init__()
|
732 |
+
|
733 |
+
filts = [70, [1, 32], [32, 32], [32, 64], [64, 64]]
|
734 |
+
gat_dims = [64, 32]
|
735 |
+
pool_ratios =[0.5, 0.7, 0.5, 0.5]
|
736 |
+
temperatures =[2.0, 2.0, 100.0, 100.0]
|
737 |
+
|
738 |
+
self.conv_time = CONV(out_channels=filts[0],
|
739 |
+
kernel_size=128,
|
740 |
+
in_channels=1)
|
741 |
+
self.first_bn = nn.BatchNorm2d(num_features=1)
|
742 |
+
|
743 |
+
self.drop = nn.Dropout(0.5, inplace=True)
|
744 |
+
self.drop_way = nn.Dropout(0.2, inplace=True)
|
745 |
+
self.selu = nn.SELU(inplace=True)
|
746 |
+
|
747 |
+
self.encoder = nn.Sequential(
|
748 |
+
nn.Sequential(Residual_block_aasist(nb_filts=filts[1], first=True)),
|
749 |
+
nn.Sequential(Residual_block_aasist(nb_filts=filts[2])),
|
750 |
+
nn.Sequential(Residual_block_aasist(nb_filts=filts[3])),
|
751 |
+
nn.Sequential(Residual_block_aasist(nb_filts=filts[4])),
|
752 |
+
nn.Sequential(Residual_block_aasist(nb_filts=filts[4])),
|
753 |
+
nn.Sequential(Residual_block_aasist(nb_filts=filts[4])))
|
754 |
+
|
755 |
+
self.pos_S = nn.Parameter(torch.randn(1, 23, filts[-1][-1]))
|
756 |
+
self.master1 = nn.Parameter(torch.randn(1, 1, gat_dims[0]))
|
757 |
+
self.master2 = nn.Parameter(torch.randn(1, 1, gat_dims[0]))
|
758 |
+
|
759 |
+
self.GAT_layer_S = GraphAttentionLayer(filts[-1][-1],
|
760 |
+
gat_dims[0],
|
761 |
+
temperature=temperatures[0])
|
762 |
+
self.GAT_layer_T = GraphAttentionLayer(filts[-1][-1],
|
763 |
+
gat_dims[0],
|
764 |
+
temperature=temperatures[1])
|
765 |
+
|
766 |
+
self.HtrgGAT_layer_ST11 = HtrgGraphAttentionLayer(
|
767 |
+
gat_dims[0], gat_dims[1], temperature=temperatures[2])
|
768 |
+
self.HtrgGAT_layer_ST12 = HtrgGraphAttentionLayer(
|
769 |
+
gat_dims[1], gat_dims[1], temperature=temperatures[2])
|
770 |
+
|
771 |
+
self.HtrgGAT_layer_ST21 = HtrgGraphAttentionLayer(
|
772 |
+
gat_dims[0], gat_dims[1], temperature=temperatures[2])
|
773 |
+
|
774 |
+
self.HtrgGAT_layer_ST22 = HtrgGraphAttentionLayer(
|
775 |
+
gat_dims[1], gat_dims[1], temperature=temperatures[2])
|
776 |
+
|
777 |
+
self.pool_S = GraphPool(pool_ratios[0], gat_dims[0], 0.3)
|
778 |
+
self.pool_T = GraphPool(pool_ratios[1], gat_dims[0], 0.3)
|
779 |
+
self.pool_hS1 = GraphPool(pool_ratios[2], gat_dims[1], 0.3)
|
780 |
+
self.pool_hT1 = GraphPool(pool_ratios[2], gat_dims[1], 0.3)
|
781 |
+
|
782 |
+
self.pool_hS2 = GraphPool(pool_ratios[2], gat_dims[1], 0.3)
|
783 |
+
self.pool_hT2 = GraphPool(pool_ratios[2], gat_dims[1], 0.3)
|
784 |
+
|
785 |
+
self.out_layer = nn.Linear(5 * gat_dims[1], 2)
|
786 |
+
|
787 |
+
def forward(self, x, Freq_aug=False):
|
788 |
+
|
789 |
+
x = x.unsqueeze(1)
|
790 |
+
x = self.conv_time(x, mask=Freq_aug)
|
791 |
+
x = x.unsqueeze(dim=1)
|
792 |
+
x = F.max_pool2d(torch.abs(x), (3, 3))
|
793 |
+
x = self.first_bn(x)
|
794 |
+
x = self.selu(x)
|
795 |
+
|
796 |
+
# get embeddings using encoder
|
797 |
+
# (#bs, #filt, #spec, #seq)
|
798 |
+
e = self.encoder(x)
|
799 |
+
|
800 |
+
# spectral GAT (GAT-S)
|
801 |
+
e_S, _ = torch.max(torch.abs(e), dim=3) # max along time
|
802 |
+
e_S = e_S.transpose(1, 2) + self.pos_S
|
803 |
+
|
804 |
+
gat_S = self.GAT_layer_S(e_S)
|
805 |
+
out_S = self.pool_S(gat_S) # (#bs, #node, #dim)
|
806 |
+
|
807 |
+
# temporal GAT (GAT-T)
|
808 |
+
e_T, _ = torch.max(torch.abs(e), dim=2) # max along freq
|
809 |
+
e_T = e_T.transpose(1, 2)
|
810 |
+
|
811 |
+
gat_T = self.GAT_layer_T(e_T)
|
812 |
+
out_T = self.pool_T(gat_T)
|
813 |
+
|
814 |
+
# learnable master node
|
815 |
+
master1 = self.master1.expand(x.size(0), -1, -1)
|
816 |
+
master2 = self.master2.expand(x.size(0), -1, -1)
|
817 |
+
|
818 |
+
# inference 1
|
819 |
+
out_T1, out_S1, master1 = self.HtrgGAT_layer_ST11(
|
820 |
+
out_T, out_S, master=self.master1)
|
821 |
+
|
822 |
+
out_S1 = self.pool_hS1(out_S1)
|
823 |
+
out_T1 = self.pool_hT1(out_T1)
|
824 |
+
|
825 |
+
out_T_aug, out_S_aug, master_aug = self.HtrgGAT_layer_ST12(
|
826 |
+
out_T1, out_S1, master=master1)
|
827 |
+
out_T1 = out_T1 + out_T_aug
|
828 |
+
out_S1 = out_S1 + out_S_aug
|
829 |
+
master1 = master1 + master_aug
|
830 |
+
|
831 |
+
# inference 2
|
832 |
+
out_T2, out_S2, master2 = self.HtrgGAT_layer_ST21(
|
833 |
+
out_T, out_S, master=self.master2)
|
834 |
+
out_S2 = self.pool_hS2(out_S2)
|
835 |
+
out_T2 = self.pool_hT2(out_T2)
|
836 |
+
|
837 |
+
out_T_aug, out_S_aug, master_aug = self.HtrgGAT_layer_ST22(
|
838 |
+
out_T2, out_S2, master=master2)
|
839 |
+
out_T2 = out_T2 + out_T_aug
|
840 |
+
out_S2 = out_S2 + out_S_aug
|
841 |
+
master2 = master2 + master_aug
|
842 |
+
|
843 |
+
out_T1 = self.drop_way(out_T1)
|
844 |
+
out_T2 = self.drop_way(out_T2)
|
845 |
+
out_S1 = self.drop_way(out_S1)
|
846 |
+
out_S2 = self.drop_way(out_S2)
|
847 |
+
master1 = self.drop_way(master1)
|
848 |
+
master2 = self.drop_way(master2)
|
849 |
+
|
850 |
+
out_T = torch.max(out_T1, out_T2)
|
851 |
+
out_S = torch.max(out_S1, out_S2)
|
852 |
+
master = torch.max(master1, master2)
|
853 |
+
|
854 |
+
T_max, _ = torch.max(torch.abs(out_T), dim=1)
|
855 |
+
T_avg = torch.mean(out_T, dim=1)
|
856 |
+
|
857 |
+
S_max, _ = torch.max(torch.abs(out_S), dim=1)
|
858 |
+
S_avg = torch.mean(out_S, dim=1)
|
859 |
+
|
860 |
+
last_hidden = torch.cat(
|
861 |
+
[T_max, T_avg, S_max, S_avg, master.squeeze(1)], dim=1)
|
862 |
+
|
863 |
+
last_hidden = self.drop(last_hidden)
|
864 |
+
output = self.out_layer(last_hidden)
|
865 |
+
|
866 |
+
return last_hidden, output
|
orig_aasist_epoch_1.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f26ad87bca3d47e97ecfeac6fee6fcae93f62673a484d447c081d45911e3a027
|
3 |
+
size 1276136
|
process_data.py
CHANGED
@@ -17,4 +17,7 @@ def process_ssl_assist_input(filepath):
|
|
17 |
X_pad = pad(X)
|
18 |
x_inp = Tensor(X_pad)
|
19 |
x_inp = x_inp.unsqueeze(0)
|
20 |
-
return x_inp
|
|
|
|
|
|
|
|
17 |
X_pad = pad(X)
|
18 |
x_inp = Tensor(X_pad)
|
19 |
x_inp = x_inp.unsqueeze(0)
|
20 |
+
return x_inp
|
21 |
+
|
22 |
+
def process_assist_input(filepath):
|
23 |
+
return process_ssl_assist_input(filepath)
|