import torch import torchaudio import logging import os from demucs.pretrained import get_model from demucs.apply import apply_model from typing import Tuple logger = logging.getLogger(__name__) class DemucsProcessor: def __init__(self, model_name="htdemucs"): try: self.device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {self.device}") self.model = get_model(model_name) print(f"Model name: {model_name}") print(f"Model sources: {self.model.sources}") # This will show available stems print(f"Model sample rate: {self.model.samplerate}") self.model.to(self.device) print(f"Model loaded successfully on {self.device}") except Exception as e: print(f"Error initializing model: {str(e)}") raise def separate_stems(self, audio_path: str, progress=None) -> Tuple[torch.Tensor, int]: try: if progress: progress(0.1, "Loading audio file...") # Load audio waveform, sample_rate = torchaudio.load(audio_path) print(f"Audio loaded - Shape: {waveform.shape}") if progress: progress(0.3, "Processing stems...") # Input validation and logging: Check waveform dimensions if waveform.dim() not in (1, 2): raise ValueError(f"Invalid waveform dimensions: Expected 1D or 2D, got {waveform.dim()}") # Handle mono input by duplicating to stereo if waveform.dim() == 1: waveform = waveform.unsqueeze(0) if waveform.shape[0] == 1: waveform = waveform.repeat(2, 1) print("Converted mono to stereo by duplication") # Ensure 3D tensor for apply_model (batch, channels, time) waveform = waveform.unsqueeze(0) print(f"Waveform shape before apply_model: {waveform.shape}") # Process with torch.no_grad(): sources = apply_model(self.model, waveform.to(self.device)) print(f"Sources shape after processing: {sources.shape}") print(f"Available stems: {self.model.sources}") if progress: progress(0.8, "Finalizing separation...") return sources, sample_rate except Exception as e: print(f"Error in stem separation: {str(e)}") raise def save_stem(self, stem: torch.Tensor, stem_name: str, output_path: str, sample_rate: int): try: torchaudio.save( f"{output_path}/{stem_name}.wav", stem.cpu(), sample_rate ) except Exception as e: print(f"Error saving stem: {str(e)}") raise