import torch
from torch import nn
import torch.nn.functional as F
from transformers import AutoProcessor, AutoTokenizer, XCLIPVisionModel, AutoModel, AutoModelForSequenceClassification

import numpy as np
import cv2
import opensmile


class TextClassificationModel:
    def __init__(self, model, device):
        self.model = model
        self.device = device
        self.model.to(device)

    def __call__(self, input_ids, attn_mask, return_last_hidden_state=False):
        self.model.eval()
        with torch.no_grad():
            input_ids = input_ids.to(self.device)
            attn_mask = attn_mask.to(self.device)
            output = self.model(input_ids=input_ids, attention_mask=attn_mask,
                                output_hidden_states=return_last_hidden_state)
            logits = output['logits']
            pred = torch.argmax(logits, dim=1)
            if return_last_hidden_state:
                hidden_states = output['hidden_states']
        if return_last_hidden_state:
            return pred, hidden_states[-1][:, 0, :]
        else:
            return pred


class XCLIPClassificationModel(nn.Module):
    def __init__(self, num_labels):
        super(XCLIPClassificationModel, self).__init__()
        self.base_model = XCLIPVisionModel.from_pretrained("microsoft/xclip-base-patch32")
        self.num_labels = num_labels
        hidden_size = self.base_model.config.hidden_size
        self.fc_norm = nn.LayerNorm(hidden_size)
        self.classifier = nn.Linear(hidden_size, self.num_labels)
        self.loss_fct = nn.CrossEntropyLoss()
        self.pool1 = nn.AdaptiveAvgPool1d(1)
        self.pool2 = nn.AdaptiveAvgPool1d(1)

    def forward(self, pixel_values, labels=None, return_last_hidden_state=False):
        batch_size, num_frames, num_channels, height, width = pixel_values.shape
        pixel_values = pixel_values.reshape(-1, num_channels, height, width)
        out = self.base_model(pixel_values)[0]  # [48, 50, 768]
        out = torch.transpose(out, 1, 2)  # [48, 768, 50]
        out = self.pool1(out)  # [48, 768, 1]
        out = torch.transpose(out, 1, 2)  # [48, 1, 768]
        out = out.squeeze(1)  # [48, 768]
        hidden_out = out.view(batch_size, num_frames, -1)  # [3, 16, 768]
        hidden_out = torch.transpose(hidden_out, 1, 2)  # [3, 768, 16]
        pooled_out = self.pool2(hidden_out)  # [3, 768, 1]
        pooled_out = torch.transpose(pooled_out, 1, 2)  # [3, 1, 768]
        pooled_out = pooled_out[:, 0, :]  # [3, 768]
        logits = self.classifier(pooled_out)
        loss = None
        if labels is not None:
            loss = self.loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

        if return_last_hidden_state:
            return {'logits': logits, 'loss': loss, 'last_hidden_state': pooled_out}
        else:
            return {'logits': logits, 'loss': loss}


class VideoClassificationModel:
    def __init__(self, model, device):
        self.model = model
        self.device = device
        self.model.to(device)

    def __call__(self, pixel_values, return_last_hidden_state=False):
        self.model.eval()
        with torch.no_grad():
            pixel_values = pixel_values.to(self.device)
            output = self.model(pixel_values, return_last_hidden_state=return_last_hidden_state)
            logits = output['logits']
            pred = torch.argmax(logits, dim=1)
            if return_last_hidden_state:
                hidden_states = output['last_hidden_state']
        if return_last_hidden_state:
            return pred, hidden_states
        else:
            return pred


class ConvNet(nn.Module):
    def __init__(self, num_labels, n_input=1, n_channel=32):
        super(ConvNet, self).__init__()
        self.ln0 = nn.LayerNorm((1, 6191))
        self.conv1 = nn.Conv1d(n_input, n_channel, kernel_size=3)
        self.conv2 = nn.Conv1d(n_channel, n_channel, kernel_size=3)
        self.bn1 = nn.BatchNorm1d(n_channel)
        self.bn2 = nn.BatchNorm1d(n_channel)
        self.pool1 = nn.MaxPool1d(2)
        self.fc1 = nn.Linear(n_channel*3093, 3093)
        self.fc2 = nn.Linear(3093, num_labels)
        self.flat = nn.Flatten()
        self.dropout = nn.Dropout(0.3)

    def forward(self, x, return_last_hidden_state=False):
        x = self.ln0(x)
        x = self.conv1(x)
        x = F.relu(self.bn1(x))
        x = self.conv2(x)
        x = F.relu(self.bn2(x))
        x = self.pool1(x)
        x = self.dropout(x)
        x = self.flat(x)
        hid = F.relu(self.fc1(x))
        x = self.fc2(hid)
        if not return_last_hidden_state:
            return {'logits': F.log_softmax(x, dim=1)}
        else:
            return {'logits': F.log_softmax(x, dim=1), 'last_hidden_state': hid}


class AudioClassificationModel:
    def __init__(self, model, device):
        self.model = model
        self.device = device
        self.model.to(device)

    def __call__(self, input_ids, return_last_hidden_state=False):
        self.model.eval()
        with torch.no_grad():
            input_ids = torch.tensor(input_ids, dtype=torch.float).to(self.device)
            output = self.model(input_ids, return_last_hidden_state=return_last_hidden_state)
            logits = output['logits']
            pred = torch.argmax(logits, dim=1)
            if return_last_hidden_state:
                hidden_state = output['last_hidden_state']
        if return_last_hidden_state:
            return pred, hidden_state
        else:
            return pred


class MultimodalClassificationModel(nn.Module):
    def __init__(self, text_model, video_model, audio_model, num_labels, input_size, hidden_size=256):
        super(MultimodalClassificationModel, self).__init__()
        self.text_model = text_model
        self.video_model = video_model
        self.audio_model = audio_model
        self.num_labels = num_labels
        self.linear1 = nn.Linear(input_size, hidden_size)
        self.linear2 = nn.Linear(hidden_size, self.num_labels)
        self.relu1 = nn.ReLU()
        self.drop1 = nn.Dropout()
        self.loss_func = nn.CrossEntropyLoss()

    def forward(self, batch, labels=None):
        text_pred, text_last_hidden = self.text_model(
            batch['text']['input_ids'].squeeze(1),
            batch['text']['attention_mask'].squeeze(1),
            return_last_hidden_state=True
        )
        video_pred, video_last_hidden = self.video_model(
            batch['video']['pixel_values'].squeeze(1),
            return_last_hidden_state=True
        )
        audio_pred, audio_last_hidden = self.audio_model(
            batch['audio'],
            return_last_hidden_state=True
        )
        concat_input = torch.cat((text_last_hidden, video_last_hidden, audio_last_hidden), dim=1)
        hidden_state = self.linear1(concat_input)
        hidden_state = self.drop1(self.relu1(hidden_state))
        logits = self.linear2(hidden_state)
        loss = None
        if labels is not None:
            loss = self.loss_func(logits.view(-1, self.num_labels), labels.view(-1))
        return {'logits': logits, 'loss': loss}


class MainModel:
    def __init__(self, model, device):
        self.model = model
        self.device = device
        self.model.to(device)

    def __call__(self, batch):
        self.model.eval()
        with torch.no_grad():
            output = self.model(batch)
            logits = output['logits']
            pred = torch.argmax(logits, dim=1)
        return pred
    
def prepare_models(num_labels: int, 
                   text_model_path: str,
                   video_model_path: str,
                   audio_model_path: str,
                   device: str='cpu'):
    # TEXT
    text_model_name = 'bert-large-uncased'
    text_base_model = AutoModelForSequenceClassification.from_pretrained(
        text_model_name,
        num_labels=num_labels
    )
    state_dict = torch.load(text_model_path, map_location=torch.device('cpu'))
    text_base_model.load_state_dict(state_dict, strict=False)
    text_model = TextClassificationModel(text_base_model, device=device)

    # VIDEO
    video_base_model = XCLIPClassificationModel(num_labels)
    state_dict = torch.load(video_model_path, map_location=torch.device('cpu'))
    video_base_model.load_state_dict(state_dict, strict=False)
    video_model = VideoClassificationModel(video_base_model, device=device)

    # AUDIO
    audio_base_model = ConvNet(num_labels)
    checkpoint = torch.load(audio_model_path, map_location=torch.device('cpu'))
    audio_base_model.load_state_dict(checkpoint['model_state_dict'])
    audio_model = AudioClassificationModel(audio_base_model, device=device)

    return text_model, video_model, audio_model

def sample_frame_indices(seg_len, clip_len=16, frame_sample_rate=4, mode="video"):
    # seg_len -- how many frames are received
    # clip_len -- how many frames to return
    converted_len = int(clip_len * frame_sample_rate)
    converted_len = min(converted_len, seg_len-1)
    end_idx = np.random.randint(converted_len, seg_len)
    start_idx = end_idx - converted_len
    if mode == "video":
        indices = np.linspace(start_idx, end_idx, num=clip_len)
    else:
        indices = np.linspace(start_idx, end_idx, num=clip_len*frame_sample_rate)
    indices = np.clip(indices, start_idx, end_idx - 1).astype(np.int64)
    return indices

def get_frames(file_path, clip_len=16,):
    cap = cv2.VideoCapture(file_path)
    v_len = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    indices = sample_frame_indices(v_len)

    frames = []
    for fn in range(v_len):
        success, frame = cap.read()
        if success is False:
            continue
        if (fn in indices):
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            res = cv2.resize(frame[90:-80, 60:-100], dsize=(224, 224), interpolation=cv2.INTER_CUBIC)
            frames.append(res)
    cap.release()

    if len(frames) < clip_len:
        add_num = clip_len - len(frames)
        frames_to_add = [frames[-1]] * add_num
        frames.extend(frames_to_add)

    return frames

def prepare_data_input(text: str,
                       video_path: str):
    # VIDEO
    video_frames = get_frames(video_path)
    video_model_name = "microsoft/xclip-base-patch32"
    video_feature_extractor = AutoProcessor.from_pretrained(video_model_name)
    video_encoding = video_feature_extractor(videos=video_frames, return_tensors="pt")
    # AUDIO
    smile = opensmile.Smile(
        opensmile.FeatureSet.ComParE_2016,
        opensmile.FeatureLevel.Functionals,
        sampling_rate=16000,    
        resample=True,    
        num_workers=5,
        verbose=True,
    )
    audio_features = smile.process_files([video_path])
    redundant_feat = open('files/redundant_feat.txt').read().split(',')
    audio_features.drop(columns=redundant_feat, inplace=True)
    # TEXT 
    text_model_name = 'bert-large-uncased'
    tokenizer = AutoTokenizer.from_pretrained(text_model_name)
    text_encoding = tokenizer(text,
                          padding='max_length',
                          truncation=True,
                          max_length=128,
                          return_tensors='pt')
    return {'text': text_encoding, 'video': video_encoding, 'audio': audio_features.values.reshape((1, 1, 6191))}

def infer_multimodal_model(text: str,
                          video_path: str,
                          model_pathes: dict):
    label2id = {'anger': 0, 'disgust': 1, 'fear': 2, 'joy': 3, 'neutral': 4, 'sadness': 5, 'surprise': 6}
    id2label = {v: k for k, v in label2id.items()}
    num_labels = 7
    text_model, video_model, audio_model = prepare_models(num_labels,
                                                          model_pathes['text_model_path'],
                                                          model_pathes['video_model_path'],
                                                          model_pathes['audio_model_path'],)
    multi_model = MultimodalClassificationModel(
        text_model,
        video_model,
        audio_model,
        num_labels,
        input_size=4885, 
        hidden_size=512
    )
    checkpoint = torch.load(model_pathes['multimodal_model_path'], map_location=torch.device('cpu'))
    multi_model.load_state_dict(checkpoint)
    device = 'cpu'
    final_model = MainModel(multi_model, device=device)
    batch = prepare_data_input(text, video_path)
    label = final_model(batch).detach().cpu().tolist()
    return id2label[label[0]]