File size: 5,303 Bytes
5017efb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import torch
import torchaudio
import torch.nn as nn
from transformers import PreTrainedModel, PretrainedConfig
import torch
from BigVGAN import bigvgan
from BigVGAN.meldataset import get_mel_spectrogram
from voice_restore import VoiceRestore
import argparse
from model import OptimizedAudioRestorationModel
import librosa
from inference_long import apply_overlap_windowing_waveform, reconstruct_waveform_from_windows

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Configuration class for VoiceRestore
class VoiceRestoreConfig(PretrainedConfig):
    model_type = "voice_restore"

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.steps = kwargs.get("steps", 16)
        self.cfg_strength = kwargs.get("cfg_strength", 0.5)
        self.window_size_sec = kwargs.get("window_size_sec", 5.0)
        self.overlap = kwargs.get("overlap", 0.5)

# Model class for VoiceRestore
class VoiceRestore(PreTrainedModel):
    config_class = VoiceRestoreConfig
    
    def __init__(self, config: VoiceRestoreConfig):
        super().__init__(config)
        self.steps = config.steps
        self.cfg_strength = config.cfg_strength
        self.window_size_sec = config.window_size_sec
        self.overlap = config.overlap

        # Initialize BigVGAN model
        self.bigvgan_model = bigvgan.BigVGAN.from_pretrained(
            'nvidia/bigvgan_v2_24khz_100band_256x',
            use_cuda_kernel=False,
            force_download=False
        ).to(device)
        self.bigvgan_model.remove_weight_norm()

        # Optimized restoration model
        self.optimized_model = OptimizedAudioRestorationModel(device=device, bigvgan_model=self.bigvgan_model)
        save_path = "./pytorch_model.bin"
        state_dict = torch.load(save_path, map_location=torch.device(device))
        if 'model_state_dict' in state_dict:
            state_dict = state_dict['model_state_dict']
        
        self.optimized_model.voice_restore.load_state_dict(state_dict, strict=True)
        self.optimized_model.eval()

    def forward(self, input_path, output_path, short=True):
        # Restore the audio using the parameters from the config
        if short:
            self.restore_audio_short(self.optimized_model, input_path, output_path, self.steps, self.cfg_strength)
        else:
            self.restore_audio_long(self.optimized_model, input_path, output_path, self.steps, self.cfg_strength, self.window_size_sec, self.overlap)

    def restore_audio_short(self, model, input_path, output_path, steps, cfg_strength):
        """
        Short inference for audio restoration.
        """
        # Load the audio file
        device_type = device.type  
        audio, sr = torchaudio.load(input_path)
        if sr != model.target_sample_rate:
            audio = torchaudio.functional.resample(audio, sr, model.target_sample_rate)

        audio = audio.mean(dim=0, keepdim=True) if audio.dim() > 1 else audio  # Convert to mono if stereo

        with torch.inference_mode():
            with torch.autocast(device_type):
                restored_wav = model(audio, steps=steps, cfg_strength=cfg_strength)
                restored_wav = restored_wav.squeeze(0).float().cpu()  # Move to CPU after processing

        # Save the restored audio
        torchaudio.save(output_path, restored_wav, model.target_sample_rate)

    def restore_audio_long(self, model, input_path, output_path, steps, cfg_strength, window_size_sec, overlap):
        """
        Long inference for audio restoration using overlapping windows.
        """
        # Load the audio file
        wav, sr = librosa.load(input_path, sr=24000, mono=True)
        wav = torch.FloatTensor(wav).unsqueeze(0)  # Shape: [1, num_samples]

        window_size_samples = int(window_size_sec * sr)
        wav_windows = apply_overlap_windowing_waveform(wav, window_size_samples, overlap)

        restored_wav_windows = []
        for wav_window in wav_windows:
            wav_window = wav_window.to(device)
            processed_mel = get_mel_spectrogram(wav_window, self.bigvgan_model.h).to(device)

            # Restore audio
            with torch.no_grad():
                with torch.autocast(device):
                    restored_mel = model.voice_restore.sample(processed_mel.transpose(1, 2), steps=steps, cfg_strength=cfg_strength)
                    restored_mel = restored_mel.squeeze(0).transpose(0, 1)

                restored_wav = self.bigvgan_model(restored_mel.unsqueeze(0)).squeeze(0).float().cpu()
                restored_wav_windows.append(restored_wav)

            torch.cuda.empty_cache()

        restored_wav_windows = torch.stack(restored_wav_windows)
        restored_wav = reconstruct_waveform_from_windows(restored_wav_windows, window_size_samples, overlap)

        # Save the restored audio
        torchaudio.save(output_path, restored_wav.unsqueeze(0), 24000)


# # Function to load the model using AutoModel
# from transformers import AutoModel

# def load_voice_restore_model(checkpoint_path: str):
#     model = AutoModel.from_pretrained(checkpoint_path, config=VoiceRestoreConfig())
#     return model

# # Example Usage
# model = load_voice_restore_model("./checkpoints/voice-restore-20d-16h-optim.pt")
# model("test_input.wav", "test_output.wav")