d22cs051's picture
fixing map location for torch load setting deafult to cpu
2ad7af5
raw
history blame
3.83 kB
import gradio as gr
import torch
import soundfile as sf
from speechbrain.inference.separation import SepformerSeparation as separator
# defineing model class
class SepformerFineTune(torch.nn.Module):
def __init__(self, model):
super(SepformerFineTune, self).__init__()
self.model = model
# disabling gradient computation
for parms in self.model.parameters():
parms.requires_grad = False
# enable gradient computation for the last layer
named_layers = dict(model.named_modules())
for name, layer in named_layers.items():
# print(f"Name: {name}, Layer: {layer}")
if name == "mods.masknet.output.0":
for param in layer.parameters():
param.requires_grad = True
if name == "mods.masknet.output_gate":
for param in layer.parameters():
param.requires_grad = True
# printing all tranble parameters
# for model_name, model_params in model.named_parameters():
# print(f"Model Layer Name: {model_name}, Model Params: {model_params.requires_grad}")
def forward(self, mix):
est_sources = self.model.separate_batch(mix)
return est_sources[:,:,0], est_sources[:,:,1] # NOTE: Working with 2 sources ONLY
class SourceSeparationApp:
def __init__(self, model_path,device="cpu"):
self.model = self.load_model(model_path)
self.device = device
def load_model(self, model_path):
model = separator.from_hparams(source="speechbrain/sepformer-wsj03mix", savedir='pretrained_models/sepformer-wsj03mix', run_opts={"device": device})
checkpoint = torch.load(model_path, map_location=torch.device("cpu")
fine_tuned_model = SepformerFineTune(model)
fine_tuned_model.load_state_dict(checkpoint["model"])
return fine_tuned_model
def separate_sources(self, audio_file):
# Load input audio
# print(f"[LOG] Audio file: {audio_file}")
input_audio_tensor, sr = audio_file[1], audio_file[0]
if self.model is None:
return "Error: Model not loaded."
# sending input audio to PyTorch tensor
input_audio_tensor = torch.tensor(input_audio_tensor,dtype=torch.float).unsqueeze(0)
input_audio_tensor = input_audio_tensor.to(self.device)
# Source separation using the loaded model
self.model.to(self.device)
self.model.eval()
with torch.inference_mode():
# print(f"[LOG] mix shape: {mix.shape}, s1 shape: {s1.shape}, s2 shape: {s2.shape}, noise shape: {noise.shape}")
source1,source2 = self.model(input_audio_tensor)
# Save separated sources
sf.write("source1.wav", source1.squeeze().cpu().numpy(), sr)
sf.write("source2.wav", source2.squeeze().cpu().numpy(), sr)
return "Separation completed", "source1.wav", "source2.wav"
def run(self):
audio_input = gr.Audio(label="Upload or record audio")
output_text = gr.Label(label="Status:")
audio_output1 = gr.Audio(label="Source 1", type="filepath",)
audio_output2 = gr.Audio(label="Source 2", type="filepath",)
gr.Interface(
fn=self.separate_sources,
inputs=audio_input,
outputs=[output_text, audio_output1, audio_output2],
title="Audio Source Separation",
description="Separate sources from a mixed audio signal.",
allow_flagging=False
).launch()
if __name__ == "__main__":
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_path = "fine_tuned_sepformer-wsj03mix-7sec.ckpt" # Replace with your model path
app = SourceSeparationApp(model_path, device=device)
app.run()