import os
from typing import Union

import torch
import torchaudio
from modules.Denoiser.AudioDenoiser import AudioDenoiser

from modules.utils.constants import MODELS_DIR

from modules.devices import devices

import soundfile as sf

ad: Union[AudioDenoiser, None] = None


class TTSAudioDenoiser:

    def load_ad(self):
        global ad
        if ad is None:
            ad = AudioDenoiser(
                os.path.join(
                    MODELS_DIR,
                    "Denoise",
                    "audio-denoiser-512-32-v1",
                ),
                device=devices.device,
            )
            ad.model.to(devices.device)
        return ad

    def denoise(self, audio_data, sample_rate, auto_scale=False):
        ad = self.load_ad()
        sr = ad.model_sample_rate
        return sr, ad.process_waveform(audio_data, sample_rate, auto_scale)


if __name__ == "__main__":
    tts_deno = TTSAudioDenoiser()
    data, sr = sf.read("test.wav")
    audio_tensor = torch.from_numpy(data).unsqueeze(0).float()
    print(audio_tensor)

    # data, sr = torchaudio.load("test.wav")
    # print(data)
    # data = data.to(devices.device)

    sr, denoised = tts_deno.denoise(audio_data=audio_tensor, sample_rate=sr)
    denoised = denoised.cpu()
    torchaudio.save("denoised.wav", denoised, sample_rate=sr)