import torch
import gradio as gr
from transformers import AutoProcessor, AutoModel
from utils import (
    convert_frames_to_gif,
    download_youtube_video,
    get_num_total_frames,
    sample_frames_from_video_file,
)

FRAME_SAMPLING_RATE = 4
DEFAULT_MODEL = "microsoft/xclip-base-patch16-zero-shot"

VALID_ZEROSHOT_VIDEOCLASSIFICATION_MODELS = [
    "microsoft/xclip-base-patch32",
    "microsoft/xclip-base-patch16-zero-shot",
    "microsoft/xclip-base-patch16-kinetics-600",
    "microsoft/xclip-large-patch14ft/xclip-base-patch32-16-frames",
    "microsoft/xclip-large-patch14",
    "microsoft/xclip-base-patch16-hmdb-4-shot",
    "microsoft/xclip-base-patch16-16-frames",
    "microsoft/xclip-base-patch16-hmdb-2-shot",
    "microsoft/xclip-base-patch16-ucf-2-shot",
    "microsoft/xclip-base-patch16-ucf-8-shot",
    "microsoft/xclip-base-patch16",
    "microsoft/xclip-base-patch16-hmdb-8-shot",
    "microsoft/xclip-base-patch16-hmdb-16-shot",
    "microsoft/xclip-base-patch16-ucf-16-shot",
]

processor = AutoProcessor.from_pretrained(DEFAULT_MODEL)
model = AutoModel.from_pretrained(DEFAULT_MODEL)

examples = [
    [
        "https://www.youtu.be/l1dBM8ZECao",
        "sleeping dog,cat fight club,birds of prey",
    ],
    [
        "https://youtu.be/VMj-3S1tku0",
        "programming course,eating spaghetti,playing football",
    ],
    [
        "https://youtu.be/BRw7rvLdGzU",
        "game of thrones,the lord of the rings,vikings",
    ],
]


def select_model(model_name):
    global processor, model
    processor = AutoProcessor.from_pretrained(model_name)
    model = AutoModel.from_pretrained(model_name)


def predict(youtube_url_or_file_path, labels_text):

    if youtube_url_or_file_path.startswith("http"):
        video_path = download_youtube_video(youtube_url_or_file_path)
    else:
        video_path = youtube_url_or_file_path

    # rearrange sampling rate based on video length and model input length
    num_total_frames = get_num_total_frames(video_path)
    num_model_input_frames = model.config.vision_config.num_frames
    if num_total_frames < FRAME_SAMPLING_RATE * num_model_input_frames:
        frame_sampling_rate = num_total_frames // num_model_input_frames
    else:
        frame_sampling_rate = FRAME_SAMPLING_RATE

    labels = labels_text.split(",")

    frames = sample_frames_from_video_file(
        video_path, num_model_input_frames, frame_sampling_rate
    )
    gif_path = convert_frames_to_gif(frames, save_path="video.gif")

    inputs = processor(
        text=labels, videos=list(frames), return_tensors="pt", padding=True
    )
    # forward pass
    with torch.no_grad():
        outputs = model(**inputs)

    probs = outputs.logits_per_video[0].softmax(dim=-1).cpu().numpy()
    label_to_prob = {}
    for ind, label in enumerate(labels):
        label_to_prob[label] = float(probs[ind])

    return label_to_prob, gif_path


app = gr.Blocks()
with app:
    gr.Markdown(
        "# **<p align='center'>PROGTOG VIOLENCE DETECTION</p>**"
    )

    with gr.Row():
        with gr.Column():
            model_names_dropdown = gr.Dropdown(
                choices=VALID_ZEROSHOT_VIDEOCLASSIFICATION_MODELS,
                label="Model:",
                show_label=True,
                value=DEFAULT_MODEL,
            )
            model_names_dropdown.change(fn=select_model, inputs=model_names_dropdown)
            with gr.Tab(label="Youtube URL"):
                gr.Markdown(
                    "### **Youtube URL**"
                )
                youtube_url = gr.Textbox(label="Youtube URL:", show_label=True)
                youtube_url_labels_text = gr.Textbox(
                    label="Labels Text:", show_label=True
                )
                youtube_url_predict_btn = gr.Button(value="Predict")
            with gr.Tab(label="Local File"):
                gr.Markdown(
                    "### **Tags**"
                )
                video_file = gr.Video(label="Video File:", show_label=True)
                local_video_labels_text = gr.Textbox(
                    label="Labels Text:", show_label=True
                )
                local_video_predict_btn = gr.Button(value="Predict")
        with gr.Column():
            video_gif = gr.Image(
                label="Input Clip",
                show_label=True,
            )
        with gr.Column():
            predictions = gr.Label(label="Predictions:", show_label=True)

    # gr.Markdown("**Examples:**")
    # gr.Examples(
    #     examples,
    #     [youtube_url, youtube_url_labels_text],
    #     [predictions, video_gif],
    #     fn=predict,
    #     cache_examples=True,
    # )

    youtube_url_predict_btn.click(
        predict,
        inputs=[youtube_url, youtube_url_labels_text],
        outputs=[predictions, video_gif],
    )
    local_video_predict_btn.click(
        predict,
        inputs=[video_file, local_video_labels_text],
        outputs=[predictions, video_gif],
    )
    # gr.Markdown(
    #     """
    #     \n Demo created by: <a href=\"https://github.com/fcakyon\">fcakyon</a>.
    #     <br> Based on this <a href=\"https://huggingface.co/docs/transformers/main/model_doc/xclip">HuggingFace model</a>.
    #     """
    # )

app.launch()