UniMERNet-Demo / app.py
wanderkid's picture
fix infer bug
1320fad verified
import os
os.system('pip install -U transformers==4.44.2')
import sys
import shutil
import torch
import base64
import argparse
import gradio as gr
import numpy as np
from PIL import Image
from huggingface_hub import snapshot_download
import spaces
# == download weights ==
tiny_model_dir = snapshot_download('wanderkid/unimernet_tiny', local_dir='./models/unimernet_tiny')
small_model_dir = snapshot_download('wanderkid/unimernet_small', local_dir='./models/unimernet_small')
base_model_dir = snapshot_download('wanderkid/unimernet_base', local_dir='./models/unimernet_base')
os.system("ls -l models/unimernet_tiny")
os.system("ls -l models/unimernet_small")
os.system("ls -l models/unimernet_base")
# == download weights ==
sys.path.insert(0, os.path.join(os.getcwd(), ".."))
from unimernet.common.config import Config
import unimernet.tasks as tasks
from unimernet.processors import load_processor
template_html = """<!DOCTYPE html>
<html lang="en" data-lt-installed="true"><head>
<meta charset="UTF-8">
<title>Title</title>
<script>
const text =
</script>
<style>
#content {
max-width: 800px;
margin: auto;
}
</style>
<script>
let script = document.createElement('script');
script.src = "https://cdn.jsdelivr.net/npm/mathpix-markdown-it@1.3.6/es5/bundle.js";
document.head.append(script);
script.onload = function() {
const isLoaded = window.loadMathJax();
if (isLoaded) {
console.log('Styles loaded!')
}
const el = window.document.getElementById('content-text');
if (el) {
const options = {
htmlTags: true
};
const html = window.render(text, options);
el.outerHTML = html;
}
};
</script>
</head>
<body>
<div id="content"><div id="content-text"></div></div>
</body>
</html>
"""
def latex2html(latex_code):
latex_code = '\\[' + latex_code + '\\]'
latex_code = latex_code.replace('\left(', '(').replace('\\right)', ')').replace('\left[', '[').replace('\\right]', ']').replace('\left{', '{').replace('\\right}', '}').replace('\left|', '|').replace('\\right|', '|').replace('\left.', '.').replace('\\right.', '.')
latex_code = latex_code.replace('"', '``').replace('$', '')
latex_code_list = latex_code.split('\n')
gt= ''
for out in latex_code_list:
gt += '"' + out.replace('\\', '\\\\') + r'\n' + '"' + '+' + '\n'
gt = gt[:-2]
lines = template_html.split("const text =")
new_web = lines[0] + 'const text =' + gt + lines[1]
return new_web
def load_model_and_processor(cfg_path):
args = argparse.Namespace(cfg_path=cfg_path, options=None)
cfg = Config(args)
task = tasks.setup_task(cfg)
model = task.build_model(cfg)
vis_processor = load_processor('formula_image_eval', cfg.config.datasets.formula_rec_eval.vis_processor.eval)
return model, vis_processor
@spaces.GPU
def recognize_image(input_img, model_type):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if model_type == "base":
model = model_base.to(device)
elif model_type == "small":
model = model_small.to(device)
else:
model = model_tiny.to(device)
model.eval()
if len(input_img.shape) == 3:
input_img = input_img[:, :, ::-1].copy()
img = Image.fromarray(input_img)
image = vis_processor(img).unsqueeze(0).to(device)
output = model.generate({"image": image})
latex_code = output["pred_str"][0]
html_code = latex2html(latex_code)
encoded_html = base64.b64encode(html_code.encode('utf-8')).decode('utf-8')
iframe_src = f"data:text/html;base64,{encoded_html}"
iframe = f'<iframe src="{iframe_src}" width="100%" height="600px"></iframe>'
return latex_code, iframe
def gradio_reset():
return gr.update(value=None), gr.update(value=None), gr.update(value=None)
if __name__ == "__main__":
root_path = os.path.abspath(os.getcwd())
# == load model ==
print("load tiny model ...")
model_tiny, vis_processor = load_model_and_processor(os.path.join(root_path, "cfg_tiny.yaml"))
print("load small model ...")
model_small, vis_processor = load_model_and_processor(os.path.join(root_path, "cfg_small.yaml"))
print("load base model ...")
model_base, vis_processor = load_model_and_processor(os.path.join(root_path, "cfg_base.yaml"))
print("== load all models done. ==")
# == load model ==
with open("header.html", "r") as file:
header = file.read()
with gr.Blocks() as demo:
gr.HTML(header)
with gr.Row():
with gr.Column():
model_type = gr.Radio(
choices=["tiny", "small", "base"],
value="tiny",
label="Model Type",
interactive=True,
)
input_img = gr.Image(label=" ", interactive=True)
with gr.Row():
clear = gr.Button("Clear")
predict = gr.Button(value="Recognize", interactive=True, variant="primary")
with gr.Accordion("Examples:"):
example_root = os.path.join(os.path.dirname(__file__), "examples")
gr.Examples(
examples=[os.path.join(example_root, _) for _ in os.listdir(example_root) if
_.endswith("png")],
inputs=input_img,
)
with gr.Column():
gr.Button(value="Predict Result:", interactive=False)
pred_latex = gr.Textbox(label='Predict Latex', interactive=False)
output_html = gr.HTML(label="Rendered html", show_label=True)
clear.click(gradio_reset, inputs=None, outputs=[input_img, pred_latex, output_html])
predict.click(recognize_image, inputs=[input_img, model_type], outputs=[pred_latex, output_html])
demo.launch(server_name="0.0.0.0", server_port=7860, debug=True)