from typing import Optional
import os
os.environ["WANDB_DISABLED"] = "true"

import numpy as np

import gradio as gr
import torch
import torch.nn as nn

import torchvision
from torchvision.transforms import CenterCrop, ConvertImageDtype, Normalize, Resize
from torchvision.transforms.functional import InterpolationMode
from torchvision import transforms
from torchvision.io import ImageReadMode, read_image


from transformers import CLIPModel, AutoModel
from huggingface_hub import hf_hub_download
from safetensors.torch import load_model

from datasets import load_dataset, load_metric
from transformers import (
    AutoConfig,
AutoImageProcessor,
    AutoModelForSequenceClassification,
    AutoTokenizer,
    logging,
)

class Transform(torch.nn.Module):
    def __init__(self, image_size, mean, std):
        super().__init__()
        self.transforms = torch.nn.Sequential(
            Resize([image_size], interpolation=InterpolationMode.BICUBIC),
            CenterCrop(image_size),
            ConvertImageDtype(torch.float),
            Normalize(mean, std),
        )

    def forward(self, x) -> torch.Tensor:
        """`x` should be an instance of `PIL.Image.Image`"""
        with torch.no_grad():
            x = self.transforms(x)
        return x



class VisionTextDualEncoderModel(nn.Module):
    def __init__(self, num_classes):
        super(VisionTextDualEncoderModel, self).__init__()

        # Load the XLM-RoBERTa model
        self.text_encoder = AutoModel.from_pretrained("cardiffnlp/twitter-xlm-roberta-base-sentiment-multilingual")

        # Define your vision model (e.g., using torchvision)
        self.vision_encoder = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")

        vision_output_dim = self.vision_encoder.config.vision_config.hidden_size


        # Combine the modalities
        self.fc = nn.Linear(
            self.text_encoder.config.hidden_size + vision_output_dim, num_classes
        )

    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        pixel_values: Optional[torch.FloatTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        return_loss: Optional[bool] = None,
        token_type_ids: Optional[torch.LongTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        labels: Optional[torch.LongTensor] = None,
    ):
        # Encode text inputs
        text_outputs = self.text_encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
        ).pooler_output

        # Encode vision inputs
        vision_outputs = self.vision_encoder.vision_model(
            pixel_values=pixel_values,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        # Concatenate text and vision features
        combined_features = torch.cat(
            (text_outputs, vision_outputs.pooler_output), dim=1
        )

        # Forward through a linear layer for classification
        logits = self.fc(combined_features)

        return {"logits": logits}

id2label = {0: "negative", 1: "neutral", 2: "positive"}
label2id = {"negative": 0, "neutral": 1, "positive": 2}

tokenizer = AutoTokenizer.from_pretrained("cardiffnlp/twitter-xlm-roberta-base-sentiment-multilingual")

model = VisionTextDualEncoderModel(num_classes=3)
config = model.vision_encoder.config

# https://huggingface.co/FFZG-cleopatra/M2SA/blob/main/model.safetensors
sf_filename = hf_hub_download("FFZG-cleopatra/M2SA", filename="model.safetensors")

load_model(model, sf_filename) 
image_processor = AutoImageProcessor.from_pretrained("openai/clip-vit-base-patch32")

def predict_sentiment(text, image):
    print(text, image)
    
    image = read_image(image, mode=ImageReadMode.RGB)
    
    
    text_inputs = tokenizer(
            text,
            max_length=512,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        )
    
    image_transformations = Transform(
        config.vision_config.image_size,
        image_processor.image_mean,
        image_processor.image_std,
    )
    image_transformations = torch.jit.script(image_transformations)
    pixel_values = image_transformations(image)
    text_inputs["pixel_values"] = pixel_values.unsqueeze(0)
   
    prediction = None
    with torch.no_grad():
        outputs = model(**text_inputs)
        print(outputs)
        prediction = np.argmax(outputs["logits"], axis=-1)
        print(id2label[prediction[0].item()])
    return id2label[prediction[0].item()]


interface = gr.Interface(
    fn=lambda text, image: predict_sentiment(text, image),
    inputs=[gr.Textbox(),gr.Image(type="filepath")],
    outputs=['text'],
    title='Multilingual Multimodal Sentiment Analysis',
    examples= [["I am enjoying","A_Sep20_14_1189155141.jpg"]],
    description='Get the positive/neutral/negative sentiment for the given input.'
)

interface.launch(inline = False)