File size: 3,928 Bytes
6aca51e
 
 
 
5283045
6aca51e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a87778a
6aca51e
 
 
 
76d593a
6aca51e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5283045
6aca51e
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import gradio as gr
import torch
import soundfile as sf
from speechbrain.inference.separation import SepformerSeparation as separator
import os


# defineing model class
class SepformerFineTune(torch.nn.Module):
    def __init__(self, model):
        super(SepformerFineTune, self).__init__()
        self.model = model
        # disabling gradient computation
        for parms in self.model.parameters():
            parms.requires_grad = False
        
        # enable gradient computation for the last layer
        named_layers = dict(model.named_modules())
        for name, layer in named_layers.items():
            # print(f"Name: {name}, Layer: {layer}")
            if name == "mods.masknet.output.0":
                for param in layer.parameters():
                    param.requires_grad = True
            if name == "mods.masknet.output_gate":
                for param in layer.parameters():
                    param.requires_grad = True
            

        # printing all tranble parameters
        # for model_name, model_params in model.named_parameters():
        #     print(f"Model Layer Name: {model_name}, Model Params: {model_params.requires_grad}")
    def forward(self, mix):
        est_sources = self.model.separate_batch(mix)
        return est_sources[:,:,0], est_sources[:,:,1] # NOTE: Working with 2 sources ONLY


class SourceSeparationApp:
    def __init__(self, model_path,device="cpu"):
        self.model = self.load_model(model_path)
        self.device = device

    def load_model(self, model_path):
        model = separator.from_hparams(source="speechbrain/sepformer-wsj03mix", savedir='pretrained_models/sepformer-wsj03mix', run_opts={"device": device})
        checkpoint = torch.load(model_path, map_location=torch.device("cpu"))
        fine_tuned_model = SepformerFineTune(model)
        fine_tuned_model.load_state_dict(checkpoint["model"])
        return fine_tuned_model

    def separate_sources(self, audio_file):
        # Load input audio
        # print(f"[LOG] Audio file: {audio_file}")
        input_audio_tensor, sr = audio_file[1], audio_file[0]

        if self.model is None:
            return "Error: Model not loaded."

        # sending input audio to PyTorch tensor
        input_audio_tensor = torch.tensor(input_audio_tensor,dtype=torch.float).unsqueeze(0)
        input_audio_tensor = input_audio_tensor.to(self.device)
        
        # Source separation using the loaded model
        self.model.to(self.device)
        self.model.eval()
        with torch.inference_mode():
            # print(f"[LOG] mix shape: {mix.shape}, s1 shape: {s1.shape}, s2 shape: {s2.shape}, noise shape: {noise.shape}")
            source1,source2 = self.model(input_audio_tensor)


        # Save separated sources
        sf.write("source1.wav", source1.squeeze().cpu().numpy(), sr)
        sf.write("source2.wav", source2.squeeze().cpu().numpy(), sr)

        return "Separation completed", "source1.wav", "source2.wav"

    def run(self):
        audio_input = gr.Audio(label="Upload or record audio")
        output_text = gr.Label(label="Status:") 
        audio_output1 = gr.Audio(label="Source 1", type="filepath",)
        audio_output2 = gr.Audio(label="Source 2", type="filepath",)
        gr.Interface(
            fn=self.separate_sources,
            inputs=audio_input,
            outputs=[output_text, audio_output1, audio_output2],
            title="Audio Source Separation",
            description="Separate sources from a mixed audio signal.",
            examples=[["examples/" + example] for example in os.listdir("examples")],
            allow_flagging=False
        ).launch()


if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model_path = "fine_tuned_sepformer-wsj03mix-7sec.ckpt"  # Replace with your model path
    app = SourceSeparationApp(model_path, device=device)
    app.run()