wufan commited on
Commit
8cf32c2
·
verified ·
1 Parent(s): 8a8d449

add latex render

Browse files
Files changed (1) hide show
  1. app.py +163 -105
app.py CHANGED
@@ -1,106 +1,164 @@
1
- import os
2
- os.system('pip install -U transformers==4.44.2')
3
- import sys
4
- import shutil
5
- import torch
6
- import argparse
7
- import gradio as gr
8
- import numpy as np
9
- from PIL import Image
10
- from huggingface_hub import snapshot_download
11
- import spaces
12
-
13
- # == download weights ==
14
- tiny_model_dir = snapshot_download('wanderkid/unimernet_tiny', local_dir='./models/unimernet_tiny')
15
- small_model_dir = snapshot_download('wanderkid/unimernet_small', local_dir='./models/unimernet_small')
16
- base_model_dir = snapshot_download('wanderkid/unimernet_base', local_dir='./models/unimernet_base')
17
- os.system("ls -l models/unimernet_tiny")
18
- # os.system(f"sed -i 's/MODEL_DIR/{tiny_model_dir}/g' cfg_tiny.yaml")
19
- # os.system(f"sed -i 's/MODEL_DIR/{small_model_dir}/g' cfg_small.yaml")
20
- # os.system(f"sed -i 's/MODEL_DIR/{base_model_dir}/g' cfg_base.yaml")
21
- # root_path = os.path.abspath(os.getcwd())
22
- # os.makedirs(os.path.join(root_path, "models"), exist_ok=True)
23
- # shutil.move(tiny_model_dir, os.path.join(root_path, "models", "unimernet_tiny"))
24
- # shutil.move(small_model_dir, os.path.join(root_path, "models", "unimernet_small"))
25
- # shutil.move(base_model_dir, os.path.join(root_path, "models", "unimernet_base"))
26
- # == download weights ==
27
-
28
- sys.path.insert(0, os.path.join(os.getcwd(), ".."))
29
- from unimernet.common.config import Config
30
- import unimernet.tasks as tasks
31
- from unimernet.processors import load_processor
32
-
33
-
34
- def load_model_and_processor(cfg_path):
35
- args = argparse.Namespace(cfg_path=cfg_path, options=None)
36
- cfg = Config(args)
37
- task = tasks.setup_task(cfg)
38
- model = task.build_model(cfg)
39
- vis_processor = load_processor('formula_image_eval', cfg.config.datasets.formula_rec_eval.vis_processor.eval)
40
- return model, vis_processor
41
-
42
- @spaces.GPU
43
- def recognize_image(input_img, model_type):
44
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
45
- if model_type == "base":
46
- model = model_base.to(device)
47
- elif model_type == "small":
48
- model = model_small.to(device)
49
- else:
50
- model = model_tiny.to(device)
51
-
52
- if len(input_img.shape) == 3:
53
- input_img = input_img[:, :, ::-1].copy()
54
-
55
- img = Image.fromarray(input_img)
56
- image = vis_processor(img).unsqueeze(0).to(device)
57
- output = model.generate({"image": image})
58
- latex_code = output["pred_str"][0]
59
- return latex_code
60
-
61
- def gradio_reset():
62
- return gr.update(value=None), gr.update(value=None)
63
-
64
-
65
- if __name__ == "__main__":
66
- root_path = os.path.abspath(os.getcwd())
67
- # == load model ==
68
- model_tiny, vis_processor = load_model_and_processor(os.path.join(root_path, "cfg_tiny.yaml"))
69
- model_small, vis_processor = load_model_and_processor(os.path.join(root_path, "cfg_small.yaml"))
70
- model_base, vis_processor = load_model_and_processor(os.path.join(root_path, "cfg_base.yaml"))
71
- print("== load all models ==")
72
- # == load model ==
73
-
74
- with open("header.html", "r") as file:
75
- header = file.read()
76
- with gr.Blocks() as demo:
77
- gr.HTML(header)
78
-
79
- with gr.Row():
80
- with gr.Column():
81
- model_type = gr.Radio(
82
- choices=["tiny", "small", "base"],
83
- value="tiny",
84
- label="Model Type",
85
- interactive=True,
86
- )
87
- input_img = gr.Image(label=" ", interactive=True)
88
- with gr.Row():
89
- clear = gr.Button("Clear")
90
- predict = gr.Button(value="Recognize", interactive=True, variant="primary")
91
-
92
- with gr.Accordion("Examples:"):
93
- example_root = os.path.join(os.path.dirname(__file__), "examples")
94
- gr.Examples(
95
- examples=[os.path.join(example_root, _) for _ in os.listdir(example_root) if
96
- _.endswith("png")],
97
- inputs=input_img,
98
- )
99
- with gr.Column():
100
- gr.Button(value="Predict Latex:", interactive=False)
101
- pred_latex = gr.Textbox(label='Latex', interactive=False)
102
-
103
- clear.click(gradio_reset, inputs=None, outputs=[input_img, pred_latex])
104
- predict.click(recognize_image, inputs=[input_img, model_type], outputs=[pred_latex])
105
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  demo.launch(server_name="0.0.0.0", server_port=7860, debug=True)
 
1
+ import os
2
+ os.system('pip install -U transformers==4.44.2')
3
+ import sys
4
+ import shutil
5
+ import torch
6
+ import base64
7
+ import argparse
8
+ import gradio as gr
9
+ import numpy as np
10
+ from PIL import Image
11
+ from huggingface_hub import snapshot_download
12
+ import spaces
13
+
14
+ # == download weights ==
15
+ tiny_model_dir = snapshot_download('wanderkid/unimernet_tiny', local_dir='./models/unimernet_tiny')
16
+ small_model_dir = snapshot_download('wanderkid/unimernet_small', local_dir='./models/unimernet_small')
17
+ base_model_dir = snapshot_download('wanderkid/unimernet_base', local_dir='./models/unimernet_base')
18
+ # == download weights ==
19
+
20
+ sys.path.insert(0, os.path.join(os.getcwd(), ".."))
21
+ from unimernet.common.config import Config
22
+ import unimernet.tasks as tasks
23
+ from unimernet.processors import load_processor
24
+
25
+
26
+ template_html = """<!DOCTYPE html>
27
+ <html lang="en" data-lt-installed="true"><head>
28
+ <meta charset="UTF-8">
29
+ <title>Title</title>
30
+ <script>
31
+ const text =
32
+ </script>
33
+ <style>
34
+ #content {
35
+ max-width: 800px;
36
+ margin: auto;
37
+ }
38
+ </style>
39
+ <script>
40
+ let script = document.createElement('script');
41
+ script.src = "https://cdn.jsdelivr.net/npm/mathpix-markdown-it@1.3.6/es5/bundle.js";
42
+ document.head.append(script);
43
+
44
+ script.onload = function() {
45
+ const isLoaded = window.loadMathJax();
46
+ if (isLoaded) {
47
+ console.log('Styles loaded!')
48
+ }
49
+
50
+ const el = window.document.getElementById('content-text');
51
+ if (el) {
52
+ const options = {
53
+ htmlTags: true
54
+ };
55
+ const html = window.render(text, options);
56
+ el.outerHTML = html;
57
+ }
58
+ };
59
+ </script>
60
+ </head>
61
+ <body>
62
+ <div id="content"><div id="content-text"></div></div>
63
+ </body>
64
+ </html>
65
+ """
66
+
67
+ def latex2html(latex_code):
68
+ right_num = latex_code.count('\\right')
69
+ left_num = latex_code.count('\left')
70
+
71
+ if right_num != left_num:
72
+ latex_code = latex_code.replace('\left(', '(').replace('\\right)', ')').replace('\left[', '[').replace('\\right]', ']').replace('\left{', '{').replace('\\right}', '}').replace('\left|', '|').replace('\\right|', '|').replace('\left.', '.').replace('\\right.', '.')
73
+
74
+ latex_code = latex_code.replace('"', '``').replace('$', '')
75
+
76
+ latex_code_list = latex_code.split('\n')
77
+ gt= ''
78
+ for out in latex_code_list:
79
+ gt += '"' + out.replace('\\', '\\\\') + r'\n' + '"' + '+' + '\n'
80
+
81
+ gt = gt[:-2]
82
+
83
+ lines = template_html.split("const text =")
84
+ new_web = lines[0] + 'const text =' + gt + lines[1]
85
+ return new_web
86
+
87
+ def load_model_and_processor(cfg_path):
88
+ args = argparse.Namespace(cfg_path=cfg_path, options=None)
89
+ cfg = Config(args)
90
+ task = tasks.setup_task(cfg)
91
+ model = task.build_model(cfg)
92
+ vis_processor = load_processor('formula_image_eval', cfg.config.datasets.formula_rec_eval.vis_processor.eval)
93
+ return model, vis_processor
94
+
95
+ @spaces.GPU
96
+ def recognize_image(input_img, model_type):
97
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
98
+ if model_type == "base":
99
+ model = model_base.to(device)
100
+ elif model_type == "small":
101
+ model = model_small.to(device)
102
+ else:
103
+ model = model_tiny.to(device)
104
+
105
+ if len(input_img.shape) == 3:
106
+ input_img = input_img[:, :, ::-1].copy()
107
+
108
+ img = Image.fromarray(input_img)
109
+ image = vis_processor(img).unsqueeze(0).to(device)
110
+ output = model.generate({"image": image})
111
+ latex_code = output["pred_str"][0]
112
+ html_code = latex2html(latex_code)
113
+ encoded_html = base64.b64encode(html_code.encode('utf-8')).decode('utf-8')
114
+ iframe_src = f"data:text/html;base64,{encoded_html}"
115
+ iframe = f'<iframe src="{iframe_src}" width="100%" height="600px"></iframe>'
116
+ return latex_code, iframe
117
+
118
+ def gradio_reset():
119
+ return gr.update(value=None), gr.update(value=None), gr.update(value=None)
120
+
121
+
122
+ if __name__ == "__main__":
123
+ root_path = os.path.abspath(os.getcwd())
124
+ # == load model ==
125
+ model_tiny, vis_processor = load_model_and_processor(os.path.join(root_path, "cfg_tiny.yaml"))
126
+ model_small, vis_processor = load_model_and_processor(os.path.join(root_path, "cfg_small.yaml"))
127
+ model_base, vis_processor = load_model_and_processor(os.path.join(root_path, "cfg_base.yaml"))
128
+ print("== load all models ==")
129
+ # == load model ==
130
+
131
+ with open("header.html", "r") as file:
132
+ header = file.read()
133
+ with gr.Blocks() as demo:
134
+ gr.HTML(header)
135
+
136
+ with gr.Row():
137
+ with gr.Column():
138
+ model_type = gr.Radio(
139
+ choices=["tiny", "small", "base"],
140
+ value="tiny",
141
+ label="Model Type",
142
+ interactive=True,
143
+ )
144
+ input_img = gr.Image(label=" ", interactive=True)
145
+ with gr.Row():
146
+ clear = gr.Button("Clear")
147
+ predict = gr.Button(value="Recognize", interactive=True, variant="primary")
148
+
149
+ with gr.Accordion("Examples:"):
150
+ example_root = os.path.join(os.path.dirname(__file__), "examples")
151
+ gr.Examples(
152
+ examples=[os.path.join(example_root, _) for _ in os.listdir(example_root) if
153
+ _.endswith("png")],
154
+ inputs=input_img,
155
+ )
156
+ with gr.Column():
157
+ gr.Button(value="Predict Result:", interactive=False)
158
+ pred_latex = gr.Textbox(label='Predict Latex', interactive=False)
159
+ output_html = gr.HTML(label='Output Html')
160
+
161
+ clear.click(gradio_reset, inputs=None, outputs=[input_img, pred_latex, output_html])
162
+ predict.click(recognize_image, inputs=[input_img, model_type], outputs=[pred_latex, output_html])
163
+
164
  demo.launch(server_name="0.0.0.0", server_port=7860, debug=True)