Spaces:
Running
Running
fix: add `no-det-res.jpg` to replace 'None' (which cause error when caching examples)
49d27ca
# coding: utf-8 | |
# [Pix2Text](https://github.com/breezedeus/pix2text): an Open-Source Alternative to Mathpix. | |
# Copyright (C) 2022-2024, [Breezedeus](https://www.breezedeus.com). | |
import os | |
import json | |
import functools | |
import random | |
import shutil | |
import string | |
import tempfile | |
import time | |
import zipfile | |
from pathlib import Path | |
import yaml | |
import gradio as gr | |
import numpy as np | |
from huggingface_hub import hf_hub_download | |
# from cnstd.utils import pil_to_numpy, imsave | |
from pix2text import Pix2Text | |
from pix2text.utils import set_logger, merge_line_texts | |
logger = set_logger() | |
LANGUAGES = yaml.safe_load(open('languages.yaml', 'r', encoding='utf-8'))['languages'] | |
OUTPUT_RESULT_DIR = Path('./output-results') | |
OUTPUT_RESULT_DIR.mkdir(exist_ok=True) | |
def prepare_mfd_model(): | |
target_fp = './yolov7-model/mfd-yolov7-epoch224-20230613.pt' | |
if os.path.exists(target_fp): | |
return target_fp | |
HF_TOKEN = os.environ.get('HF_TOKEN') | |
local_path = hf_hub_download( | |
repo_id='breezedeus/paid-models', | |
subfolder='cnstd/1.2', | |
filename='yolov7-model-20230613.zip', | |
repo_type="model", | |
cache_dir='./', | |
token=HF_TOKEN, | |
) | |
with zipfile.ZipFile(local_path) as zf: | |
zf.extractall('./') | |
return target_fp | |
def get_p2t_model(lan_list: list, mfd_model_name: str, mfr_model_name: str): | |
analyzer_config = {} | |
if 'yolov7_tiny' not in mfd_model_name: | |
mfd_fp = prepare_mfd_model() | |
analyzer_config = dict( # 声明 LayoutAnalyzer 的初始化参数 | |
model_name='mfd', | |
model_type='yolov7', # 表示使用的是 YoloV7 模型,而不是 YoloV7_Tiny 模型 | |
model_fp=mfd_fp, # 注:修改成你的模型文件所存储的路径 | |
) | |
formula_config = {} | |
if 'mfr-pro' in mfr_model_name: | |
formula_config = dict( # 声明 LayoutAnalyzer 的初始化参数 | |
model_name='mfr-pro', model_backend='onnx', | |
) | |
p2t = Pix2Text( | |
languages=lan_list, | |
analyzer_config=analyzer_config, | |
formula_config=formula_config, | |
) | |
return p2t | |
def latex_render(latex_str): | |
return f"$$\n{latex_str}\n$$" | |
# return latex_str | |
def recognize( | |
lang_list, mfd_model_name, mfr_model_name, rec_type, resized_shape, image_file | |
): | |
lang_list = [LANGUAGES[l] for l in lang_list] | |
p2t = get_p2t_model(lang_list, mfd_model_name, mfr_model_name) | |
if rec_type == 'mixed': | |
suffix = list(string.ascii_letters) | |
random.shuffle(suffix) | |
suffix = ''.join(suffix[:6]) | |
out_det_fp = f'out-det-{time.time()}-{suffix}.jpg' | |
# 如果 OUTPUT_RESULT_DIR 文件数量超过 1000,按时间删除最早的 1000 个文件 | |
if len(os.listdir(OUTPUT_RESULT_DIR)) > 1000: | |
for fp in sorted(os.listdir(OUTPUT_RESULT_DIR))[:1000]: | |
os.remove(OUTPUT_RESULT_DIR / fp) | |
outs = p2t.recognize( | |
image_file, | |
resized_shape=resized_shape, | |
save_analysis_res=OUTPUT_RESULT_DIR / out_det_fp, | |
) | |
# To get just the text contents, use: | |
only_text = merge_line_texts(outs, auto_line_break=True) | |
# return only_text, latex_render(only_text) | |
return only_text, str(OUTPUT_RESULT_DIR / out_det_fp) | |
elif rec_type == 'formula': | |
only_text = p2t.recognize_formula(image_file) | |
return latex_render(only_text), './docs/no-det-res.jpg' | |
elif rec_type == 'text': | |
only_text = p2t.recognize_text(image_file) | |
return only_text, './docs/no-det-res.jpg' | |
def example_func(lang_list, rec_type, image_file): | |
return recognize( | |
lang_list, | |
mfd_model_name='yolov7 (paid)', | |
mfr_model_name='mfr-pro (paid)', | |
rec_type=rec_type, | |
resized_shape=768, | |
image_file=image_file, | |
) | |
def main(): | |
langs = list(LANGUAGES.keys()) | |
langs.sort(key=lambda x: x.lower()) | |
title = 'Demo' | |
examples = [ | |
[ | |
['English'], | |
'mixed', | |
'docs/examples/mixed-en.jpg', | |
], | |
[ | |
['English', 'Chinese Simplified'], | |
'mixed', | |
'docs/examples/mixed-ch_sim.jpg', | |
], | |
[ | |
['English', 'Chinese Traditional'], | |
'mixed', | |
'docs/examples/mixed-ch_tra.jpg', | |
], | |
[ | |
['English', 'Vietnamese'], | |
'mixed', | |
'docs/examples/mixed-vietnamese.jpg', | |
], | |
[ | |
['English'], | |
'formula', | |
'docs/examples/formula1.png' | |
], | |
[ | |
['English'], | |
'formula', | |
'docs/examples/formula2.jpg' | |
], | |
[ | |
['English'], | |
'formula', | |
'docs/examples/hw-formula.png' | |
], | |
[ | |
['English', 'Chinese Simplified'], | |
'text', | |
'docs/examples/pure-text.jpg', | |
], | |
] | |
table_desc = """ | |
<div align="center"> | |
<img src="https://www.notion.so/image/https%3A%2F%2Fprod-files-secure.s3.us-west-2.amazonaws.com%2F9341931a-53f0-48e1-b026-0f1ad17b457c%2Fd0e55da8-36a5-482c-bea6-c389e2fcacea%2FUntitled.png?table=block&id=caebb37a-e23f-49ab-9687-2cba3801992e" width="120px"/> | |
[![Visitors](https://api.visitorbadge.io/api/visitors?path=https%3A%2F%2Fhuggingface.co%2Fspaces%2Fbreezedeus%2Fpix2text-demo&labelColor=%23697689&countColor=%23f5c791&style=flat&labelStyle=upper)](https://visitorbadge.io/status?path=https%3A%2F%2Fhuggingface.co%2Fspaces%2Fbreezedeus%2FCnOCR-Demo) | |
[![Discord](https://img.shields.io/discord/1200765964434821260?logo=discord&label=Discord)](https://discord.gg/H9FmDSMA) | |
| | | | |
| ------------------------------- | --------------------------------------- | | |
| 🏄 **Online Service** | [p2t.breezedeus.com](https://p2t.breezedeus.com) | | |
| 💬 **Discord** | [Pix2Text @ Discord](https://discord.gg/tGuFEybd) | | |
| 📀 **Code** | [Github](https://github.com/breezedeus/pix2text) | | |
| 🤗 **MFR Model** | [breezedeus/pix2text-mfr](https://huggingface.co/breezedeus/pix2text-mfr) | | |
| 📄 **More Infos** | [breezedeus.com/pix2text](https://www.breezedeus.com/pix2text) | | |
| 👨🏻💻 **Author** | [Breezedeus](https://www.breezedeus.com) | | |
If useful, please help to **star 🌟 [Pix2Text](https://github.com/breezedeus/pix2text)** 🙏 | |
</div> | |
""" | |
with gr.Blocks() as demo: | |
gr.HTML( | |
f'<h1 style="text-align: center; margin-bottom: 1rem;"><a href="https://github.com/breezedeus/pix2text" target="_blank">Pix2Text V1.0</a> {title}</h1>' | |
) | |
with gr.Row(equal_height=False): | |
with gr.Column(min_width=200, variant='panel', scale=3): | |
gr.Markdown('### Settings') | |
lang_list = gr.Dropdown( | |
label='Text Languages', | |
choices=langs, | |
value=['English', 'Chinese Simplified'], | |
multiselect=True, | |
# info='Which languages to be recognized as Texts.', | |
) | |
mfd_model_name = gr.Dropdown( | |
label='MFD Models', | |
choices=['yolov7_tiny (free)', 'yolov7 (paid)'], | |
value='yolov7 (paid)', | |
) | |
mfr_model_name = gr.Dropdown( | |
label='MFR Models', | |
choices=['mfr (free)', 'mfr-pro (paid)'], | |
value='mfr-pro (paid)', | |
) | |
rec_type = gr.Dropdown( | |
label='Image Type', | |
choices=['mixed', 'formula', 'text'], | |
value='mixed', | |
# info='Which type of image to be recognized.', | |
) | |
with gr.Accordion('More Options', open=False): | |
resized_shape = gr.Slider( | |
label='resized_shape', | |
minimum=512, | |
maximum=2048, | |
value=768, | |
step=32, | |
) | |
with gr.Column(scale=6, variant='compact'): | |
gr.Markdown('### Upload Image to be Recognized') | |
image_file = gr.Image( | |
label='Image', type="pil", image_mode='RGB', show_label=False | |
) | |
sub_btn = gr.Button("Submit", variant="primary") | |
with gr.Column(scale=2, variant='compact'): | |
gr.Markdown(table_desc) | |
with gr.Row(equal_height=False): | |
with gr.Column(scale=1, variant='compact'): | |
gr.Markdown('**Detection Result**') | |
det_result = gr.Image( | |
label='Detection Result', scale=1, show_label=False | |
) | |
with gr.Column(scale=1, variant='compact'): | |
gr.Markdown( | |
'**Recognition Results (Paste them into the [P2T Online Service](https://p2t.breezedeus.com) to view rendered outcomes)**' | |
) | |
rec_result = gr.Textbox( | |
label=f'Recognition Result ', | |
lines=5, | |
value='', | |
scale=1, | |
show_label=False, | |
show_copy_button=True, | |
) | |
# render_result = gr.Markdown(label=f'After Rendering', value='') | |
# rec_result.change(latex_render, rec_result, render_result) | |
sub_btn.click( | |
recognize, | |
inputs=[ | |
lang_list, | |
mfd_model_name, | |
mfr_model_name, | |
rec_type, | |
resized_shape, | |
image_file, | |
], | |
outputs=[rec_result, det_result], | |
) | |
gr.Examples( | |
label='Examples', | |
examples=examples, | |
inputs=[ | |
lang_list, | |
rec_type, | |
image_file, | |
], | |
outputs=[rec_result, det_result], | |
fn=example_func, | |
cache_examples=os.getenv('CACHE_EXAMPLES') == '1', | |
) | |
demo.queue(max_size=10) | |
demo.launch() | |
if __name__ == '__main__': | |
main() | |