# -*- encoding: utf-8 -*- # @Author: SWHL # @Contact: liekkaskono@163.com import time from pathlib import Path import cv2 import numpy as np import pandas as pd import streamlit as st from PIL import Image from rapidocr_onnxruntime import RapidOCR, VisRes from streamlit_image_select import image_select vis = VisRes() font_dict = { "ch": "chinese_cht.ttf", "japan": "japan.ttc", "korean": "korean.ttf", "en": "chinese_cht.ttf", } def init_sidebar(): st.session_state["params"] = {} st.sidebar.markdown( "### [🛠️ Parameter Settings](https://github.com/RapidAI/RapidOCR/wiki/config_parameter)" ) box_thresh = st.sidebar.slider( "box_thresh", min_value=0.0, max_value=1.0, value=0.5, step=0.1, help="检测到的框是文本的概率,值越大,框中是文本的概率就越大。存在漏检时,调低该值。取值范围:[0, 1.0],默认值为0.5", ) st.session_state["params"]["box_thresh"] = box_thresh unclip_ratio = st.sidebar.slider( "unclip_ratio", min_value=1.5, max_value=2.0, value=1.6, step=0.1, help="控制文本检测框的大小,值越大,检测框整体越大。在出现框截断文字的情况,调大该值。取值范围:[1.5, 2.0],默认值为1.6", ) st.session_state["params"]["unclip_ratio"] = unclip_ratio text_score = st.sidebar.slider( "text_score", min_value=0.0, max_value=1.0, value=0.5, step=0.1, help="文本识别结果是正确的置信度,值越大,显示出的识别结果更准确。存在漏检时,调低该值。取值范围:[0, 1.0],默认值为0.5", ) st.session_state["params"]["text_score"] = text_score with st.sidebar.container(): img_path = image_select( label="Examples(click to select):", images=examples, key="equation_default", use_container_width=True, ) img = cv2.imread(img_path) st.session_state["img"] = img def inference( text_det=None, text_rec=None, ): img = st.session_state.get("img") box_thresh = st.session_state["params"].get("box_thresh") unclip_ratio = st.session_state["params"].get("unclip_ratio") text_score = st.session_state["params"].get("text_score") det_model_path = str(Path("models") / "text_det" / text_det) rec_model_path = str(Path("models") / "text_rec" / text_rec) if ( "v2" in rec_model_path or "korean" in rec_model_path or "japan" in rec_model_path ): rec_image_shape = [3, 32, 320] else: rec_image_shape = [3, 48, 320] rapid_ocr = RapidOCR( det_model_path=det_model_path, rec_model_path=rec_model_path, rec_img_shape=rec_image_shape, ) if "ch" in rec_model_path or "en" in rec_model_path: lan_name = "ch" elif "japan" in rec_model_path: lan_name = "japan" elif "korean" in rec_model_path: lan_name = "korean" else: lan_name = "ch" ocr_result, infer_elapse = rapid_ocr( img, box_thresh=box_thresh, unclip_ratio=unclip_ratio, text_score=text_score ) if not ocr_result or not infer_elapse: return None, None, None det_cost, cls_cost, rec_cost = infer_elapse elapse = f"- `det cost`: {det_cost:.5f}\n - `cls cost`: {cls_cost:.5f}\n - `rec cost`: {rec_cost:.5f}" dt_boxes, rec_res, scores = list(zip(*ocr_result)) font_path = Path("fonts") / font_dict.get(lan_name) vis_img = vis(img, dt_boxes, rec_res, scores, font_path=str(font_path)) vis_img = vis_img[..., ::-1] out_df = pd.DataFrame( [[rec, score] for rec, score in zip(rec_res, scores)], columns=("Rec", "Score"), ) return vis_img, out_df, elapse def tips(txt: str, wait_time: int = 2, icon: str = "🎉"): st.toast(txt, icon=icon) time.sleep(wait_time) if __name__ == "__main__": st.markdown( "