RapidOCRDemo / app.py
SWHL's picture
fix(app): Remove the single visualize function
8095db7
raw
history blame
6.88 kB
# -*- encoding: utf-8 -*-
# @Author: SWHL
# @Contact: liekkaskono@163.com
import time
from pathlib import Path
import cv2
import numpy as np
import pandas as pd
import streamlit as st
from PIL import Image
from rapidocr_onnxruntime import RapidOCR, VisRes
from streamlit_image_select import image_select
vis = VisRes()
font_dict = {
"ch": "chinese_cht.ttf",
"japan": "japan.ttc",
"korean": "korean.ttf",
"en": "chinese_cht.ttf",
}
def init_sidebar():
st.session_state["params"] = {}
st.sidebar.markdown(
"### [🛠️ Parameter Settings](https://github.com/RapidAI/RapidOCR/wiki/config_parameter)"
)
box_thresh = st.sidebar.slider(
"box_thresh",
min_value=0.0,
max_value=1.0,
value=0.5,
step=0.1,
help="检测到的框是文本的概率,值越大,框中是文本的概率就越大。存在漏检时,调低该值。取值范围:[0, 1.0],默认值为0.5",
)
st.session_state["params"]["box_thresh"] = box_thresh
unclip_ratio = st.sidebar.slider(
"unclip_ratio",
min_value=1.5,
max_value=2.0,
value=1.6,
step=0.1,
help="控制文本检测框的大小,值越大,检测框整体越大。在出现框截断文字的情况,调大该值。取值范围:[1.5, 2.0],默认值为1.6",
)
st.session_state["params"]["unclip_ratio"] = unclip_ratio
text_score = st.sidebar.slider(
"text_score",
min_value=0.0,
max_value=1.0,
value=0.5,
step=0.1,
help="文本识别结果是正确的置信度,值越大,显示出的识别结果更准确。存在漏检时,调低该值。取值范围:[0, 1.0],默认值为0.5",
)
st.session_state["params"]["text_score"] = text_score
with st.sidebar.container():
img_path = image_select(
label="Examples(click to select):",
images=examples,
key="equation_default",
use_container_width=True,
)
img = cv2.imread(img_path)
st.session_state["img"] = img
def inference(
text_det=None,
text_rec=None,
):
img = st.session_state.get("img")
box_thresh = st.session_state["params"].get("box_thresh")
unclip_ratio = st.session_state["params"].get("unclip_ratio")
text_score = st.session_state["params"].get("text_score")
det_model_path = str(Path("models") / "text_det" / text_det)
rec_model_path = str(Path("models") / "text_rec" / text_rec)
if (
"v2" in rec_model_path
or "korean" in rec_model_path
or "japan" in rec_model_path
):
rec_image_shape = [3, 32, 320]
else:
rec_image_shape = [3, 48, 320]
rapid_ocr = RapidOCR(
det_model_path=det_model_path,
rec_model_path=rec_model_path,
rec_img_shape=rec_image_shape,
)
if "ch" in rec_model_path or "en" in rec_model_path:
lan_name = "ch"
elif "japan" in rec_model_path:
lan_name = "japan"
elif "korean" in rec_model_path:
lan_name = "korean"
else:
lan_name = "ch"
ocr_result, infer_elapse = rapid_ocr(
img, box_thresh=box_thresh, unclip_ratio=unclip_ratio, text_score=text_score
)
if not ocr_result or not infer_elapse:
return None, None, None
det_cost, cls_cost, rec_cost = infer_elapse
elapse = f"- `det cost`: {det_cost:.5f}\n - `cls cost`: {cls_cost:.5f}\n - `rec cost`: {rec_cost:.5f}"
dt_boxes, rec_res, scores = list(zip(*ocr_result))
font_path = Path("fonts") / font_dict.get(lan_name)
vis_img = vis(img, dt_boxes, rec_res, scores, font_path=str(font_path))
vis_img = vis_img[..., ::-1]
out_df = pd.DataFrame(
[[rec, score] for rec, score in zip(rec_res, scores)],
columns=("Rec", "Score"),
)
return vis_img, out_df, elapse
def tips(txt: str, wait_time: int = 2, icon: str = "🎉"):
st.toast(txt, icon=icon)
time.sleep(wait_time)
if __name__ == "__main__":
st.markdown(
"<h1 style='text-align: center;'><a href='https://github.com/RapidAI/RapidOCR' style='text-decoration: none'>Rapid⚡OCR</a></h1>",
unsafe_allow_html=True,
)
st.markdown(
"""
<p align="center">
<a href=""><img src="https://img.shields.io/badge/Python->=3.6,<3.13-aff.svg"></a>
<a href=""><img src="https://img.shields.io/badge/OS-Linux%2C%20Win%2C%20Mac-pink.svg"></a>
<a href="https://pepy.tech/project/rapidocr_onnxruntime"><img src="https://static.pepy.tech/personalized-badge/rapidocr_onnxruntime?period=total&units=abbreviation&left_color=grey&right_color=blue&left_text=Downloads%20Ort"></a>
<a href="https://pypi.org/project/rapidocr-onnxruntime/"><img alt="PyPI" src="https://img.shields.io/pypi/v/rapidocr-onnxruntime"></a>
</p>
""",
unsafe_allow_html=True,
)
examples = [
"images/1.jpg",
"images/ch_en_num.jpg",
"images/air_ticket.jpg",
"images/car_plate.jpeg",
"images/train_ticket.jpeg",
"images/japan_2.jpg",
"images/korean_1.jpg",
]
init_sidebar()
menu_det, menu_rec = st.columns([1, 1])
det_models = [
"ch_PP-OCRv4_det_infer.onnx",
"ch_PP-OCRv3_det_infer.onnx",
"ch_PP-OCRv2_det_infer.onnx",
"ch_ppocr_server_v2.0_det_infer.onnx",
"ch_PP-OCRv4_det_server_infer.onnx",
]
select_det = menu_det.selectbox("Det model:", det_models)
rec_models = [
"ch_PP-OCRv4_rec_infer.onnx",
"ch_PP-OCRv3_rec_infer.onnx",
"ch_PP-OCRv2_rec_infer.onnx",
"ch_ppocr_server_v2.0_rec_infer.onnx",
"en_PP-OCRv3_rec_infer.onnx",
"en_number_mobile_v2.0_rec_infer.onnx",
"korean_mobile_v2.0_rec_infer.onnx",
"japan_rec_crnn_v2.onnx",
]
select_rec = menu_rec.selectbox("Rec model:", rec_models)
with st.form("my-form", clear_on_submit=True):
img_file_buffer = st.file_uploader(
"Upload an image",
accept_multiple_files=False,
label_visibility="visible",
type=["png", "jpg", "jpeg", "bmp"],
)
submit = st.form_submit_button("Upload")
if submit and img_file_buffer is not None:
image = Image.open(img_file_buffer)
img = np.array(image)
st.session_state["img"] = img
if st.session_state["img"] is not None:
out_img, out_json, elapse = inference(select_det, select_rec)
if all(v is not None for v in [out_img, out_json, elapse]):
st.markdown("#### Visualize:")
st.image(out_img)
st.markdown("### Rec Result:")
st.markdown(elapse)
st.dataframe(out_json, use_container_width=True)
else:
tips("识别结果为空", wait_time=5, icon="⚠️")