Spaces:
Running
Running
# -*- encoding: utf-8 -*- | |
# @Author: SWHL | |
# @Contact: liekkaskono@163.com | |
import time | |
from pathlib import Path | |
import cv2 | |
import numpy as np | |
import pandas as pd | |
import streamlit as st | |
from PIL import Image | |
from rapidocr_onnxruntime import RapidOCR, VisRes | |
from streamlit_image_select import image_select | |
vis = VisRes() | |
font_dict = { | |
"ch": "chinese_cht.ttf", | |
"japan": "japan.ttc", | |
"korean": "korean.ttf", | |
"en": "chinese_cht.ttf", | |
} | |
def init_sidebar(): | |
st.session_state["params"] = {} | |
st.sidebar.markdown( | |
"### [🛠️ Parameter Settings](https://github.com/RapidAI/RapidOCR/wiki/config_parameter)" | |
) | |
box_thresh = st.sidebar.slider( | |
"box_thresh", | |
min_value=0.0, | |
max_value=1.0, | |
value=0.5, | |
step=0.1, | |
help="检测到的框是文本的概率,值越大,框中是文本的概率就越大。存在漏检时,调低该值。取值范围:[0, 1.0],默认值为0.5", | |
) | |
st.session_state["params"]["box_thresh"] = box_thresh | |
unclip_ratio = st.sidebar.slider( | |
"unclip_ratio", | |
min_value=1.5, | |
max_value=2.0, | |
value=1.6, | |
step=0.1, | |
help="控制文本检测框的大小,值越大,检测框整体越大。在出现框截断文字的情况,调大该值。取值范围:[1.5, 2.0],默认值为1.6", | |
) | |
st.session_state["params"]["unclip_ratio"] = unclip_ratio | |
text_score = st.sidebar.slider( | |
"text_score", | |
min_value=0.0, | |
max_value=1.0, | |
value=0.5, | |
step=0.1, | |
help="文本识别结果是正确的置信度,值越大,显示出的识别结果更准确。存在漏检时,调低该值。取值范围:[0, 1.0],默认值为0.5", | |
) | |
st.session_state["params"]["text_score"] = text_score | |
with st.sidebar.container(): | |
img_path = image_select( | |
label="Examples(click to select):", | |
images=examples, | |
key="equation_default", | |
use_container_width=True, | |
) | |
img = cv2.imread(img_path) | |
st.session_state["img"] = img | |
def inference( | |
text_det=None, | |
text_rec=None, | |
): | |
img = st.session_state.get("img") | |
box_thresh = st.session_state["params"].get("box_thresh") | |
unclip_ratio = st.session_state["params"].get("unclip_ratio") | |
text_score = st.session_state["params"].get("text_score") | |
det_model_path = str(Path("models") / "text_det" / text_det) | |
rec_model_path = str(Path("models") / "text_rec" / text_rec) | |
if ( | |
"v2" in rec_model_path | |
or "korean" in rec_model_path | |
or "japan" in rec_model_path | |
): | |
rec_image_shape = [3, 32, 320] | |
else: | |
rec_image_shape = [3, 48, 320] | |
rapid_ocr = RapidOCR( | |
det_model_path=det_model_path, | |
rec_model_path=rec_model_path, | |
rec_img_shape=rec_image_shape, | |
) | |
if "ch" in rec_model_path or "en" in rec_model_path: | |
lan_name = "ch" | |
elif "japan" in rec_model_path: | |
lan_name = "japan" | |
elif "korean" in rec_model_path: | |
lan_name = "korean" | |
else: | |
lan_name = "ch" | |
ocr_result, infer_elapse = rapid_ocr( | |
img, box_thresh=box_thresh, unclip_ratio=unclip_ratio, text_score=text_score | |
) | |
if not ocr_result or not infer_elapse: | |
return None, None, None | |
det_cost, cls_cost, rec_cost = infer_elapse | |
elapse = f"- `det cost`: {det_cost:.5f}\n - `cls cost`: {cls_cost:.5f}\n - `rec cost`: {rec_cost:.5f}" | |
dt_boxes, rec_res, scores = list(zip(*ocr_result)) | |
font_path = Path("fonts") / font_dict.get(lan_name) | |
vis_img = vis(img, dt_boxes, rec_res, scores, font_path=str(font_path)) | |
vis_img = vis_img[..., ::-1] | |
out_df = pd.DataFrame( | |
[[rec, score] for rec, score in zip(rec_res, scores)], | |
columns=("Rec", "Score"), | |
) | |
return vis_img, out_df, elapse | |
def tips(txt: str, wait_time: int = 2, icon: str = "🎉"): | |
st.toast(txt, icon=icon) | |
time.sleep(wait_time) | |
if __name__ == "__main__": | |
st.markdown( | |
"<h1 style='text-align: center;'><a href='https://github.com/RapidAI/RapidOCR' style='text-decoration: none'>Rapid⚡OCR</a></h1>", | |
unsafe_allow_html=True, | |
) | |
st.markdown( | |
""" | |
<p align="center"> | |
<a href=""><img src="https://img.shields.io/badge/Python->=3.6,<3.13-aff.svg"></a> | |
<a href=""><img src="https://img.shields.io/badge/OS-Linux%2C%20Win%2C%20Mac-pink.svg"></a> | |
<a href="https://pepy.tech/project/rapidocr_onnxruntime"><img src="https://static.pepy.tech/personalized-badge/rapidocr_onnxruntime?period=total&units=abbreviation&left_color=grey&right_color=blue&left_text=Downloads%20Ort"></a> | |
<a href="https://pypi.org/project/rapidocr-onnxruntime/"><img alt="PyPI" src="https://img.shields.io/pypi/v/rapidocr-onnxruntime"></a> | |
</p> | |
""", | |
unsafe_allow_html=True, | |
) | |
examples = [ | |
"images/1.jpg", | |
"images/ch_en_num.jpg", | |
"images/air_ticket.jpg", | |
"images/car_plate.jpeg", | |
"images/train_ticket.jpeg", | |
"images/japan_2.jpg", | |
"images/korean_1.jpg", | |
] | |
init_sidebar() | |
menu_det, menu_rec = st.columns([1, 1]) | |
det_models = [ | |
"ch_PP-OCRv4_det_infer.onnx", | |
"ch_PP-OCRv3_det_infer.onnx", | |
"ch_PP-OCRv2_det_infer.onnx", | |
"ch_ppocr_server_v2.0_det_infer.onnx", | |
"ch_PP-OCRv4_det_server_infer.onnx", | |
] | |
select_det = menu_det.selectbox("Det model:", det_models) | |
rec_models = [ | |
"ch_PP-OCRv4_rec_infer.onnx", | |
"ch_PP-OCRv3_rec_infer.onnx", | |
"ch_PP-OCRv2_rec_infer.onnx", | |
"ch_ppocr_server_v2.0_rec_infer.onnx", | |
"en_PP-OCRv3_rec_infer.onnx", | |
"en_number_mobile_v2.0_rec_infer.onnx", | |
"korean_mobile_v2.0_rec_infer.onnx", | |
"japan_rec_crnn_v2.onnx", | |
] | |
select_rec = menu_rec.selectbox("Rec model:", rec_models) | |
with st.form("my-form", clear_on_submit=True): | |
img_file_buffer = st.file_uploader( | |
"Upload an image", | |
accept_multiple_files=False, | |
label_visibility="visible", | |
type=["png", "jpg", "jpeg", "bmp"], | |
) | |
submit = st.form_submit_button("Upload") | |
if submit and img_file_buffer is not None: | |
image = Image.open(img_file_buffer) | |
img = np.array(image) | |
st.session_state["img"] = img | |
if st.session_state["img"] is not None: | |
out_img, out_json, elapse = inference(select_det, select_rec) | |
if all(v is not None for v in [out_img, out_json, elapse]): | |
st.markdown("#### Visualize:") | |
st.image(out_img) | |
st.markdown("### Rec Result:") | |
st.markdown(elapse) | |
st.dataframe(out_json, use_container_width=True) | |
else: | |
tips("识别结果为空", wait_time=5, icon="⚠️") | |