import os os.system('pip install -U transformers==4.44.2') import sys import shutil import torch import base64 import argparse import gradio as gr import numpy as np from PIL import Image from huggingface_hub import snapshot_download import spaces # == download weights == tiny_model_dir = snapshot_download('wanderkid/unimernet_tiny', local_dir='./models/unimernet_tiny') small_model_dir = snapshot_download('wanderkid/unimernet_small', local_dir='./models/unimernet_small') base_model_dir = snapshot_download('wanderkid/unimernet_base', local_dir='./models/unimernet_base') os.system("ls -l models/unimernet_tiny") os.system("ls -l models/unimernet_small") os.system("ls -l models/unimernet_base") # == download weights == sys.path.insert(0, os.path.join(os.getcwd(), "..")) from unimernet.common.config import Config import unimernet.tasks as tasks from unimernet.processors import load_processor template_html = """ Title
""" def latex2html(latex_code): latex_code = '\\[' + latex_code + '\\]' latex_code = latex_code.replace('\left(', '(').replace('\\right)', ')').replace('\left[', '[').replace('\\right]', ']').replace('\left{', '{').replace('\\right}', '}').replace('\left|', '|').replace('\\right|', '|').replace('\left.', '.').replace('\\right.', '.') latex_code = latex_code.replace('"', '``').replace('$', '') latex_code_list = latex_code.split('\n') gt= '' for out in latex_code_list: gt += '"' + out.replace('\\', '\\\\') + r'\n' + '"' + '+' + '\n' gt = gt[:-2] lines = template_html.split("const text =") new_web = lines[0] + 'const text =' + gt + lines[1] return new_web def load_model_and_processor(cfg_path): args = argparse.Namespace(cfg_path=cfg_path, options=None) cfg = Config(args) task = tasks.setup_task(cfg) model = task.build_model(cfg) vis_processor = load_processor('formula_image_eval', cfg.config.datasets.formula_rec_eval.vis_processor.eval) return model, vis_processor @spaces.GPU def recognize_image(input_img, model_type): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if model_type == "base": model = model_base.to(device) elif model_type == "small": model = model_small.to(device) else: model = model_tiny.to(device) model.eval() if len(input_img.shape) == 3: input_img = input_img[:, :, ::-1].copy() img = Image.fromarray(input_img) image = vis_processor(img).unsqueeze(0).to(device) output = model.generate({"image": image}) latex_code = output["pred_str"][0] html_code = latex2html(latex_code) encoded_html = base64.b64encode(html_code.encode('utf-8')).decode('utf-8') iframe_src = f"data:text/html;base64,{encoded_html}" iframe = f'' return latex_code, iframe def gradio_reset(): return gr.update(value=None), gr.update(value=None), gr.update(value=None) if __name__ == "__main__": root_path = os.path.abspath(os.getcwd()) # == load model == print("load tiny model ...") model_tiny, vis_processor = load_model_and_processor(os.path.join(root_path, "cfg_tiny.yaml")) print("load small model ...") model_small, vis_processor = load_model_and_processor(os.path.join(root_path, "cfg_small.yaml")) print("load base model ...") model_base, vis_processor = load_model_and_processor(os.path.join(root_path, "cfg_base.yaml")) print("== load all models done. ==") # == load model == with open("header.html", "r") as file: header = file.read() with gr.Blocks() as demo: gr.HTML(header) with gr.Row(): with gr.Column(): model_type = gr.Radio( choices=["tiny", "small", "base"], value="tiny", label="Model Type", interactive=True, ) input_img = gr.Image(label=" ", interactive=True) with gr.Row(): clear = gr.Button("Clear") predict = gr.Button(value="Recognize", interactive=True, variant="primary") with gr.Accordion("Examples:"): example_root = os.path.join(os.path.dirname(__file__), "examples") gr.Examples( examples=[os.path.join(example_root, _) for _ in os.listdir(example_root) if _.endswith("png")], inputs=input_img, ) with gr.Column(): gr.Button(value="Predict Result:", interactive=False) pred_latex = gr.Textbox(label='Predict Latex', interactive=False) output_html = gr.HTML(label="Rendered html", show_label=True) clear.click(gradio_reset, inputs=None, outputs=[input_img, pred_latex, output_html]) predict.click(recognize_image, inputs=[input_img, model_type], outputs=[pred_latex, output_html]) demo.launch(server_name="0.0.0.0", server_port=7860, debug=True)