from pyannote.audio import Pipeline
from pydub import AudioSegment
import os
from transformers import WhisperForConditionalGeneration, WhisperProcessor
import torchaudio
import torch
import re
from transformers import pipeline
import spaces


device = 0 if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float32


MODEL_NAME = "openai/whisper-large-v3"
CKPT = "projecte-aina/whisper-large-v3-tiny-caesar"
BATCH_SIZE = 1
model = WhisperForConditionalGeneration.from_pretrained(MODEL_NAME, torch_dtype=torch_dtype).to(device)
processor = WhisperProcessor.from_pretrained(MODEL_NAME)
pipeline_vad = Pipeline.from_pretrained("./pyannote/config.yaml") 
threshold = 10000 
segments_dir = "."

pipe = pipeline(
    task="automatic-speech-recognition",
    model=CKPT,
    chunk_length_s=30,
    device=device,
    token=os.getenv("HF_TOKEN")
    ) 

def post_process_transcription(transcription, max_repeats=2):
    tokens = re.findall(r'\b\w+\'?\w*\b[.,!?]?', transcription)

    cleaned_tokens = []
    repetition_count = 0
    previous_token = None

    for token in tokens:
        reduced_token = re.sub(r"(\w{1,3})(\1{2,})", "", token)

        if reduced_token == previous_token:
            repetition_count += 1
            if repetition_count <= max_repeats:
                cleaned_tokens.append(reduced_token)
        else:
            repetition_count = 1
            cleaned_tokens.append(reduced_token)

        previous_token = reduced_token

    cleaned_transcription = " ".join(cleaned_tokens)
    cleaned_transcription = re.sub(r'\s+', ' ', cleaned_transcription).strip()

    return cleaned_transcription


def convert_forced_to_tokens(forced_decoder_ids):
    forced_decoder_tokens = []
    for i, (idx, token) in enumerate(forced_decoder_ids):
        if token is not None:
            forced_decoder_tokens.append([idx, processor.tokenizer.decode(token)])
        else:
            forced_decoder_tokens.append([idx, token])
    return forced_decoder_tokens

def generate_1st_chunk(audio):

    input_audio, sample_rate = torchaudio.load(audio)
    input_audio = torchaudio.transforms.Resample(sample_rate, 16000)(input_audio)
    
    input_speech = input_audio[0]

    input_features = processor(input_speech, 
                                    sampling_rate=16_000, 
                                    return_tensors="pt", torch_dtype=torch_dtype).input_features.to(device)

    forced_decoder_ids = []
    forced_decoder_ids.append([1,50270]) #[1, '<|ca|>']
    forced_decoder_ids.append([2,50262]) #[2, '<|es|>']
    forced_decoder_ids.append([3,50360]) #[3, '<|transcribe|>']

    forced_decoder_ids_modified = forced_decoder_ids
    idx = processor.tokenizer.all_special_tokens.index("<|startofprev|>")
    forced_bos_token_id = processor.tokenizer.all_special_ids[idx]
    prompt = "Antes de 'digui'm', '112'. 112, digui'm. Hola, puc parlar en castellà? Sí, digui, diga. Sí, mire: a veces al abrir la puerta de mi piso tengo una persona ahí. Vale, avisamos a la Guàrdia Urbana, ¿de acuerdo? Vale, perfecto. Gracias. Gracias. Buen día."
    prompt_tokens = processor.tokenizer(prompt, add_special_tokens=False).input_ids

    # we need to force these tokens
    forced_decoder_ids = []
    for idx, token in enumerate(prompt_tokens):
        # indexing starts from 1 for forced tokens (token at position 0 is the SOS token)
        forced_decoder_ids.append([idx + 1, token])
            
    # now we add the SOS token at the end
    offset = len(forced_decoder_ids)
    forced_decoder_ids.append([offset + 1, model.generation_config.decoder_start_token_id])

    # now we need to append the rest of the prefix tokens (lang, task, timestamps)
    offset = len(forced_decoder_ids)
    for idx, token in forced_decoder_ids_modified:
        forced_decoder_ids.append([idx + offset , token])

    model.generation_config.forced_decoder_ids = forced_decoder_ids

    pred_ids = model.generate(input_features, 
                                    return_timestamps=True,
                                    max_new_tokens=128,
                                    decoder_start_token_id=forced_bos_token_id)
    #exclude prompt from output
    forced_decoder_tokens = convert_forced_to_tokens(forced_decoder_ids)
    output = processor.decode(pred_ids[0][len(forced_decoder_tokens) + 1:], skip_special_tokens=True)

    return output[1:]

def generate_2nd_chuk(audio):

    input_audio, sample_rate = torchaudio.load(audio)
    input_audio = torchaudio.transforms.Resample(sample_rate, 16000)(input_audio)
    
    input_speech = input_audio[0]

    input_features = processor(input_speech, 
                                    sampling_rate=16_000, 
                                    return_tensors="pt", torch_dtype=torch_dtype).input_features.to(device)
    forced_decoder_ids = []

    forced_decoder_ids.append([1,50270]) #[1, '<|ca|>']
    forced_decoder_ids.append([2,50262]) #[2, '<|es|>']
    forced_decoder_ids.append([3,50360]) #[3, '<|transcribe|>']

    forced_decoder_ids_modified = forced_decoder_ids
    idx = processor.tokenizer.all_special_tokens.index("<|startofprev|>")
    forced_bos_token_id = processor.tokenizer.all_special_ids[idx]
        
    prompt = "112, digui'm. Hola, puc parlar en castellà? Sí, digui, diga. Sí, mire: a veces al abrir la puerta de mi piso tengo una persona ahí. Vale, avisamos a la Guàrdia Urbana, ¿de acuerdo? Vale, perfecto. Gracias. Gracias. Buen día."  
    prompt_tokens = processor.tokenizer(prompt, add_special_tokens=False).input_ids

    # we need to force these tokens
    forced_decoder_ids = []
    for idx, token in enumerate(prompt_tokens):
        # indexing starts from 1 for forced tokens (token at position 0 is the SOS token)
        forced_decoder_ids.append([idx + 1, token])
            
    # now we add the SOS token at the end
    offset = len(forced_decoder_ids)
    forced_decoder_ids.append([offset + 1, model.generation_config.decoder_start_token_id])

    # now we need to append the rest of the prefix tokens (lang, task, timestamps)
    offset = len(forced_decoder_ids)
    for idx, token in forced_decoder_ids_modified:
        forced_decoder_ids.append([idx + offset , token])

    model.generation_config.forced_decoder_ids = forced_decoder_ids

    pred_ids = model.generate(input_features, 
                                    return_timestamps=True,
                                    max_new_tokens=128,
                                    decoder_start_token_id=forced_bos_token_id)
    #exclude prompt from output
    forced_decoder_tokens = convert_forced_to_tokens(forced_decoder_ids)
    output = processor.decode(pred_ids[0][len(forced_decoder_tokens) + 1:], skip_special_tokens=True)

    return output[1:]

def processing_vad_threshold(audio, output_vad, threshold, max_duration, concatenated_segment):

    transcription_audio = ""
    is_first_chunk = True
    for speech in output_vad.get_timeline().support():
        start, end = speech.start, speech.end
        segment_duration = (end - start) * 1000
        segment_audio = audio[start * 1000:end * 1000]

        if max_duration + segment_duration < threshold:
            concatenated_segment += audio[start * 1000:end * 1000]
            max_duration += segment_duration
        
        else:
            if len(concatenated_segment) > 0:
                temp_segment_path = os.path.join(segments_dir, f"temp_segment.wav")
                concatenated_segment.export(temp_segment_path, format="wav")

                if is_first_chunk:
                    output = generate_1st_chunk(temp_segment_path)
                    is_first_chunk = False
                else:
                    output = generate_2nd_chuk(temp_segment_path)
                transcription_audio = transcription_audio + output
                max_duration = segment_duration
                concatenated_segment = segment_audio
        
    # Process any remaining audio in the concatenated_segment
    if len(concatenated_segment) > 0:
        temp_segment_path = os.path.join(segments_dir, f"temp_segment.wav")
        concatenated_segment.export(temp_segment_path, format="wav")
            
        output = generate_2nd_chuk(temp_segment_path)
        transcription_audio = transcription_audio + output

    return(transcription_audio)

def format_audio(audio_path):
    input_audio, sample_rate = torchaudio.load(audio_path)

    if input_audio.shape[0] == 2:  #stereo2mono
        input_audio = torch.mean(input_audio, dim=0, keepdim=True) 
    
    resampler = torchaudio.transforms.Resample(sample_rate, 16000)
    input_audio = resampler(input_audio)
    input_audio = input_audio.squeeze().numpy()
    return(input_audio)


def transcribe_pipeline(audio, task):
    text = pipe(audio, batch_size=BATCH_SIZE, generate_kwargs={"task": task}, return_timestamps=True)["text"]
    return text

def generate(audio_path, use_v5):
    audio = AudioSegment.from_wav(audio_path)
    
    temp_mono_path = None
    if audio.channels != 1: #stereo2mono
       audio = audio.set_channels(1)
       temp_mono_path = "temp_mono.wav"
       audio.export(temp_mono_path, format="wav") 
       audio_path = temp_mono_path
        
    output_vad = pipeline_vad(audio_path)
    concatenated_segment = AudioSegment.empty()
    max_duration = 0

    if use_v5:
        output = processing_vad_threshold(audio, output_vad, threshold, max_duration, concatenated_segment)
    else:  
        task = "transcribe"
        output = transcribe_pipeline(format_audio(audio_path), task)
        
    clean_output = post_process_transcription(output)
    
    if temp_mono_path and os.path.exists(temp_mono_path):
       os.remove(temp_mono_path)
        
    return clean_output