# -*- coding: utf-8 -*-

import os
import re
import torch
from transformers import (
    AutoTokenizer,
    AutoModel,
    T5ForConditionalGeneration,
    MBartForConditionalGeneration,
    AutoModelForSeq2SeqLM,
)
from tqdm.auto import tqdm
import streamlit as st
from typing import Dict, List


@st.cache_resource
def load_model(model_name, device):
    print(f"Using model {model_name}")
    os.makedirs("cache", exist_ok=True)
    model = AutoModelForSeq2SeqLM.from_pretrained(model_name, cache_dir="cache")
    model.to(device)

    model_name = model_name.split("/")[-1]
    load_model_path = os.path.join("models", f"{model_name}-best_loss.bin")
    print(f"Loading model from {load_model_path}")
    model.load_state_dict(
        torch.load(load_model_path, map_location=torch.device(device))
    )

    return model


@st.cache_resource
def load_tokenizer(model_name):
    print(f"Loading tokenizer {model_name}")
    if "mbart" in model_name.lower():
        tokenizer = AutoTokenizer.from_pretrained(
            model_name, src_lang="vi_VN", tgt_lang="vi_VN"
        )
        # tokenizer.src_lang = "vi_VN"
        # tokenizer.tgt_lang = "vi_VN"
    else:
        tokenizer = AutoTokenizer.from_pretrained(model_name)

    return tokenizer


def prepare_batch_model_inputs(batch, tokenizer, max_len, is_train=False, device="cpu"):
    inputs = tokenizer(
        batch["src"],
        text_target=batch["tgt"] if is_train else None,
        padding="longest",
        max_length=max_len,
        truncation=True,
        return_tensors="pt",
    )

    for k, v in inputs.items():
        inputs[k] = v.to(device)

    return inputs


def prepare_single_model_inputs(src, tokenizer, max_len, device="cpu"):
    inputs = tokenizer(
        src,
        padding="longest",
        max_length=max_len,
        truncation=True,
        return_tensors="pt",
    )

    for k, v in inputs.items():
        inputs[k] = v.to(device)

    return inputs


def make_input_sentence_from_strings(data):
    # data = {
    #     "CHỈ TIÊU": objective_name,
    #     "ĐƠN VỊ": unit,
    #     "ĐIỀU KIỆN": condition,
    #     "KPI mục tiêu tháng": kpi_target,
    #     "Đánh giá": evaluation_value,
    #     "Thời gian báo cáo": current_time,
    #     f"T{current_time[1]}.{current_time[0]} thực tế": real_value,
    #     "Previous month value key": f"T{previous_month[1]}.{previous_month[0]}",
    #     f"T{previous_month[1]}.{previous_month[0]}": previous_month_value,
    #     "Previous year value key": f"T{previous_year[1]}.{previous_year[0]}",
    #     f"T{previous_year[1]}.{previous_year[0]}": previous_year_value,
    #     "Previous month compare key": f"So sánh T{previous_month[1]}.{previous_month[0]} Tăng giảm",
    #     f"So sánh T{previous_month[1]}.{previous_month[0]} Tăng giảm": previous_month_compare,
    #     "Previous year compare key": f"So sánh T{previous_year[1]}.{previous_year[0]} Tăng giảm",
    #     "Previous month": previous_month,
    #     "Previous year": previous_year,
    # }

    previous_month_value_key = data["Previous month value key"]
    previous_year_value_key = data["Previous year value key"]
    objective_name = data["CHỈ TIÊU"]
    unit = data["ĐƠN VỊ"]
    condition = data["ĐIỀU KIỆN"]
    kpi_target = data["KPI mục tiêu tháng"]
    current_time = data["Thời gian báo cáo"]
    real_value = data[f"T{current_time[1]}.{current_time[0]} thực tế"]
    evaluation_value = data["Đánh giá"]
    previous_month_value = data[previous_month_value_key]
    previous_year_value = data[previous_year_value_key]
    previous_month_compare_key = data["Previous month compare key"]
    previous_year_compare_key = data["Previous year compare key"]
    previous_month_compare = data[previous_month_compare_key]
    previous_year_compare = data[previous_year_compare_key]
    previous_month = data["Previous month"]
    previous_year = data["Previous year"]

    # make a template string from the following example:
    # """{"CHỈ TIÊU": "Tỷ lệ kết nối thành công đến tổng đài - KHCN_Di động Vip", "ĐƠN VỊ": "%", "ĐIỀU KIỆN": ">=", "KPI mục tiêu tháng": 95.0, "Tháng 9.2022": 97.5, "Đánh giá": "Đạt", "T8.2022": 96.6, "So sánh T8.2022 Tăng giảm": 1.0, "T9.2021": 96.8, "So sánh T9.2021 Tăng giảm": 0.8}"""
    template_str = '"CHỈ TIÊU": "{}", "ĐƠN VỊ": "{}", "ĐIỀU KIỆN": "{}", "KPI mục tiêu tháng": {}, "Tháng {}.{}": {}, "Đánh giá": "{}", "T{}.{}": {}, "So sánh T{}.{} Tăng giảm": {}, "T{}.{}": {}, "So sánh T{}.{} Tăng giảm": {}'
    return template_str.format(
        objective_name,
        unit,
        condition,
        kpi_target,
        current_time[1],
        current_time[0],
        real_value,
        evaluation_value,
        previous_month[1],
        previous_month[0],
        previous_month_value,
        previous_month[1],
        previous_month[0],
        previous_month_compare,
        previous_year[1],
        previous_year[0],
        previous_year_value,
        previous_year[1],
        previous_year[0],
        previous_year_compare,
    )


@torch.no_grad()
def generate_description(
    input_string, model, tokenizer, device, max_len, model_name, beam_size
):
    model.eval()
    model = model.to(device)
    inputs = prepare_single_model_inputs(
        input_string, tokenizer, max_len=max_len, device=device
    )
    if "mbart" in model_name.lower():
        inputs["forced_bos_token_id"] = tokenizer.lang_code_to_id["vi_VN"]
    outputs = model.generate(
        **inputs,
        max_length=max_len,
        num_beams=beam_size,
        # early_stopping=True,
    )
    return tokenizer.batch_decode(
        outputs, skip_special_tokens=True, clean_up_tokenization_spaces=True
    )