# generate_audio.py

import spaces
import pickle
import torch
import numpy as np
from tqdm import tqdm
from transformers import BarkModel, AutoProcessor, AutoTokenizer
from parler_tts import ParlerTTSForConditionalGeneration
from scipy.io import wavfile
from pydub import AudioSegment
import io
import ast


@spaces.GPU
class TTSGenerator:
    """
    A class to generate podcast-style audio from a transcript using ParlerTTS and Bark models.
    """
    #@spaces.GPU
    def __init__(self, transcript_file_path,output_audio_path):
        """
        Initialize the TTS generator with the path to the rewritten transcript file.
        
        Args:
            transcript_file_path (str): Path to the file containing the rewritten transcript.
        """
        self.transcript_file_path = transcript_file_path
        self.output_audio_path = output_audio_path
        
        # Set device
        self.device = "cuda" if torch.cuda.is_available() else "cpu"


        # Load Parler model and tokenizer for Speaker 1
        self.parler_model = ParlerTTSForConditionalGeneration.from_pretrained("parler-tts/parler-tts-mini-v1").to(self.device)
        self.parler_tokenizer = AutoTokenizer.from_pretrained("parler-tts/parler-tts-mini-v1")
        self.speaker1_description = """
        Laura's voice is expressive and dramatic in delivery, speaking at a moderately fast pace with a very close recording that almost has no background noise and very clear audio.
        """
        self.speaker2_description = """
        Gary's voice is expressive and dramatic in delivery, speaking at a moderately fast pace with a very close recording that almost has no background noise and very clear audio.
        """
        
        # Load Bark model and processor for Speaker 2
        # self.bark_processor = AutoProcessor.from_pretrained("suno/bark-small")
        # self.bark_model = BarkModel.from_pretrained("suno/bark-small", torch_dtype=torch.float16).to(self.device)
        # self.bark_sampling_rate = 24000
        # self.voice_preset = "v2/en_speaker_6"

    #@spaces.GPU
    def load_transcript(self):
        """
        Loads the rewritten transcript from the specified file.
        
        Returns:
            list: The content of the transcript as a list of tuples (speaker, text).
        """
        with open(self.transcript_file_path, 'rb') as f:
            return ast.literal_eval(pickle.load(f))
    
    #@spaces.GPU(duration=30)
    def generate_speaker1_audio(self, text):
        """
        Generate audio for Speaker 1 using ParlerTTS.
        
        Args:
            text (str): Text to be synthesized for Speaker 1.
        
        Returns:
            np.array: Audio array.
            int: Sampling rate.
        """
        # input_ids = self.parler_tokenizer(self.speaker1_description, return_tensors="pt").input_ids.to(self.device)
        # prompt_input_ids = self.parler_tokenizer(text, return_tensors="pt").input_ids.to(self.device)
        # generation = self.parler_model.generate(input_ids=input_ids, prompt_input_ids=prompt_input_ids)
        # audio_arr = generation.cpu().numpy().squeeze()
        # return audio_arr, self.parler_model.config.sampling_rate
        with torch.no_grad():
            input_ids = self.parler_tokenizer(self.speaker1_description, return_tensors="pt", padding=True).input_ids.to(self.device)
            attention_mask_input = self.parler_tokenizer(self.speaker1_description, return_tensors="pt", padding=True).attention_mask.to(self.device)
            
            prompt_input_ids = self.parler_tokenizer(text, return_tensors="pt", padding=True).input_ids.to(self.device)
            attention_mask_prompt = self.parler_tokenizer(text, return_tensors="pt", padding=True).attention_mask.to(self.device)
            
            # Pass all required arguments to generate() for reliable behavior
            generation = self.parler_model.generate(
                input_ids=input_ids,
                attention_mask=attention_mask_input,  # Set attention mask for input IDs
                prompt_input_ids=prompt_input_ids,
                prompt_attention_mask=attention_mask_prompt  # Set prompt attention mask
            )
            audio_arr = generation.cpu().numpy().squeeze()
        return audio_arr, self.parler_model.config.sampling_rate

    #@spaces.GPU(duration=30)
    def generate_speaker2_audio(self, text):
        """
        Generate audio for Speaker 2 using Bark.
        
        Args:
            text (str): Text to be synthesized for Speaker 2.
        
        Returns:
            np.array: Audio array.
            int: Sampling rate.
        """
        with torch.no_grad():
            input_ids = self.parler_tokenizer(self.speaker2_description, return_tensors="pt", padding=True).input_ids.to(self.device)
            attention_mask_input = self.parler_tokenizer(self.speaker1_description, return_tensors="pt", padding=True).attention_mask.to(self.device)
            
            prompt_input_ids = self.parler_tokenizer(text, return_tensors="pt", padding=True).input_ids.to(self.device)
            attention_mask_prompt = self.parler_tokenizer(text, return_tensors="pt", padding=True).attention_mask.to(self.device)
            
            # Pass all required arguments to generate() for reliable behavior
            generation = self.parler_model.generate(
                input_ids=input_ids,
                attention_mask=attention_mask_input,  # Set attention mask for input IDs
                prompt_input_ids=prompt_input_ids,
                prompt_attention_mask=attention_mask_prompt  # Set prompt attention mask
            )
            audio_arr = generation.cpu().numpy().squeeze()
        
        # inputs = self.bark_processor(text, voice_preset=self.voice_preset).to(self.device)
        # speech_output = self.bark_model.generate(**inputs, temperature=0.9, semantic_temperature=0.8)
        # audio_arr = speech_output[0].cpu().numpy()
        # return audio_arr, self.bark_sampling_rate
        # Tokenize input text and obtain input IDs and attention mask
        # inputs = self.bark_processor(text, voice_preset=self.voice_preset).to(self.device)
        # speech_output = self.bark_model.generate(**inputs, temperature=0.9, semantic_temperature=0.8)
        # audio_arr = speech_output[0].cpu().numpy()
        return audio_arr, self.parler_model.config.sampling_rate

    
    #@spaces.GPU
    @staticmethod
    def numpy_to_audio_segment(audio_arr, sampling_rate):
        """
        Convert numpy array to AudioSegment.
        
        Args:
            audio_arr (np.array): Numpy array of audio data.
            sampling_rate (int): Sampling rate of the audio.
        
        Returns:
            AudioSegment: Converted audio segment.
        """
        audio_int16 = (audio_arr * 32767).astype(np.int16)
        byte_io = io.BytesIO()
        wavfile.write(byte_io, sampling_rate, audio_int16)
        byte_io.seek(0)
        return AudioSegment.from_wav(byte_io)
    
    #@spaces.GPU(duration=300)
    def generate_audio(self):
        """
        Converts the transcript into audio and saves it to a file.
        
        Returns:
            str: Path to the saved audio file.
        """
        transcript = self.load_transcript()
        final_audio = None

        for speaker, text in tqdm(transcript, desc="Generating podcast segments", unit="segment"):
            if speaker == "Speaker 1":
                audio_arr, rate = self.generate_speaker1_audio(text)
            else:  # Speaker 2
                audio_arr, rate = self.generate_speaker2_audio(text)
            
            # Convert to AudioSegment
            audio_segment = self.numpy_to_audio_segment(audio_arr, rate)
            
            # Add segment to final audio
            if final_audio is None:
                final_audio = audio_segment
            else:
                final_audio += audio_segment
            torch.cuda.empty_cache()

        # Export final audio to MP3
        final_audio.export(self.output_audio_path, format="mp3", bitrate="192k", parameters=["-q:a", "0"])
        return self.output_audio_path