import os
os.system('pip install -U transformers==4.44.2')
import sys
import shutil
import torch
import base64
import argparse
import gradio as gr
import numpy as np
from PIL import Image
from huggingface_hub import snapshot_download
import spaces

# == download weights ==
tiny_model_dir = snapshot_download('wanderkid/unimernet_tiny', local_dir='./models/unimernet_tiny')
small_model_dir = snapshot_download('wanderkid/unimernet_small', local_dir='./models/unimernet_small')
base_model_dir = snapshot_download('wanderkid/unimernet_base', local_dir='./models/unimernet_base')
os.system("ls -l models/unimernet_tiny")
os.system("ls -l models/unimernet_small")
os.system("ls -l models/unimernet_base")
# == download weights ==

sys.path.insert(0, os.path.join(os.getcwd(), ".."))
from unimernet.common.config import Config
import unimernet.tasks as tasks
from unimernet.processors import load_processor


template_html = """<!DOCTYPE html>
<html lang="en" data-lt-installed="true"><head>
  <meta charset="UTF-8">
  <title>Title</title>
  <script>
    const text = 
  </script>
  <style>
    #content {
      max-width: 800px;
      margin: auto;
    }
  </style>
  <script>
    let script = document.createElement('script');
    script.src = "https://cdn.jsdelivr.net/npm/mathpix-markdown-it@1.3.6/es5/bundle.js";
    document.head.append(script);

    script.onload = function() {
      const isLoaded = window.loadMathJax();
      if (isLoaded) {
        console.log('Styles loaded!')
      }

      const el = window.document.getElementById('content-text');
      if (el) {
        const options = {
          htmlTags: true
        };
        const html = window.render(text, options);
        el.outerHTML = html;
      }
    };
  </script>
</head>
<body>
  <div id="content"><div id="content-text"></div></div>
</body>
</html>
"""

def latex2html(latex_code):
    latex_code = '\\[' + latex_code + '\\]'
    latex_code = latex_code.replace('\left(', '(').replace('\\right)', ')').replace('\left[', '[').replace('\\right]', ']').replace('\left{', '{').replace('\\right}', '}').replace('\left|', '|').replace('\\right|', '|').replace('\left.', '.').replace('\\right.', '.')
    
    latex_code = latex_code.replace('"', '``').replace('$', '')

    latex_code_list = latex_code.split('\n')
    gt= ''
    for out in latex_code_list:
        gt +=  '"' + out.replace('\\', '\\\\') + r'\n' + '"' + '+' + '\n' 
    
    gt = gt[:-2]

    lines = template_html.split("const text =")
    new_web = lines[0] + 'const text ='  + gt  + lines[1]
    return new_web

def load_model_and_processor(cfg_path):
    args = argparse.Namespace(cfg_path=cfg_path, options=None)
    cfg = Config(args)
    task = tasks.setup_task(cfg)
    model = task.build_model(cfg)
    vis_processor = load_processor('formula_image_eval', cfg.config.datasets.formula_rec_eval.vis_processor.eval)
    return model, vis_processor

@spaces.GPU
def recognize_image(input_img, model_type):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if model_type == "base":
        model = model_base.to(device)
    elif model_type == "small":
        model = model_small.to(device)
    else:
        model = model_tiny.to(device)

    model.eval()
        
    if len(input_img.shape) == 3:
            input_img = input_img[:, :, ::-1].copy()

    img = Image.fromarray(input_img)
    image = vis_processor(img).unsqueeze(0).to(device)
    output = model.generate({"image": image})
    latex_code = output["pred_str"][0]
    html_code = latex2html(latex_code)
    encoded_html = base64.b64encode(html_code.encode('utf-8')).decode('utf-8')
    iframe_src = f"data:text/html;base64,{encoded_html}"
    iframe = f'<iframe src="{iframe_src}" width="100%" height="600px"></iframe>'
    return latex_code, iframe
    
def gradio_reset():
    return gr.update(value=None), gr.update(value=None), gr.update(value=None)

    
if __name__ == "__main__":
    root_path = os.path.abspath(os.getcwd())
    # == load model ==
    print("load tiny model ...")
    model_tiny, vis_processor = load_model_and_processor(os.path.join(root_path, "cfg_tiny.yaml"))
    print("load small model ...")
    model_small, vis_processor = load_model_and_processor(os.path.join(root_path, "cfg_small.yaml"))
    print("load base model ...")
    model_base, vis_processor = load_model_and_processor(os.path.join(root_path, "cfg_base.yaml"))
    print("== load all models done. ==")
    # == load model ==
    
    with open("header.html", "r") as file:
        header = file.read()
    with gr.Blocks() as demo:
        gr.HTML(header)
        
        with gr.Row():
            with gr.Column():
                model_type = gr.Radio(
                        choices=["tiny", "small", "base"],
                        value="tiny",
                        label="Model Type",
                        interactive=True,
                    )
                input_img = gr.Image(label=" ", interactive=True)
                with gr.Row():
                    clear = gr.Button("Clear")
                    predict = gr.Button(value="Recognize", interactive=True, variant="primary")
                    
                with gr.Accordion("Examples:"):
                    example_root = os.path.join(os.path.dirname(__file__), "examples")
                    gr.Examples(
                        examples=[os.path.join(example_root, _) for _ in os.listdir(example_root) if
                                    _.endswith("png")],
                        inputs=input_img,
                    )
            with gr.Column():
                gr.Button(value="Predict Result:", interactive=False)
                pred_latex = gr.Textbox(label='Predict Latex', interactive=False)
                output_html = gr.HTML(label="Rendered html", show_label=True)
    
        clear.click(gradio_reset, inputs=None, outputs=[input_img, pred_latex, output_html])
        predict.click(recognize_image, inputs=[input_img, model_type], outputs=[pred_latex, output_html])
    
    demo.launch(server_name="0.0.0.0", server_port=7860, debug=True)