SWHL commited on
Commit
5d6a0bb
1 Parent(s): 9792e33

Update files

Browse files
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ *.pyc
2
+
3
+ __pycache__/
app.py CHANGED
@@ -1,7 +1,115 @@
 
 
 
 
 
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
 
 
5
 
6
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- encoding: utf-8 -*-
2
+ import math
3
+ import random
4
+ from pathlib import Path
5
+ import time
6
+
7
+ import cv2
8
  import gradio as gr
9
+ from rapidocr_onnxruntime import TextSystem
10
+ import numpy as np
11
+ from PIL import Image, ImageDraw, ImageFont
12
+
13
+ text_sys = TextSystem('config.yaml')
14
+
15
+
16
+ def draw_ocr_box_txt(image, boxes, txts, font_path,
17
+ scores=None, text_score=0.5):
18
+ if not Path(font_path).exists():
19
+ raise FileNotFoundError(f'The {font_path} does not exists! \n'
20
+ f'Please download the file in the https://drive.google.com/file/d/1evWVX38EFNwTq_n5gTFgnlv8tdaNcyIA/view?usp=sharing')
21
+
22
+ h, w = image.height, image.width
23
+ img_left = image.copy()
24
+ img_right = Image.new('RGB', (w, h), (255, 255, 255))
25
+
26
+ random.seed(0)
27
+ draw_left = ImageDraw.Draw(img_left)
28
+ draw_right = ImageDraw.Draw(img_right)
29
+ for idx, (box, txt) in enumerate(zip(boxes, txts)):
30
+ if scores is not None and scores[idx] < text_score:
31
+ continue
32
+
33
+ color = (random.randint(0, 255),
34
+ random.randint(0, 255),
35
+ random.randint(0, 255))
36
+ draw_left.polygon(box, fill=color)
37
+ draw_right.polygon([box[0][0], box[0][1],
38
+ box[1][0], box[1][1],
39
+ box[2][0], box[2][1],
40
+ box[3][0], box[3][1]],
41
+ outline=color)
42
+
43
+ box_height = math.sqrt((box[0][0] - box[3][0])**2
44
+ + (box[0][1] - box[3][1])**2)
45
+
46
+ box_width = math.sqrt((box[0][0] - box[1][0])**2
47
+ + (box[0][1] - box[1][1])**2)
48
+
49
+ if box_height > 2 * box_width:
50
+ font_size = max(int(box_width * 0.9), 10)
51
+ font = ImageFont.truetype(font_path, font_size,
52
+ encoding="utf-8")
53
+ cur_y = box[0][1]
54
+ for c in txt:
55
+ char_size = font.getsize(c)
56
+ draw_right.text((box[0][0] + 3, cur_y), c,
57
+ fill=(0, 0, 0), font=font)
58
+ cur_y += char_size[1]
59
+ else:
60
+ font_size = max(int(box_height * 0.8), 10)
61
+ font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
62
+ draw_right.text([box[0][0], box[0][1]], txt,
63
+ fill=(0, 0, 0), font=font)
64
+
65
+ img_left = Image.blend(image, img_left, 0.5)
66
+ img_show = Image.new('RGB', (w * 2, h), (255, 255, 255))
67
+ img_show.paste(img_left, (0, 0, w, h))
68
+ img_show.paste(img_right, (w, 0, w * 2, h))
69
+ return np.array(img_show)
70
+
71
+
72
+ def visualize(image_path, boxes, rec_res, font_path="resources/fonts/FZYTK.TTF"):
73
+ image = Image.open(image_path)
74
+ txts = [rec_res[i][0] for i in range(len(rec_res))]
75
+ scores = [rec_res[i][1] for i in range(len(rec_res))]
76
+
77
+ draw_img = draw_ocr_box_txt(image, boxes,
78
+ txts, font_path,
79
+ scores,
80
+ text_score=0.5)
81
+
82
+ draw_img_save = Path("./inference_results/")
83
+ if not draw_img_save.exists():
84
+ draw_img_save.mkdir(parents=True, exist_ok=True)
85
+
86
+ time_stamp = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(time.time()))
87
+ image_save = str(draw_img_save / f'{time_stamp}_{Path(image_path).name}')
88
+ cv2.imwrite(image_save, draw_img[:, :, ::-1])
89
+ return image_save
90
+
91
 
92
+ def inference(img):
93
+ img_path = img.name
94
+ img = cv2.imread(img_path)
95
+ dt_boxes, rec_res = text_sys(img)
96
+ img_save_path = visualize(img_path, dt_boxes, rec_res)
97
+ return img_save_path, rec_res
98
 
99
+ title = 'Rapid🗲OCR Demo (捷智OCR)'
100
+ description = 'Gradio demo for RapidOCR. Github Repo: https://github.com/RapidAI/RapidOCR'
101
+ article = "<p style='text-align: center'> Completely open source, free and support offline deployment of multi-platform and multi-language OCR SDK <a href='https://github.com/RapidAI/RapidOCR'>Github Repo</a></p>"
102
+ css = ".output_image, .input_image {height: 40rem !important; width: 100% !important;}"
103
+ gr.Interface(
104
+ inference,
105
+ inputs=gr.inputs.Image(type='file', label='Input'),
106
+ outputs=[
107
+ gr.outputs.Image(type='file', label='Output_image'),
108
+ gr.outputs.Textbox(type='text', label='Output_text')
109
+ ],
110
+ title=title,
111
+ description=description,
112
+ article=article,
113
+ css=css,
114
+ allow_flagging='never',
115
+ ).launch(debug=True, enable_queue=True)
config.yaml ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Global:
2
+ text_score: 0.5
3
+ use_angle_cls: true
4
+ print_verbose: true
5
+ min_height: 30
6
+ width_height_ratio: 8
7
+
8
+ Det:
9
+ module_name: ch_ppocr_v3_det
10
+ class_name: TextDetector
11
+ model_path: resources/models/ch_PP-OCRv3_det_infer.onnx
12
+
13
+ use_cuda: false
14
+ # Details of the params: https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html
15
+ CUDAExecutionProvider:
16
+ device_id: 0
17
+ arena_extend_strategy: kNextPowerOfTwo
18
+ cudnn_conv_algo_search: EXHAUSTIVE
19
+ do_copy_in_default_stream: true
20
+
21
+ pre_process:
22
+ DetResizeForTest:
23
+ limit_side_len: 736
24
+ limit_type: min
25
+ NormalizeImage:
26
+ std: [0.229, 0.224, 0.225]
27
+ mean: [0.485, 0.456, 0.406]
28
+ scale: 1./255.
29
+ order: hwc
30
+ ToCHWImage:
31
+ KeepKeys:
32
+ keep_keys: ['image', 'shape']
33
+
34
+ post_process:
35
+ thresh: 0.3
36
+ box_thresh: 0.5
37
+ max_candidates: 1000
38
+ unclip_ratio: 1.6
39
+ use_dilation: true
40
+ score_mode: fast
41
+
42
+ Cls:
43
+ module_name: ch_ppocr_v2_cls
44
+ class_name: TextClassifier
45
+ model_path: resources/models/ch_ppocr_mobile_v2.0_cls_infer.onnx
46
+
47
+ use_cuda: false
48
+ CUDAExecutionProvider:
49
+ device_id: 0
50
+ arena_extend_strategy: kNextPowerOfTwo
51
+ cudnn_conv_algo_search: EXHAUSTIVE
52
+ do_copy_in_default_stream: true
53
+
54
+ cls_image_shape: [3, 48, 192]
55
+ cls_batch_num: 6
56
+ cls_thresh: 0.9
57
+ label_list: ['0', '180']
58
+
59
+ Rec:
60
+ module_name: ch_ppocr_v3_rec
61
+ class_name: TextRecognizer
62
+ model_path: resources/models/ch_PP-OCRv3_rec_infer.onnx
63
+
64
+ use_cuda: false
65
+ CUDAExecutionProvider:
66
+ device_id: 0
67
+ arena_extend_strategy: kNextPowerOfTwo
68
+ cudnn_conv_algo_search: EXHAUSTIVE
69
+ do_copy_in_default_stream: true
70
+
71
+ rec_img_shape: [3, 48, 320]
72
+ rec_batch_num: 6
rapidocr_onnxruntime/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # -*- encoding: utf-8 -*-
2
+ # @Author: SWHL
3
+ # @Contact: liekkaskono@163.com
4
+ from .rapid_ocr_api import TextSystem
rapidocr_onnxruntime/ch_ppocr_v2_cls/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # -*- encoding: utf-8 -*-
2
+ # @Author: SWHL
3
+ # @Contact: liekkaskono@163.com
4
+ from .text_cls import TextClassifier
rapidocr_onnxruntime/ch_ppocr_v2_cls/config.yaml ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_path: resources/models/ch_ppocr_mobile_v2.0_cls_infer.onnx
2
+
3
+ use_cuda: false
4
+ # Details of the params: https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html
5
+ CUDAExecutionProvider:
6
+ device_id: 0
7
+ arena_extend_strategy: kNextPowerOfTwo
8
+ cudnn_conv_algo_search: EXHAUSTIVE
9
+ do_copy_in_default_stream: true
10
+
11
+ cls_image_shape: [3, 48, 192]
12
+ cls_batch_num: 6
13
+ cls_thresh: 0.9
14
+ label_list: ['0', '180']
rapidocr_onnxruntime/ch_ppocr_v2_cls/text_cls.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import argparse
15
+ import copy
16
+ import math
17
+ import time
18
+ from typing import List
19
+
20
+ import cv2
21
+ import numpy as np
22
+
23
+ try:
24
+ from .utils import ClsPostProcess, read_yaml, OrtInferSession
25
+ except:
26
+ from utils import ClsPostProcess, read_yaml, OrtInferSession
27
+
28
+
29
+ class TextClassifier(object):
30
+ def __init__(self, config):
31
+ self.cls_image_shape = config['cls_image_shape']
32
+ self.cls_batch_num = config['cls_batch_num']
33
+ self.cls_thresh = config['cls_thresh']
34
+ self.postprocess_op = ClsPostProcess(config['label_list'])
35
+
36
+ session_instance = OrtInferSession(config)
37
+ self.session = session_instance.session
38
+ self.input_name = session_instance.get_input_name()
39
+
40
+ def __call__(self, img_list: List[np.ndarray]):
41
+ if isinstance(img_list, np.ndarray):
42
+ img_list = [img_list]
43
+
44
+ img_list = copy.deepcopy(img_list)
45
+
46
+ # Calculate the aspect ratio of all text bars
47
+ width_list = [img.shape[1] / float(img.shape[0]) for img in img_list]
48
+
49
+ # Sorting can speed up the cls process
50
+ indices = np.argsort(np.array(width_list))
51
+
52
+ img_num = len(img_list)
53
+ cls_res = [['', 0.0]] * img_num
54
+ batch_num = self.cls_batch_num
55
+ elapse = 0
56
+ for beg_img_no in range(0, img_num, batch_num):
57
+ end_img_no = min(img_num, beg_img_no + batch_num)
58
+
59
+ norm_img_batch = []
60
+ for ino in range(beg_img_no, end_img_no):
61
+ norm_img = self.resize_norm_img(img_list[indices[ino]])
62
+ norm_img = norm_img[np.newaxis, :]
63
+ norm_img_batch.append(norm_img)
64
+ norm_img_batch = np.concatenate(norm_img_batch).astype(np.float32)
65
+
66
+ starttime = time.time()
67
+ onnx_inputs = {self.input_name: norm_img_batch}
68
+ prob_out = self.session.run(None, onnx_inputs)[0]
69
+ cls_result = self.postprocess_op(prob_out)
70
+ elapse += time.time() - starttime
71
+
72
+ for rno in range(len(cls_result)):
73
+ label, score = cls_result[rno]
74
+ cls_res[indices[beg_img_no + rno]] = [label, score]
75
+ if '180' in label and score > self.cls_thresh:
76
+ img_list[indices[beg_img_no + rno]] = cv2.rotate(
77
+ img_list[indices[beg_img_no + rno]], 1)
78
+ return img_list, cls_res, elapse
79
+
80
+ def resize_norm_img(self, img):
81
+ img_c, img_h, img_w = self.cls_image_shape
82
+ h, w = img.shape[:2]
83
+ ratio = w / float(h)
84
+ if math.ceil(img_h * ratio) > img_w:
85
+ resized_w = img_w
86
+ else:
87
+ resized_w = int(math.ceil(img_h * ratio))
88
+
89
+ resized_image = cv2.resize(img, (resized_w, img_h))
90
+ resized_image = resized_image.astype('float32')
91
+ if img_c == 1:
92
+ resized_image = resized_image / 255
93
+ resized_image = resized_image[np.newaxis, :]
94
+ else:
95
+ resized_image = resized_image.transpose((2, 0, 1)) / 255
96
+
97
+ resized_image -= 0.5
98
+ resized_image /= 0.5
99
+ padding_im = np.zeros((img_c, img_h, img_w), dtype=np.float32)
100
+ padding_im[:, :, :resized_w] = resized_image
101
+ return padding_im
102
+
103
+
104
+ if __name__ == "__main__":
105
+ parser = argparse.ArgumentParser()
106
+ parser.add_argument('--image_path', type=str, help='image_dir|image_path')
107
+ parser.add_argument('--config_path', type=str, default='config.yaml')
108
+ args = parser.parse_args()
109
+
110
+ config = read_yaml(args.config_path)
111
+
112
+ text_classifier = TextClassifier(config)
113
+
114
+ img = cv2.imread(args.image_path)
115
+ img_list, cls_res, predict_time = text_classifier(img)
116
+ for ino in range(len(img_list)):
117
+ print(f"cls result:{cls_res[ino]}")
rapidocr_onnxruntime/ch_ppocr_v2_cls/utils.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import warnings
15
+
16
+ import yaml
17
+ from onnxruntime import (get_available_providers, get_device,
18
+ SessionOptions, InferenceSession,
19
+ GraphOptimizationLevel)
20
+
21
+
22
+ class OrtInferSession(object):
23
+ def __init__(self, config):
24
+ sess_opt = SessionOptions()
25
+ sess_opt.log_severity_level = 4
26
+ sess_opt.enable_cpu_mem_arena = False
27
+ sess_opt.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
28
+
29
+ cuda_ep = 'CUDAExecutionProvider'
30
+ cpu_ep = 'CPUExecutionProvider'
31
+ cpu_provider_options = {
32
+ "arena_extend_strategy": "kSameAsRequested",
33
+ }
34
+
35
+ EP_list = []
36
+ if config['use_cuda'] and get_device() == 'GPU' \
37
+ and cuda_ep in get_available_providers():
38
+ EP_list = [(cuda_ep, config[cuda_ep])]
39
+ EP_list.append((cpu_ep, cpu_provider_options))
40
+
41
+ self.session = InferenceSession(config['model_path'],
42
+ sess_options=sess_opt,
43
+ providers=EP_list)
44
+
45
+ if config['use_cuda'] and cuda_ep not in self.session.get_providers():
46
+ warnings.warn(f'{cuda_ep} is not avaiable for current env, the inference part is automatically shifted to be executed under {cpu_ep}.\n'
47
+ 'Please ensure the installed onnxruntime-gpu version matches your cuda and cudnn version, '
48
+ 'you can check their relations from the offical web site: '
49
+ 'https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html',
50
+ RuntimeWarning)
51
+
52
+ def get_input_name(self, input_idx=0):
53
+ return self.session.get_inputs()[input_idx].name
54
+
55
+ def get_output_name(self, output_idx=0):
56
+ return self.session.get_outputs()[output_idx].name
57
+
58
+
59
+ def read_yaml(yaml_path):
60
+ with open(yaml_path, 'rb') as f:
61
+ data = yaml.load(f, Loader=yaml.Loader)
62
+ return data
63
+
64
+
65
+ class ClsPostProcess(object):
66
+ """ Convert between text-label and text-index """
67
+
68
+ def __init__(self, label_list):
69
+ super(ClsPostProcess, self).__init__()
70
+ self.label_list = label_list
71
+
72
+ def __call__(self, preds, label=None):
73
+ pred_idxs = preds.argmax(axis=1)
74
+ decode_out = [(self.label_list[idx], preds[i, idx])
75
+ for i, idx in enumerate(pred_idxs)]
76
+ if label is None:
77
+ return decode_out
78
+
79
+ label = [(self.label_list[idx], 1.0) for idx in label]
80
+ return decode_out, label
rapidocr_onnxruntime/ch_ppocr_v3_det/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # -*- encoding: utf-8 -*-
2
+ # @Author: SWHL
3
+ # @Contact: liekkaskono@163.com
4
+ from .text_detect import TextDetector
rapidocr_onnxruntime/ch_ppocr_v3_det/config.yaml ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_path: resources/models/ch_PP-OCRv3_det_infer.onnx
2
+
3
+ use_cuda: false
4
+ CUDAExecutionProvider:
5
+ device_id: 0
6
+ arena_extend_strategy: kNextPowerOfTwo
7
+ cudnn_conv_algo_search: EXHAUSTIVE
8
+ do_copy_in_default_stream: true
9
+
10
+ pre_process:
11
+ DetResizeForTest:
12
+ limit_side_len: 736
13
+ limit_type: min
14
+ NormalizeImage:
15
+ std: [0.229, 0.224, 0.225]
16
+ mean: [0.485, 0.456, 0.406]
17
+ scale: 1./255.
18
+ order: hwc
19
+ ToCHWImage:
20
+ KeepKeys:
21
+ keep_keys: ['image', 'shape']
22
+
23
+ post_process:
24
+ thresh: 0.3
25
+ box_thresh: 0.5
26
+ max_candidates: 1000
27
+ unclip_ratio: 1.6
28
+ use_dilation: true
29
+ score_mode: "fast"
rapidocr_onnxruntime/ch_ppocr_v3_det/text_detect.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # -*- encoding: utf-8 -*-
15
+ # @Author: SWHL
16
+ # @Contact: liekkaskono@163.com
17
+ import argparse
18
+ import time
19
+
20
+ import cv2
21
+ import numpy as np
22
+
23
+ try:
24
+ from .utils import (DBPostProcess, create_operators,
25
+ transform, read_yaml, OrtInferSession)
26
+ except:
27
+ from utils import (DBPostProcess, create_operators,
28
+ transform, read_yaml, OrtInferSession)
29
+
30
+
31
+ class TextDetector(object):
32
+ def __init__(self, config):
33
+ self.preprocess_op = create_operators(config['pre_process'])
34
+ self.postprocess_op = DBPostProcess(**config['post_process'])
35
+
36
+ session_instance = OrtInferSession(config)
37
+ self.session = session_instance.session
38
+ self.input_name = session_instance.get_input_name()
39
+
40
+ def __call__(self, img):
41
+ if img is None:
42
+ raise ValueError('img is None')
43
+
44
+ ori_im_shape = img.shape[:2]
45
+
46
+ data = {'image': img}
47
+ data = transform(data, self.preprocess_op)
48
+ img, shape_list = data
49
+ if img is None:
50
+ return None, 0
51
+
52
+ img = np.expand_dims(img, axis=0).astype(np.float32)
53
+ shape_list = np.expand_dims(shape_list, axis=0)
54
+
55
+ starttime = time.time()
56
+ preds = self.session.run(None, {self.input_name: img})
57
+
58
+ post_result = self.postprocess_op(preds[0], shape_list)
59
+
60
+ dt_boxes = post_result[0]['points']
61
+ dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im_shape)
62
+ elapse = time.time() - starttime
63
+ return dt_boxes, elapse
64
+
65
+ def order_points_clockwise(self, pts):
66
+ """
67
+ reference from:
68
+ https://github.com/jrosebr1/imutils/blob/master/imutils/perspective.py
69
+ sort the points based on their x-coordinates
70
+ """
71
+ xSorted = pts[np.argsort(pts[:, 0]), :]
72
+
73
+ # grab the left-most and right-most points from the sorted
74
+ # x-roodinate points
75
+ leftMost = xSorted[:2, :]
76
+ rightMost = xSorted[2:, :]
77
+
78
+ # now, sort the left-most coordinates according to their
79
+ # y-coordinates so we can grab the top-left and bottom-left
80
+ # points, respectively
81
+ leftMost = leftMost[np.argsort(leftMost[:, 1]), :]
82
+ (tl, bl) = leftMost
83
+
84
+ rightMost = rightMost[np.argsort(rightMost[:, 1]), :]
85
+ (tr, br) = rightMost
86
+
87
+ rect = np.array([tl, tr, br, bl], dtype="float32")
88
+ return rect
89
+
90
+ def clip_det_res(self, points, img_height, img_width):
91
+ for pno in range(points.shape[0]):
92
+ points[pno, 0] = int(min(max(points[pno, 0], 0), img_width - 1))
93
+ points[pno, 1] = int(min(max(points[pno, 1], 0), img_height - 1))
94
+ return points
95
+
96
+ def filter_tag_det_res(self, dt_boxes, image_shape):
97
+ img_height, img_width = image_shape[:2]
98
+ dt_boxes_new = []
99
+ for box in dt_boxes:
100
+ box = self.order_points_clockwise(box)
101
+ box = self.clip_det_res(box, img_height, img_width)
102
+ rect_width = int(np.linalg.norm(box[0] - box[1]))
103
+ rect_height = int(np.linalg.norm(box[0] - box[3]))
104
+ if rect_width <= 3 or rect_height <= 3:
105
+ continue
106
+ dt_boxes_new.append(box)
107
+ dt_boxes = np.array(dt_boxes_new)
108
+ return dt_boxes
109
+
110
+
111
+ if __name__ == "__main__":
112
+ parser = argparse.ArgumentParser()
113
+ parser.add_argument('--config_path', type=str, default='config.yaml')
114
+ parser.add_argument('--image_path', type=str, default=None)
115
+ args = parser.parse_args()
116
+
117
+ config = read_yaml(args.config_path)
118
+
119
+ text_detector = TextDetector(config)
120
+
121
+ img = cv2.imread(args.image_path)
122
+ dt_boxes, elapse = text_detector(img)
123
+
124
+ from utils import draw_text_det_res
125
+ src_im = draw_text_det_res(dt_boxes, args.image_path)
126
+ cv2.imwrite('det_results.jpg', src_im)
127
+ print('The det_results.jpg has been saved in the current directory.')
rapidocr_onnxruntime/ch_ppocr_v3_det/utils.py ADDED
@@ -0,0 +1,452 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """
16
+ # -*- encoding: utf-8 -*-
17
+ # @Author: SWHL
18
+ # @Contact: liekkaskono@163.com
19
+ import sys
20
+ import warnings
21
+
22
+ import cv2
23
+ import numpy as np
24
+ import pyclipper
25
+ import six
26
+ import yaml
27
+ from shapely.geometry import Polygon
28
+ from onnxruntime import (get_available_providers, get_device,
29
+ SessionOptions, InferenceSession,
30
+ GraphOptimizationLevel)
31
+
32
+
33
+ class OrtInferSession(object):
34
+ def __init__(self, config):
35
+ sess_opt = SessionOptions()
36
+ sess_opt.log_severity_level = 4
37
+ sess_opt.enable_cpu_mem_arena = False
38
+ sess_opt.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
39
+
40
+ cuda_ep = 'CUDAExecutionProvider'
41
+ cpu_ep = 'CPUExecutionProvider'
42
+ cpu_provider_options = {
43
+ "arena_extend_strategy": "kSameAsRequested",
44
+ }
45
+
46
+ EP_list = []
47
+ if config['use_cuda'] and get_device() == 'GPU' \
48
+ and cuda_ep in get_available_providers():
49
+ EP_list = [(cuda_ep, config[cuda_ep])]
50
+ EP_list.append((cpu_ep, cpu_provider_options))
51
+
52
+ self.session = InferenceSession(config['model_path'],
53
+ sess_options=sess_opt,
54
+ providers=EP_list)
55
+
56
+ if config['use_cuda'] and cuda_ep not in self.session.get_providers():
57
+ warnings.warn(f'{cuda_ep} is not avaiable for current env, the inference part is automatically shifted to be executed under {cpu_ep}.\n'
58
+ 'Please ensure the installed onnxruntime-gpu version matches your cuda and cudnn version, '
59
+ 'you can check their relations from the offical web site: '
60
+ 'https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html',
61
+ RuntimeWarning)
62
+
63
+ def get_input_name(self, input_idx=0):
64
+ return self.session.get_inputs()[input_idx].name
65
+
66
+ def get_output_name(self, output_idx=0):
67
+ return self.session.get_outputs()[output_idx].name
68
+
69
+
70
+ def read_yaml(yaml_path):
71
+ with open(yaml_path, 'rb') as f:
72
+ data = yaml.load(f, Loader=yaml.Loader)
73
+ return data
74
+
75
+
76
+ class DecodeImage(object):
77
+ """ decode image """
78
+
79
+ def __init__(self, img_mode='RGB', channel_first=False):
80
+ self.img_mode = img_mode
81
+ self.channel_first = channel_first
82
+
83
+ def __call__(self, data):
84
+ img = data['image']
85
+ if six.PY2:
86
+ assert type(img) is str and len(img) > 0, "invalid input 'img' in DecodeImage"
87
+ else:
88
+ assert type(img) is bytes and len(img) > 0, "invalid input 'img' in DecodeImage"
89
+
90
+ img = np.frombuffer(img, dtype='uint8')
91
+ img = cv2.imdecode(img, 1)
92
+ if img is None:
93
+ return None
94
+
95
+ if self.img_mode == 'GRAY':
96
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
97
+ elif self.img_mode == 'RGB':
98
+ assert img.shape[2] == 3, f'invalid shape of image[{img.shape}]'
99
+ img = img[:, :, ::-1]
100
+
101
+ if self.channel_first:
102
+ img = img.transpose((2, 0, 1))
103
+ data['image'] = img
104
+ return data
105
+
106
+
107
+ class NormalizeImage(object):
108
+ """ normalize image such as substract mean, divide std"""
109
+
110
+ def __init__(self, scale=None, mean=None, std=None, order='chw'):
111
+ if isinstance(scale, str):
112
+ scale = eval(scale)
113
+ self.scale = np.float32(scale if scale is not None else 1.0 / 255.0)
114
+ mean = mean if mean is not None else [0.485, 0.456, 0.406]
115
+ std = std if std is not None else [0.229, 0.224, 0.225]
116
+
117
+ shape = (3, 1, 1) if order == 'chw' else (1, 1, 3)
118
+ self.mean = np.array(mean).reshape(shape).astype('float32')
119
+ self.std = np.array(std).reshape(shape).astype('float32')
120
+
121
+ def __call__(self, data):
122
+ img = np.array(data['image']).astype(np.float32)
123
+ data['image'] = (img * self.scale - self.mean) / self.std
124
+ return data
125
+
126
+
127
+ class ToCHWImage(object):
128
+ """ convert hwc image to chw image"""
129
+ def __init__(self):
130
+ pass
131
+
132
+ def __call__(self, data):
133
+ img = data['image']
134
+ from PIL import Image
135
+ if isinstance(img, Image.Image):
136
+ img = np.array(img)
137
+ data['image'] = img.transpose((2, 0, 1))
138
+ return data
139
+
140
+
141
+ class KeepKeys(object):
142
+ def __init__(self, keep_keys):
143
+ self.keep_keys = keep_keys
144
+
145
+ def __call__(self, data):
146
+ data_list = []
147
+ for key in self.keep_keys:
148
+ data_list.append(data[key])
149
+ return data_list
150
+
151
+
152
+ class DetResizeForTest(object):
153
+ def __init__(self, **kwargs):
154
+ super(DetResizeForTest, self).__init__()
155
+ self.resize_type = 0
156
+ if 'image_shape' in kwargs:
157
+ self.image_shape = kwargs['image_shape']
158
+ self.resize_type = 1
159
+ elif 'limit_side_len' in kwargs:
160
+ self.limit_side_len = kwargs.get('limit_side_len', 736)
161
+ self.limit_type = kwargs.get('limit_type', 'min')
162
+
163
+ if 'resize_long' in kwargs:
164
+ self.resize_type = 2
165
+ self.resize_long = kwargs.get('resize_long', 960)
166
+ else:
167
+ self.limit_side_len = kwargs.get('limit_side_len', 736)
168
+ self.limit_type = kwargs.get('limit_type', 'min')
169
+
170
+ def __call__(self, data):
171
+ img = data['image']
172
+ src_h, src_w = img.shape[:2]
173
+
174
+ if self.resize_type == 0:
175
+ # img, shape = self.resize_image_type0(img)
176
+ img, [ratio_h, ratio_w] = self.resize_image_type0(img)
177
+ elif self.resize_type == 2:
178
+ img, [ratio_h, ratio_w] = self.resize_image_type2(img)
179
+ else:
180
+ # img, shape = self.resize_image_type1(img)
181
+ img, [ratio_h, ratio_w] = self.resize_image_type1(img)
182
+ data['image'] = img
183
+ data['shape'] = np.array([src_h, src_w, ratio_h, ratio_w])
184
+ return data
185
+
186
+ def resize_image_type1(self, img):
187
+ resize_h, resize_w = self.image_shape
188
+ ori_h, ori_w = img.shape[:2] # (h, w, c)
189
+ ratio_h = float(resize_h) / ori_h
190
+ ratio_w = float(resize_w) / ori_w
191
+ img = cv2.resize(img, (int(resize_w), int(resize_h)))
192
+ # return img, np.array([ori_h, ori_w])
193
+ return img, [ratio_h, ratio_w]
194
+
195
+ def resize_image_type0(self, img):
196
+ """
197
+ resize image to a size multiple of 32 which is required by the network
198
+ args:
199
+ img(array): array with shape [h, w, c]
200
+ return(tuple):
201
+ img, (ratio_h, ratio_w)
202
+ """
203
+ limit_side_len = self.limit_side_len
204
+ h, w = img.shape[:2]
205
+
206
+ # limit the max side
207
+ if self.limit_type == 'max':
208
+ if max(h, w) > limit_side_len:
209
+ if h > w:
210
+ ratio = float(limit_side_len) / h
211
+ else:
212
+ ratio = float(limit_side_len) / w
213
+ else:
214
+ ratio = 1.
215
+ else:
216
+ if min(h, w) < limit_side_len:
217
+ if h < w:
218
+ ratio = float(limit_side_len) / h
219
+ else:
220
+ ratio = float(limit_side_len) / w
221
+ else:
222
+ ratio = 1.
223
+ resize_h = int(h * ratio)
224
+ resize_w = int(w * ratio)
225
+
226
+ resize_h = int(round(resize_h / 32) * 32)
227
+ resize_w = int(round(resize_w / 32) * 32)
228
+
229
+ try:
230
+ if int(resize_w) <= 0 or int(resize_h) <= 0:
231
+ return None, (None, None)
232
+ img = cv2.resize(img, (int(resize_w), int(resize_h)))
233
+ except:
234
+ print(img.shape, resize_w, resize_h)
235
+ sys.exit(0)
236
+ ratio_h = resize_h / float(h)
237
+ ratio_w = resize_w / float(w)
238
+ return img, [ratio_h, ratio_w]
239
+
240
+ def resize_image_type2(self, img):
241
+ h, w = img.shape[:2]
242
+
243
+ resize_w = w
244
+ resize_h = h
245
+
246
+ # Fix the longer side
247
+ if resize_h > resize_w:
248
+ ratio = float(self.resize_long) / resize_h
249
+ else:
250
+ ratio = float(self.resize_long) / resize_w
251
+
252
+ resize_h = int(resize_h * ratio)
253
+ resize_w = int(resize_w * ratio)
254
+
255
+ max_stride = 128
256
+ resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
257
+ resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
258
+ img = cv2.resize(img, (int(resize_w), int(resize_h)))
259
+ ratio_h = resize_h / float(h)
260
+ ratio_w = resize_w / float(w)
261
+
262
+ return img, [ratio_h, ratio_w]
263
+
264
+
265
+ def transform(data, ops=None):
266
+ """ transform """
267
+ if ops is None:
268
+ ops = []
269
+
270
+ for op in ops:
271
+ data = op(data)
272
+ if data is None:
273
+ return None
274
+ return data
275
+
276
+
277
+ def create_operators(op_param_dict):
278
+ """
279
+ create operators based on the config
280
+ """
281
+ ops = []
282
+ for op_name, param in op_param_dict.items():
283
+ if param is None:
284
+ param = {}
285
+ op = eval(op_name)(**param)
286
+ ops.append(op)
287
+ return ops
288
+
289
+
290
+ def draw_text_det_res(dt_boxes, img_path):
291
+ src_im = cv2.imread(img_path)
292
+ for box in dt_boxes:
293
+ box = np.array(box).astype(np.int32).reshape(-1, 2)
294
+ cv2.polylines(src_im, [box], True,
295
+ color=(255, 255, 0), thickness=2)
296
+ return src_im
297
+
298
+
299
+ class DBPostProcess(object):
300
+ """The post process for Differentiable Binarization (DB)."""
301
+
302
+ def __init__(self,
303
+ thresh=0.3,
304
+ box_thresh=0.7,
305
+ max_candidates=1000,
306
+ unclip_ratio=2.0,
307
+ score_mode="fast",
308
+ use_dilation=False):
309
+ self.thresh = thresh
310
+ self.box_thresh = box_thresh
311
+ self.max_candidates = max_candidates
312
+ self.unclip_ratio = unclip_ratio
313
+ self.min_size = 3
314
+ self.score_mode = score_mode
315
+
316
+ if use_dilation:
317
+ self.dilation_kernel = np.array([[1, 1], [1, 1]])
318
+ else:
319
+ self.dilation_kernel = None
320
+
321
+ def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
322
+ '''
323
+ _bitmap: single map with shape (1, H, W),
324
+ whose values are binarized as {0, 1}
325
+ '''
326
+
327
+ bitmap = _bitmap
328
+ height, width = bitmap.shape
329
+
330
+ outs = cv2.findContours((bitmap * 255).astype(np.uint8), cv2.RETR_LIST,
331
+ cv2.CHAIN_APPROX_SIMPLE)
332
+ if len(outs) == 3:
333
+ img, contours, _ = outs[0], outs[1], outs[2]
334
+ elif len(outs) == 2:
335
+ contours, _ = outs[0], outs[1]
336
+
337
+ num_contours = min(len(contours), self.max_candidates)
338
+
339
+ boxes = []
340
+ scores = []
341
+ for index in range(num_contours):
342
+ contour = contours[index]
343
+ points, sside = self.get_mini_boxes(contour)
344
+ if sside < self.min_size:
345
+ continue
346
+ points = np.array(points)
347
+ if self.score_mode == "fast":
348
+ score = self.box_score_fast(pred, points.reshape(-1, 2))
349
+ else:
350
+ score = self.box_score_slow(pred, contour)
351
+ if self.box_thresh > score:
352
+ continue
353
+
354
+ box = self.unclip(points).reshape(-1, 1, 2)
355
+ box, sside = self.get_mini_boxes(box)
356
+ if sside < self.min_size + 2:
357
+ continue
358
+ box = np.array(box)
359
+
360
+ box[:, 0] = np.clip(
361
+ np.round(box[:, 0] / width * dest_width), 0, dest_width)
362
+ box[:, 1] = np.clip(
363
+ np.round(box[:, 1] / height * dest_height), 0, dest_height)
364
+ boxes.append(box.astype(np.int16))
365
+ scores.append(score)
366
+ return np.array(boxes, dtype=np.int16), scores
367
+
368
+ def unclip(self, box):
369
+ unclip_ratio = self.unclip_ratio
370
+ poly = Polygon(box)
371
+ distance = poly.area * unclip_ratio / poly.length
372
+ offset = pyclipper.PyclipperOffset()
373
+ offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
374
+ expanded = np.array(offset.Execute(distance))
375
+ return expanded
376
+
377
+ def get_mini_boxes(self, contour):
378
+ bounding_box = cv2.minAreaRect(contour)
379
+ points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
380
+
381
+ index_1, index_2, index_3, index_4 = 0, 1, 2, 3
382
+ if points[1][1] > points[0][1]:
383
+ index_1 = 0
384
+ index_4 = 1
385
+ else:
386
+ index_1 = 1
387
+ index_4 = 0
388
+ if points[3][1] > points[2][1]:
389
+ index_2 = 2
390
+ index_3 = 3
391
+ else:
392
+ index_2 = 3
393
+ index_3 = 2
394
+
395
+ box = [
396
+ points[index_1], points[index_2], points[index_3], points[index_4]
397
+ ]
398
+ return box, min(bounding_box[1])
399
+
400
+ def box_score_fast(self, bitmap, _box):
401
+ h, w = bitmap.shape[:2]
402
+ box = _box.copy()
403
+ xmin = np.clip(np.floor(box[:, 0].min()).astype(np.int32), 0, w - 1)
404
+ xmax = np.clip(np.ceil(box[:, 0].max()).astype(np.int32), 0, w - 1)
405
+ ymin = np.clip(np.floor(box[:, 1].min()).astype(np.int32), 0, h - 1)
406
+ ymax = np.clip(np.ceil(box[:, 1].max()).astype(np.int32), 0, h - 1)
407
+
408
+ mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
409
+ box[:, 0] = box[:, 0] - xmin
410
+ box[:, 1] = box[:, 1] - ymin
411
+ cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int32), 1)
412
+ return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
413
+
414
+ def box_score_slow(self, bitmap, contour):
415
+ '''
416
+ box_score_slow: use polyon mean score as the mean score
417
+ '''
418
+ h, w = bitmap.shape[:2]
419
+ contour = contour.copy()
420
+ contour = np.reshape(contour, (-1, 2))
421
+
422
+ xmin = np.clip(np.min(contour[:, 0]), 0, w - 1)
423
+ xmax = np.clip(np.max(contour[:, 0]), 0, w - 1)
424
+ ymin = np.clip(np.min(contour[:, 1]), 0, h - 1)
425
+ ymax = np.clip(np.max(contour[:, 1]), 0, h - 1)
426
+
427
+ mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
428
+
429
+ contour[:, 0] = contour[:, 0] - xmin
430
+ contour[:, 1] = contour[:, 1] - ymin
431
+
432
+ cv2.fillPoly(mask, contour.reshape(1, -1, 2).astype(np.int32), 1)
433
+ return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
434
+
435
+ def __call__(self, pred, shape_list):
436
+ pred = pred[:, 0, :, :]
437
+ segmentation = pred > self.thresh
438
+
439
+ boxes_batch = []
440
+ for batch_index in range(pred.shape[0]):
441
+ src_h, src_w, ratio_h, ratio_w = shape_list[batch_index]
442
+ if self.dilation_kernel is not None:
443
+ mask = cv2.dilate(
444
+ np.array(segmentation[batch_index]).astype(np.uint8),
445
+ self.dilation_kernel)
446
+ else:
447
+ mask = segmentation[batch_index]
448
+ boxes, scores = self.boxes_from_bitmap(pred[batch_index], mask,
449
+ src_w, src_h)
450
+
451
+ boxes_batch.append({'points': boxes})
452
+ return boxes_batch
rapidocr_onnxruntime/ch_ppocr_v3_rec/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # -*- encoding: utf-8 -*-
2
+ # @Author: SWHL
3
+ # @Contact: liekkaskono@163.com
4
+ from .text_recognize import TextRecognizer
rapidocr_onnxruntime/ch_ppocr_v3_rec/config.yaml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_path: resources/models/ch_PP-OCRv3_rec_infer.onnx
2
+
3
+ use_cuda: false
4
+ # Details of the params: https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html
5
+ CUDAExecutionProvider:
6
+ device_id: 0
7
+ arena_extend_strategy: kNextPowerOfTwo
8
+ cudnn_conv_algo_search: EXHAUSTIVE
9
+ do_copy_in_default_stream: true
10
+
11
+ rec_img_shape: [3, 48, 320]
12
+ rec_batch_num: 6
rapidocr_onnxruntime/ch_ppocr_v3_rec/text_recognize.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import argparse
15
+ import math
16
+ import time
17
+ from typing import List
18
+
19
+ import cv2
20
+ import numpy as np
21
+
22
+ try:
23
+ from .utils import CTCLabelDecode, read_yaml, OrtInferSession
24
+ except:
25
+ from utils import CTCLabelDecode, read_yaml, OrtInferSession
26
+
27
+
28
+ class TextRecognizer(object):
29
+ def __init__(self, config):
30
+ session_instance = OrtInferSession(config)
31
+ self.session = session_instance.session
32
+ self.input_name = session_instance.get_input_name()
33
+ meta_dict = session_instance.get_metadata()
34
+
35
+ if 'character' in meta_dict.keys():
36
+ self.character_dict_path = meta_dict['character'].splitlines()
37
+ else:
38
+ self.character_dict_path = config.get('keys_path', None)
39
+ self.postprocess_op = CTCLabelDecode(self.character_dict_path)
40
+
41
+ self.rec_batch_num = config['rec_batch_num']
42
+ self.rec_image_shape = config['rec_img_shape']
43
+
44
+ def __call__(self, img_list: List[np.ndarray]):
45
+ if isinstance(img_list, np.ndarray):
46
+ img_list = [img_list]
47
+
48
+ # Calculate the aspect ratio of all text bars
49
+ width_list = [img.shape[1] / float(img.shape[0]) for img in img_list]
50
+
51
+ # Sorting can speed up the recognition process
52
+ indices = np.argsort(np.array(width_list))
53
+
54
+ img_num = len(img_list)
55
+ rec_res = [['', 0.0]] * img_num
56
+
57
+ batch_num = self.rec_batch_num
58
+ elapse = 0
59
+ for beg_img_no in range(0, img_num, batch_num):
60
+ end_img_no = min(img_num, beg_img_no + batch_num)
61
+ max_wh_ratio = 0
62
+ for ino in range(beg_img_no, end_img_no):
63
+ h, w = img_list[indices[ino]].shape[0:2]
64
+ wh_ratio = w * 1.0 / h
65
+ max_wh_ratio = max(max_wh_ratio, wh_ratio)
66
+
67
+ norm_img_batch = []
68
+ for ino in range(beg_img_no, end_img_no):
69
+ norm_img = self.resize_norm_img(img_list[indices[ino]],
70
+ max_wh_ratio)
71
+ norm_img_batch.append(norm_img[np.newaxis, :])
72
+ norm_img_batch = np.concatenate(norm_img_batch).astype(np.float32)
73
+
74
+ starttime = time.time()
75
+ onnx_inputs = {self.input_name: norm_img_batch}
76
+ preds = self.session.run(None, onnx_inputs)[0]
77
+ rec_result = self.postprocess_op(preds)
78
+
79
+ for rno in range(len(rec_result)):
80
+ rec_res[indices[beg_img_no + rno]] = rec_result[rno]
81
+ elapse += time.time() - starttime
82
+ return rec_res, elapse
83
+
84
+ def resize_norm_img(self, img, max_wh_ratio):
85
+ img_channel, img_height, img_width = self.rec_image_shape
86
+ assert img_channel == img.shape[2]
87
+
88
+ img_width = int(img_height * max_wh_ratio)
89
+
90
+ h, w = img.shape[:2]
91
+ ratio = w / float(h)
92
+ if math.ceil(img_height * ratio) > img_width:
93
+ resized_w = img_width
94
+ else:
95
+ resized_w = int(math.ceil(img_height * ratio))
96
+
97
+ resized_image = cv2.resize(img, (resized_w, img_height))
98
+ resized_image = resized_image.astype('float32')
99
+ resized_image = resized_image.transpose((2, 0, 1)) / 255
100
+ resized_image -= 0.5
101
+ resized_image /= 0.5
102
+
103
+ padding_im = np.zeros((img_channel, img_height, img_width),
104
+ dtype=np.float32)
105
+ padding_im[:, :, 0:resized_w] = resized_image
106
+ return padding_im
107
+
108
+
109
+ if __name__ == "__main__":
110
+ parser = argparse.ArgumentParser()
111
+ parser.add_argument('--image_path', type=str, help='image_dir|image_path')
112
+ parser.add_argument('--config_path', type=str, default='config.yaml')
113
+ args = parser.parse_args()
114
+
115
+ config = read_yaml(args.config_path)
116
+ text_recognizer = TextRecognizer(config)
117
+
118
+ img = cv2.imread(args.image_path)
119
+ rec_res, predict_time = text_recognizer(img)
120
+ print(f'rec result: {rec_res}\t cost: {predict_time}s')
rapidocr_onnxruntime/ch_ppocr_v3_rec/utils.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- encoding: utf-8 -*-
2
+ # @Author: SWHL
3
+ # @Contact: liekkaskono@163.com
4
+ import warnings
5
+
6
+ import numpy as np
7
+ import yaml
8
+ from onnxruntime import (get_available_providers, get_device,
9
+ SessionOptions, InferenceSession,
10
+ GraphOptimizationLevel)
11
+
12
+
13
+ class OrtInferSession(object):
14
+ def __init__(self, config):
15
+ sess_opt = SessionOptions()
16
+ sess_opt.log_severity_level = 4
17
+ sess_opt.enable_cpu_mem_arena = False
18
+ sess_opt.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
19
+
20
+ cuda_ep = 'CUDAExecutionProvider'
21
+ cpu_ep = 'CPUExecutionProvider'
22
+ cpu_provider_options = {
23
+ "arena_extend_strategy": "kSameAsRequested",
24
+ }
25
+
26
+ EP_list = []
27
+ if config['use_cuda'] and get_device() == 'GPU' \
28
+ and cuda_ep in get_available_providers():
29
+ EP_list = [(cuda_ep, config[cuda_ep])]
30
+ EP_list.append((cpu_ep, cpu_provider_options))
31
+
32
+ self.session = InferenceSession(config['model_path'],
33
+ sess_options=sess_opt,
34
+ providers=EP_list)
35
+
36
+ if config['use_cuda'] and cuda_ep not in self.session.get_providers():
37
+ warnings.warn(f'{cuda_ep} is not avaiable for current env, the inference part is automatically shifted to be executed under {cpu_ep}.\n'
38
+ 'Please ensure the installed onnxruntime-gpu version matches your cuda and cudnn version, '
39
+ 'you can check their relations from the offical web site: '
40
+ 'https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html',
41
+ RuntimeWarning)
42
+
43
+ def get_input_name(self, input_idx=0):
44
+ return self.session.get_inputs()[input_idx].name
45
+
46
+ def get_output_name(self, output_idx=0):
47
+ return self.session.get_outputs()[output_idx].name
48
+
49
+ def get_metadata(self):
50
+ meta_dict = self.session.get_modelmeta().custom_metadata_map
51
+ return meta_dict
52
+
53
+
54
+ def read_yaml(yaml_path):
55
+ with open(yaml_path, 'rb') as f:
56
+ data = yaml.load(f, Loader=yaml.Loader)
57
+ return data
58
+
59
+
60
+ class CTCLabelDecode(object):
61
+ """ Convert between text-label and text-index """
62
+
63
+ def __init__(self, character_dict_path):
64
+ super(CTCLabelDecode, self).__init__()
65
+
66
+ self.character_str = []
67
+ assert character_dict_path is not None, "character_dict_path should not be None"
68
+
69
+ if isinstance(character_dict_path, str):
70
+ with open(character_dict_path, "rb") as fin:
71
+ lines = fin.readlines()
72
+ for line in lines:
73
+ line = line.decode('utf-8').strip("\n").strip("\r\n")
74
+ self.character_str.append(line)
75
+ else:
76
+ self.character_str = character_dict_path
77
+ self.character_str.append(' ')
78
+
79
+ dict_character = self.add_special_char(self.character_str)
80
+ self.character = dict_character
81
+
82
+ self.dict = {}
83
+ for i, char in enumerate(dict_character):
84
+ self.dict[char] = i
85
+
86
+ def __call__(self, preds, label=None):
87
+ preds_idx = preds.argmax(axis=2)
88
+ preds_prob = preds.max(axis=2)
89
+ text = self.decode(preds_idx, preds_prob,
90
+ is_remove_duplicate=True)
91
+ if label is None:
92
+ return text
93
+ label = self.decode(label)
94
+ return text, label
95
+
96
+ def add_special_char(self, dict_character):
97
+ dict_character = ['blank'] + dict_character
98
+ return dict_character
99
+
100
+ def get_ignored_tokens(self):
101
+ return [0] # for ctc blank
102
+
103
+ def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
104
+ """ convert text-index into text-label. """
105
+
106
+ result_list = []
107
+ ignored_tokens = self.get_ignored_tokens()
108
+ batch_size = len(text_index)
109
+ for batch_idx in range(batch_size):
110
+ char_list = []
111
+ conf_list = []
112
+ for idx in range(len(text_index[batch_idx])):
113
+ if text_index[batch_idx][idx] in ignored_tokens:
114
+ continue
115
+ if is_remove_duplicate:
116
+ # only for predict
117
+ if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
118
+ batch_idx][idx]:
119
+ continue
120
+ char_list.append(self.character[int(text_index[batch_idx][
121
+ idx])])
122
+ if text_prob is not None:
123
+ conf_list.append(text_prob[batch_idx][idx])
124
+ else:
125
+ conf_list.append(1)
126
+ text = ''.join(char_list)
127
+ result_list.append((text, np.mean(conf_list + [1e-50] )))
128
+ return result_list
rapidocr_onnxruntime/rapid_ocr_api.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- encoding: utf-8 -*-
2
+ # @Author: SWHL
3
+ # @Contact: liekkaskono@163.com
4
+ import copy
5
+ import importlib
6
+ import sys
7
+ from pathlib import Path
8
+
9
+ import cv2
10
+ import numpy as np
11
+ import yaml
12
+
13
+ root_dir = Path(__file__).resolve().parent
14
+ sys.path.append(str(root_dir))
15
+
16
+
17
+ class TextSystem(object):
18
+ def __init__(self, config_path):
19
+ super(TextSystem).__init__()
20
+ if not Path(config_path).exists():
21
+ raise FileExistsError(f'{config_path} does not exist!')
22
+
23
+ config = self.read_yaml(config_path)
24
+
25
+ global_config = config['Global']
26
+ self.print_verbose = global_config['print_verbose']
27
+ self.text_score = global_config['text_score']
28
+ self.min_height = global_config['min_height']
29
+ self.width_height_ratio = global_config['width_height_ratio']
30
+
31
+ TextDetector = self.init_module(config['Det']['module_name'],
32
+ config['Det']['class_name'])
33
+ self.text_detector = TextDetector(config['Det'])
34
+
35
+ TextRecognizer = self.init_module(config['Rec']['module_name'],
36
+ config['Rec']['class_name'])
37
+ self.text_recognizer = TextRecognizer(config['Rec'])
38
+
39
+ self.use_angle_cls = config['Global']['use_angle_cls']
40
+ if self.use_angle_cls:
41
+ TextClassifier = self.init_module(config['Cls']['module_name'],
42
+ config['Cls']['class_name'])
43
+ self.text_cls = TextClassifier(config['Cls'])
44
+
45
+ def __call__(self, img: np.ndarray):
46
+ h, w = img.shape[:2]
47
+ if self.width_height_ratio == -1:
48
+ use_limit_ratio = False
49
+ else:
50
+ use_limit_ratio = w / h > self.width_height_ratio
51
+
52
+ if h <= self.min_height or use_limit_ratio:
53
+ dt_boxes, img_crop_list = self.get_boxes_img_without_det(img, h, w)
54
+ else:
55
+ dt_boxes, elapse = self.text_detector(img)
56
+ if dt_boxes is None or len(dt_boxes) < 1:
57
+ return None, None
58
+ if self.print_verbose:
59
+ print(f'dt_boxes num: {len(dt_boxes)}, elapse: {elapse}')
60
+
61
+ dt_boxes = self.sorted_boxes(dt_boxes)
62
+ img_crop_list = self.get_crop_img_list(img, dt_boxes)
63
+
64
+ if self.use_angle_cls:
65
+ img_crop_list, _, elapse = self.text_cls(img_crop_list)
66
+ if self.print_verbose:
67
+ print(f'cls num: {len(img_crop_list)}, elapse: {elapse}')
68
+
69
+ rec_res, elapse = self.text_recognizer(img_crop_list)
70
+ if self.print_verbose:
71
+ print(f'rec_res num: {len(rec_res)}, elapse: {elapse}')
72
+
73
+ filter_boxes, filter_rec_res = self.filter_boxes_rec_by_score(dt_boxes,
74
+ rec_res)
75
+ return filter_boxes, filter_rec_res
76
+
77
+ @staticmethod
78
+ def read_yaml(yaml_path):
79
+ with open(yaml_path, 'rb') as f:
80
+ data = yaml.load(f, Loader=yaml.Loader)
81
+ return data
82
+
83
+ @staticmethod
84
+ def init_module(module_name, class_name):
85
+ module_part = importlib.import_module(module_name)
86
+ return getattr(module_part, class_name)
87
+
88
+ def get_boxes_img_without_det(self, img, h, w):
89
+ x0, y0, x1, y1 = 0, 0, w, h
90
+ dt_boxes = np.array([[x0, y0], [x1, y0], [x1, y1], [x0, y1]])
91
+ dt_boxes = dt_boxes[np.newaxis, ...]
92
+ img_crop_list = [img]
93
+ return dt_boxes, img_crop_list
94
+
95
+ def get_crop_img_list(self, img, dt_boxes):
96
+ def get_rotate_crop_image(img, points):
97
+ img_crop_width = int(
98
+ max(
99
+ np.linalg.norm(points[0] - points[1]),
100
+ np.linalg.norm(points[2] - points[3])))
101
+ img_crop_height = int(
102
+ max(
103
+ np.linalg.norm(points[0] - points[3]),
104
+ np.linalg.norm(points[1] - points[2])))
105
+ pts_std = np.float32([[0, 0], [img_crop_width, 0],
106
+ [img_crop_width, img_crop_height],
107
+ [0, img_crop_height]])
108
+ M = cv2.getPerspectiveTransform(points, pts_std)
109
+ dst_img = cv2.warpPerspective(
110
+ img,
111
+ M, (img_crop_width, img_crop_height),
112
+ borderMode=cv2.BORDER_REPLICATE,
113
+ flags=cv2.INTER_CUBIC)
114
+ dst_img_height, dst_img_width = dst_img.shape[0:2]
115
+ if dst_img_height * 1.0 / dst_img_width >= 1.5:
116
+ dst_img = np.rot90(dst_img)
117
+ return dst_img
118
+
119
+ img_crop_list = []
120
+ for box in dt_boxes:
121
+ tmp_box = copy.deepcopy(box)
122
+ img_crop = get_rotate_crop_image(img, tmp_box)
123
+ img_crop_list.append(img_crop)
124
+ return img_crop_list
125
+
126
+ @staticmethod
127
+ def sorted_boxes(dt_boxes):
128
+ """
129
+ Sort text boxes in order from top to bottom, left to right
130
+ args:
131
+ dt_boxes(array):detected text boxes with shape [4, 2]
132
+ return:
133
+ sorted boxes(array) with shape [4, 2]
134
+ """
135
+ num_boxes = dt_boxes.shape[0]
136
+ sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0]))
137
+ _boxes = list(sorted_boxes)
138
+
139
+ for i in range(num_boxes - 1):
140
+ if abs(_boxes[i + 1][0][1] - _boxes[i][0][1]) < 10 and \
141
+ (_boxes[i + 1][0][0] < _boxes[i][0][0]):
142
+ tmp = _boxes[i]
143
+ _boxes[i] = _boxes[i + 1]
144
+ _boxes[i + 1] = tmp
145
+ return _boxes
146
+
147
+ def filter_boxes_rec_by_score(self, dt_boxes, rec_res):
148
+ filter_boxes, filter_rec_res = [], []
149
+ for box, rec_reuslt in zip(dt_boxes, rec_res):
150
+ text, score = rec_reuslt
151
+ if score >= self.text_score:
152
+ filter_boxes.append(box)
153
+ filter_rec_res.append(rec_reuslt)
154
+ return filter_boxes, filter_rec_res
155
+
156
+
157
+ if __name__ == '__main__':
158
+ text_sys = TextSystem('config.yaml')
159
+
160
+ import cv2
161
+ img = cv2.imread('resources/test_images/det_images/ch_en_num.jpg')
162
+
163
+ result = text_sys(img)
164
+ print(result)
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ Gradio
2
+ pyclipper>=1.2.0
3
+ Shapely>=1.7.1
4
+ opencv_python==4.5.1.48
5
+ six>=1.15.0
6
+ numpy>=1.19.5
7
+ Pillow
8
+ PyYAML
9
+ pytest