# -*- encoding: utf-8 -*- # @Author: SWHL # @Contact: liekkaskono@163.com import time from pathlib import Path import cv2 import numpy as np import pandas as pd import streamlit as st from PIL import Image from rapidocr_onnxruntime import RapidOCR, VisRes from streamlit_image_select import image_select vis = VisRes() font_dict = { "ch": "simfang.ttf", "japan": "japan.ttc", "korean": "korean.ttf", "en": "simfang.ttf", } def init_sidebar(): st.session_state["params"] = {} st.sidebar.markdown( "### [🛠️ Parameter Settings](https://github.com/RapidAI/RapidOCR/wiki/config_parameter)" ) box_thresh = st.sidebar.slider( "box_thresh", min_value=0.0, max_value=1.0, value=0.5, step=0.1, help="检测到的框是文本的概率,值越大,框中是文本的概率就越大。存在漏检时,调低该值。取值范围:[0, 1.0],默认值为0.5", ) st.session_state["params"]["box_thresh"] = box_thresh unclip_ratio = st.sidebar.slider( "unclip_ratio", min_value=1.5, max_value=2.0, value=1.6, step=0.1, help="控制文本检测框的大小,值越大,检测框整体越大。在出现框截断文字的情况,调大该值。取值范围:[1.5, 2.0],默认值为1.6", ) st.session_state["params"]["unclip_ratio"] = unclip_ratio text_score = st.sidebar.slider( "text_score", min_value=0.0, max_value=1.0, value=0.5, step=0.1, help="文本识别结果是正确的置信度,值越大,显示出的识别结果更准确。存在漏检时,调低该值。取值范围:[0, 1.0],默认值为0.5", ) st.session_state["params"]["text_score"] = text_score with st.sidebar.container(): img_path = image_select( label="Examples(click to select):", images=examples, key="equation_default", use_container_width=True, ) img = cv2.imread(img_path) st.session_state["img"] = img def inference( text_det=None, text_rec=None, ): img = st.session_state.get("img") box_thresh = st.session_state["params"].get("box_thresh") unclip_ratio = st.session_state["params"].get("unclip_ratio") text_score = st.session_state["params"].get("text_score") det_model_path = str(Path("models") / "text_det" / text_det) rec_model_path = str(Path("models") / "text_rec" / text_rec) if ( "v2" in rec_model_path or "korean" in rec_model_path or "japan" in rec_model_path ): rec_image_shape = [3, 32, 320] else: rec_image_shape = [3, 48, 320] rapid_ocr = RapidOCR( det_model_path=det_model_path, rec_model_path=rec_model_path, rec_img_shape=rec_image_shape, ) if "ch" in rec_model_path or "en" in rec_model_path: lan_name = "ch" elif "japan" in rec_model_path: lan_name = "japan" elif "korean" in rec_model_path: lan_name = "korean" else: lan_name = "ch" ocr_result, infer_elapse = rapid_ocr( img, box_thresh=box_thresh, unclip_ratio=unclip_ratio, text_score=text_score ) if not ocr_result or not infer_elapse: return None, None, None det_cost, cls_cost, rec_cost = infer_elapse elapse = f"- `det cost`: {det_cost:.5f}\n - `cls cost`: {cls_cost:.5f}\n - `rec cost`: {rec_cost:.5f}" dt_boxes, rec_res, scores = list(zip(*ocr_result)) font_path = Path("fonts") / font_dict.get(lan_name) print(font_path) vis_img = vis(img, dt_boxes, rec_res, scores, font_path=str(font_path)) vis_img = vis_img[..., ::-1] out_df = pd.DataFrame( [[rec, score] for rec, score in zip(rec_res, scores)], columns=("Rec", "Score"), ) return vis_img, out_df, elapse def tips(txt: str, wait_time: int = 2, icon: str = "🎉"): st.toast(txt, icon=icon) time.sleep(wait_time) if __name__ == "__main__": st.markdown( "

Rapid⚡OCR

", unsafe_allow_html=True, ) st.markdown( """

PyPI

""", unsafe_allow_html=True, ) examples = [ "images/1.jpg", "images/ch_en_num.jpg", "images/air_ticket.jpg", "images/car_plate.jpeg", "images/train_ticket.jpeg", "images/japan_2.jpg", "images/korean_1.jpg", ] init_sidebar() menu_det, menu_rec = st.columns([1, 1]) det_models = [ "ch_PP-OCRv4_det_infer.onnx", "ch_PP-OCRv3_det_infer.onnx", "ch_PP-OCRv2_det_infer.onnx", "ch_ppocr_server_v2.0_det_infer.onnx", "ch_PP-OCRv4_det_server_infer.onnx", ] select_det = menu_det.selectbox("Det model:", det_models) rec_models = [ "ch_PP-OCRv4_rec_infer.onnx", "ch_PP-OCRv3_rec_infer.onnx", "ch_PP-OCRv2_rec_infer.onnx", "ch_ppocr_server_v2.0_rec_infer.onnx", "en_PP-OCRv3_rec_infer.onnx", "en_number_mobile_v2.0_rec_infer.onnx", "korean_mobile_v2.0_rec_infer.onnx", "japan_rec_crnn_v2.onnx", ] select_rec = menu_rec.selectbox("Rec model:", rec_models) with st.form("my-form", clear_on_submit=True): img_file_buffer = st.file_uploader( "Upload an image", accept_multiple_files=False, label_visibility="visible", type=["png", "jpg", "jpeg", "bmp"], ) submit = st.form_submit_button("Upload") if submit and img_file_buffer is not None: image = Image.open(img_file_buffer) img = np.array(image) st.session_state["img"] = img if st.session_state["img"] is not None: out_img, out_json, elapse = inference(select_det, select_rec) if all(v is not None for v in [out_img, out_json, elapse]): st.markdown("#### Visualize:") st.image(out_img) st.markdown("### Rec Result:") st.markdown(elapse) st.dataframe(out_json, use_container_width=True) else: tips("识别结果为空", wait_time=5, icon="⚠️")