import gradio as gr
import time
import logging
import torch
from sys import platform
from transformers import pipeline, AutoModelForSpeechSeq2Seq, AutoProcessor
from transformers.utils import is_flash_attn_2_available
from languages import get_language_names
from subtitle_manager import Subtitle


logging.basicConfig(level=logging.INFO)
last_model = None
pipe = None

def write_file(output_file,subtitle):
    with open(output_file, 'w', encoding='utf-8') as f:
        f.write(subtitle)

def create_pipe(model, flash):
    if torch.cuda.is_available():
        device = "cuda:0"
    elif platform == "darwin":
        device = "mps"
    else:
        device = "cpu"
    torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
    model_id = model

    model = AutoModelForSpeechSeq2Seq.from_pretrained(
        model_id,
        torch_dtype=torch_dtype,
        low_cpu_mem_usage=True,
        use_safetensors=True,
        attn_implementation="flash_attention_2" if flash and is_flash_attn_2_available() else "sdpa",
        # eager (manual attention implementation)
        # flash_attention_2 (implementation using flash attention 2)
        # sdpa (implementation using torch.nn.functional.scaled_dot_product_attention)
        # PyTorch SDPA requirements in Transformers are not met. Please install torch>=2.1.1.
    )
    model.to(device)

    processor = AutoProcessor.from_pretrained(model_id)

    pipe = pipeline(
        "automatic-speech-recognition",
        model=model,
        tokenizer=processor.tokenizer,
        feature_extractor=processor.feature_extractor,
        # max_new_tokens=128,
        # chunk_length_s=15,
        # batch_size=16,
        torch_dtype=torch_dtype,
        device=device,
    )
    return pipe

def transcribe_webui_simple_progress(modelName, languageName, urlData, multipleFiles, microphoneData, task, flash,
                                    chunk_length_s, batch_size, progress=gr.Progress()):
    global last_model
    global pipe

    progress(0, desc="Loading Audio..")
    logging.info(f"urlData:{urlData}")
    logging.info(f"multipleFiles:{multipleFiles}")
    logging.info(f"microphoneData:{microphoneData}")
    logging.info(f"task: {task}")
    logging.info(f"is_flash_attn_2_available: {is_flash_attn_2_available()}")
    logging.info(f"chunk_length_s: {chunk_length_s}")
    logging.info(f"batch_size: {batch_size}")

    if last_model == None:
        logging.info("first model")
        progress(0.1, desc="Loading Model..")
        pipe = create_pipe(modelName, flash)
    elif modelName != last_model:
        logging.info("new model")
        torch.cuda.empty_cache()
        progress(0.1, desc="Loading Model..")
        pipe = create_pipe(modelName, flash)
    else:
        logging.info("Model not changed")
    last_model = modelName

    srt_sub = Subtitle("srt")
    vtt_sub = Subtitle("vtt")
    txt_sub = Subtitle("txt")

    files = []
    if multipleFiles:
        files+=multipleFiles
    if urlData:
        files.append(urlData)
    if microphoneData:
        files.append(microphoneData)
    logging.info(files)

    generate_kwargs = {}
    if languageName != "Automatic Detection" and modelName.endswith(".en") == False:
        generate_kwargs["language"] = languageName
    if modelName.endswith(".en") == False:
        generate_kwargs["task"] = task

    files_out = []
    for file in progress.tqdm(files, desc="Working..."):
        start_time = time.time()
        logging.info(file)
        outputs = pipe(
            file,
            chunk_length_s=chunk_length_s,#30
            batch_size=batch_size,#24
            generate_kwargs=generate_kwargs,
            return_timestamps=True,
        )
        logging.debug(outputs)
        logging.info(print(f"transcribe: {time.time() - start_time} sec."))

        file_out = file.split('/')[-1]
        srt = srt_sub.get_subtitle(outputs["chunks"])
        vtt = vtt_sub.get_subtitle(outputs["chunks"])
        txt = txt_sub.get_subtitle(outputs["chunks"])
        write_file(file_out+".srt",srt)
        write_file(file_out+".vtt",vtt)
        write_file(file_out+".txt",txt)
        files_out += [file_out+".srt", file_out+".vtt", file_out+".txt"]

    progress(1, desc="Completed!")
    
    return files_out, vtt, txt


with gr.Blocks(title="Insanely Fast Whisper") as demo:
    description = "An opinionated CLI to transcribe Audio files w/ Whisper on-device! Powered by 🤗 Transformers, Optimum & flash-attn"
    article = "Read the [documentation here](https://github.com/Vaibhavs10/insanely-fast-whisper#cli-options)."
    whisper_models = [
        "openai/whisper-tiny", "openai/whisper-tiny.en",
        "openai/whisper-base", "openai/whisper-base.en",
        "openai/whisper-small", "openai/whisper-small.en", "distil-whisper/distil-small.en",
        "openai/whisper-medium", "openai/whisper-medium.en", "distil-whisper/distil-medium.en",
        "openai/whisper-large",
        "openai/whisper-large-v1",
        "openai/whisper-large-v2", "distil-whisper/distil-large-v2",
        "openai/whisper-large-v3", "distil-whisper/distil-large-v3", "xaviviro/whisper-large-v3-catalan-finetuned-v2",
    ]
    waveform_options=gr.WaveformOptions(
        waveform_color="#01C6FF",
        waveform_progress_color="#0066B4",
        skip_length=2,
        show_controls=False,
    )

    simple_transcribe = gr.Interface(fn=transcribe_webui_simple_progress,
        description=description,
        article=article,
        inputs=[
            gr.Dropdown(choices=whisper_models, value="distil-whisper/distil-large-v2", label="Model", info="Select whisper model", interactive = True,),
            gr.Dropdown(choices=["Automatic Detection"] + sorted(get_language_names()), value="Automatic Detection", label="Language", info="Select audio voice language", interactive = True,),
            gr.Text(label="URL", info="(YouTube, etc.)", interactive = True),
            gr.File(label="Upload Files", file_count="multiple"),
            gr.Audio(sources=["upload", "microphone",], type="filepath", label="Input", waveform_options = waveform_options),
            gr.Dropdown(choices=["transcribe", "translate"], label="Task", value="transcribe", interactive = True),
            gr.Checkbox(label='Flash',info='Use Flash Attention 2'),
            gr.Number(label='chunk_length_s',value=30, interactive = True),
            gr.Number(label='batch_size',value=24, interactive = True)
        ], outputs=[
            gr.File(label="Download"),
            gr.Text(label="Transcription"), 
            gr.Text(label="Segments")
        ]
    )

if __name__ == "__main__":
    demo.launch()