johnlockejrr commited on
Commit
3b5cca1
·
verified ·
1 Parent(s): 781906d

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +168 -169
  2. model.pt +3 -0
app.py CHANGED
@@ -1,170 +1,169 @@
1
- import streamlit as st
2
- import warnings
3
- warnings.simplefilter("ignore", UserWarning)
4
-
5
- from uuid import uuid4
6
- from laia.scripts.htr.decode_ctc import run as decode
7
- from laia.common.arguments import CommonArgs, DataArgs, TrainerArgs, DecodeArgs
8
- import sys
9
- from tempfile import NamedTemporaryFile, mkdtemp
10
- from pathlib import Path
11
- from contextlib import redirect_stdout
12
- import re
13
- from PIL import Image
14
- from bidi.algorithm import get_display
15
- import multiprocessing
16
- from ultralytics import YOLO
17
- import cv2
18
- import numpy as np
19
- import pandas as pd
20
- import logging
21
- from transformers import AutoModel
22
-
23
- # Configure logging
24
- logging.getLogger("lightning.pytorch").setLevel(logging.ERROR)
25
-
26
- # Load YOLOv8 model
27
- model = AutoModel.from_pretrained('johnlockejrr/yolov8-samaritan-segmentation')
28
- images = Path(mkdtemp())
29
- DEFAULT_HEIGHT = 128
30
- TEXT_DIRECTION = "RTL"
31
- NUM_WORKERS = multiprocessing.cpu_count()
32
-
33
- # Regex pattern for extracting results
34
- IMAGE_ID_PATTERN = r"(?P<image_id>[-a-z0-9]{36})"
35
- CONFIDENCE_PATTERN = r"(?P<confidence>[0-9.]+)" # For line
36
- TEXT_PATTERN = r"\s*(?P<text>.*)\s*"
37
- LINE_PREDICTION = re.compile(rf"{IMAGE_ID_PATTERN} {CONFIDENCE_PATTERN} {TEXT_PATTERN}")
38
- models_name = ["johnlockejrr/pylaia-heb_sam_v1"]
39
- MODELS = {}
40
-
41
- def get_width(image, height=DEFAULT_HEIGHT):
42
- aspect_ratio = image.width / image.height
43
- return height * aspect_ratio
44
-
45
- def load_model(model_name):
46
- if model_name not in MODELS:
47
- MODELS[model_name] = Path(snapshot_download(model_name))
48
- return MODELS[model_name]
49
-
50
- def predict(model_name, input_img):
51
- model_dir = load_model(model_name)
52
-
53
- temperature = 2.0
54
- batch_size = 1
55
-
56
- weights_path = model_dir / "weights.ckpt"
57
- syms_path = model_dir / "syms.txt"
58
- language_model_params = {"language_model_weight": 1.0}
59
- use_language_model = (model_dir / "tokens.txt").exists()
60
- if use_language_model:
61
- language_model_params.update(
62
- {
63
- "language_model_path": str(model_dir / "language_model.binary"),
64
- "lexicon_path": str(model_dir / "lexicon.txt"),
65
- "tokens_path": str(model_dir / "tokens.txt"),
66
- }
67
- )
68
-
69
- common_args = CommonArgs(
70
- checkpoint=str(weights_path.relative_to(model_dir)),
71
- train_path=str(model_dir),
72
- experiment_dirname="",
73
- )
74
-
75
- data_args = DataArgs(batch_size=batch_size, color_mode="L", num_workers=NUM_WORKERS)
76
- trainer_args = TrainerArgs(progress_bar_refresh_rate=0)
77
- decode_args = DecodeArgs(
78
- include_img_ids=True,
79
- join_string="",
80
- convert_spaces=True,
81
- print_line_confidence_scores=True,
82
- print_word_confidence_scores=False,
83
- temperature=temperature,
84
- use_language_model=use_language_model,
85
- **language_model_params,
86
- )
87
-
88
- with NamedTemporaryFile() as pred_stdout, NamedTemporaryFile() as img_list:
89
- image_id = uuid4()
90
- input_img = input_img.resize((int(get_width(input_img)), DEFAULT_HEIGHT))
91
- input_img.save(f"{images}/{image_id}.jpg")
92
- Path(img_list.name).write_text("\n".join([str(image_id)]))
93
-
94
- with redirect_stdout(open(pred_stdout.name, mode="w")):
95
- decode(
96
- syms=str(syms_path),
97
- img_list=img_list.name,
98
- img_dirs=[str(images)],
99
- common=common_args,
100
- data=data_args,
101
- trainer=trainer_args,
102
- decode=decode_args,
103
- num_workers=1,
104
- )
105
- sys.stdout.flush()
106
- predictions = Path(pred_stdout.name).read_text().strip().splitlines()
107
-
108
- _, score, text = LINE_PREDICTION.match(predictions[0]).groups()
109
- if TEXT_DIRECTION == "RTL":
110
- return input_img, {"text": get_display(text), "score": score}
111
- else:
112
- return input_img, {"text": text, "score": score}
113
-
114
- def process_image(image):
115
- # Perform inference on an image, select textline only
116
- results = model(image, classes=1)
117
-
118
- img_cv2 = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
119
- boxes = results[0].boxes.xyxy.tolist()
120
- boxes.sort(key=lambda x: x[1])
121
-
122
- bboxes = []
123
- polygons = []
124
- texts = []
125
-
126
- for i, box in enumerate(boxes):
127
- x1, y1, x2, y2 = map(int, box)
128
- crop_img = img_cv2[y1:y2, x1:x2]
129
- crop_pil = Image.fromarray(cv2.cvtColor(crop_img, cv2.COLOR_BGR2RGB))
130
-
131
- # Recognize text using PyLaia model
132
- predicted = predict(models_name, crop_pil)
133
- texts.append(predicted[1]["text"])
134
-
135
- bboxes.append((x1, y1, x2, y2))
136
- polygons.append(f"Line {i+1}: {[(x1, y1), (x2, y1), (x2, y2), (x1, y2)]}")
137
-
138
- # Draw bounding box
139
- cv2.rectangle(img_cv2, (x1, y1), (x2, y2), (0, 255, 0), 2)
140
-
141
- # Convert image back to RGB for display in Streamlit
142
- img_result = cv2.cvtColor(img_cv2, cv2.COLOR_BGR2RGB)
143
-
144
- # Combine polygons and texts into a DataFrame for table display
145
- table_data = pd.DataFrame({"Polygons": polygons, "Recognized Text": texts})
146
- return Image.fromarray(img_result), table_data
147
-
148
- def segment_and_recognize(image):
149
- segmented_image, table_data = process_image(image)
150
- return segmented_image, table_data
151
-
152
- # Streamlit app layout
153
- st.title("YOLOv8 Text Line Segmentation & PyLaia Text Recognition")
154
-
155
- # File uploader
156
- uploaded_image = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg"])
157
-
158
- # Process the image if uploaded
159
- if uploaded_image is not None:
160
- image = Image.open(uploaded_image)
161
-
162
- if st.button("Segment and Recognize"):
163
- # Perform segmentation and recognition
164
- segmented_image, table_data = segment_and_recognize(image)
165
-
166
- # Display the segmented image
167
- st.image(segmented_image, caption="Segmented Image with Bounding Boxes", use_column_width=True)
168
-
169
- # Display the table with polygons and recognized text
170
  st.table(table_data)
 
1
+ import streamlit as st
2
+ import warnings
3
+ warnings.simplefilter("ignore", UserWarning)
4
+
5
+ from uuid import uuid4
6
+ from laia.scripts.htr.decode_ctc import run as decode
7
+ from laia.common.arguments import CommonArgs, DataArgs, TrainerArgs, DecodeArgs
8
+ import sys
9
+ from tempfile import NamedTemporaryFile, mkdtemp
10
+ from pathlib import Path
11
+ from contextlib import redirect_stdout
12
+ import re
13
+ from PIL import Image
14
+ from bidi.algorithm import get_display
15
+ import multiprocessing
16
+ from ultralytics import YOLO
17
+ import cv2
18
+ import numpy as np
19
+ import pandas as pd
20
+ import logging
21
+
22
+ # Configure logging
23
+ logging.getLogger("lightning.pytorch").setLevel(logging.ERROR)
24
+
25
+ # Load YOLOv8 model
26
+ model = YOLO('model.pt')
27
+ images = Path(mkdtemp())
28
+ DEFAULT_HEIGHT = 128
29
+ TEXT_DIRECTION = "RTL"
30
+ NUM_WORKERS = multiprocessing.cpu_count()
31
+
32
+ # Regex pattern for extracting results
33
+ IMAGE_ID_PATTERN = r"(?P<image_id>[-a-z0-9]{36})"
34
+ CONFIDENCE_PATTERN = r"(?P<confidence>[0-9.]+)" # For line
35
+ TEXT_PATTERN = r"\s*(?P<text>.*)\s*"
36
+ LINE_PREDICTION = re.compile(rf"{IMAGE_ID_PATTERN} {CONFIDENCE_PATTERN} {TEXT_PATTERN}")
37
+ models_name = ["johnlockejrr/pylaia-heb_sam_v1"]
38
+ MODELS = {}
39
+
40
+ def get_width(image, height=DEFAULT_HEIGHT):
41
+ aspect_ratio = image.width / image.height
42
+ return height * aspect_ratio
43
+
44
+ def load_model(model_name):
45
+ if model_name not in MODELS:
46
+ MODELS[model_name] = Path(snapshot_download(model_name))
47
+ return MODELS[model_name]
48
+
49
+ def predict(model_name, input_img):
50
+ model_dir = load_model(model_name)
51
+
52
+ temperature = 2.0
53
+ batch_size = 1
54
+
55
+ weights_path = model_dir / "weights.ckpt"
56
+ syms_path = model_dir / "syms.txt"
57
+ language_model_params = {"language_model_weight": 1.0}
58
+ use_language_model = (model_dir / "tokens.txt").exists()
59
+ if use_language_model:
60
+ language_model_params.update(
61
+ {
62
+ "language_model_path": str(model_dir / "language_model.binary"),
63
+ "lexicon_path": str(model_dir / "lexicon.txt"),
64
+ "tokens_path": str(model_dir / "tokens.txt"),
65
+ }
66
+ )
67
+
68
+ common_args = CommonArgs(
69
+ checkpoint=str(weights_path.relative_to(model_dir)),
70
+ train_path=str(model_dir),
71
+ experiment_dirname="",
72
+ )
73
+
74
+ data_args = DataArgs(batch_size=batch_size, color_mode="L", num_workers=NUM_WORKERS)
75
+ trainer_args = TrainerArgs(progress_bar_refresh_rate=0)
76
+ decode_args = DecodeArgs(
77
+ include_img_ids=True,
78
+ join_string="",
79
+ convert_spaces=True,
80
+ print_line_confidence_scores=True,
81
+ print_word_confidence_scores=False,
82
+ temperature=temperature,
83
+ use_language_model=use_language_model,
84
+ **language_model_params,
85
+ )
86
+
87
+ with NamedTemporaryFile() as pred_stdout, NamedTemporaryFile() as img_list:
88
+ image_id = uuid4()
89
+ input_img = input_img.resize((int(get_width(input_img)), DEFAULT_HEIGHT))
90
+ input_img.save(f"{images}/{image_id}.jpg")
91
+ Path(img_list.name).write_text("\n".join([str(image_id)]))
92
+
93
+ with redirect_stdout(open(pred_stdout.name, mode="w")):
94
+ decode(
95
+ syms=str(syms_path),
96
+ img_list=img_list.name,
97
+ img_dirs=[str(images)],
98
+ common=common_args,
99
+ data=data_args,
100
+ trainer=trainer_args,
101
+ decode=decode_args,
102
+ num_workers=1,
103
+ )
104
+ sys.stdout.flush()
105
+ predictions = Path(pred_stdout.name).read_text().strip().splitlines()
106
+
107
+ _, score, text = LINE_PREDICTION.match(predictions[0]).groups()
108
+ if TEXT_DIRECTION == "RTL":
109
+ return input_img, {"text": get_display(text), "score": score}
110
+ else:
111
+ return input_img, {"text": text, "score": score}
112
+
113
+ def process_image(image):
114
+ # Perform inference on an image, select textline only
115
+ results = model(image, classes=1)
116
+
117
+ img_cv2 = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
118
+ boxes = results[0].boxes.xyxy.tolist()
119
+ boxes.sort(key=lambda x: x[1])
120
+
121
+ bboxes = []
122
+ polygons = []
123
+ texts = []
124
+
125
+ for i, box in enumerate(boxes):
126
+ x1, y1, x2, y2 = map(int, box)
127
+ crop_img = img_cv2[y1:y2, x1:x2]
128
+ crop_pil = Image.fromarray(cv2.cvtColor(crop_img, cv2.COLOR_BGR2RGB))
129
+
130
+ # Recognize text using PyLaia model
131
+ predicted = predict(models_name, crop_pil)
132
+ texts.append(predicted[1]["text"])
133
+
134
+ bboxes.append((x1, y1, x2, y2))
135
+ polygons.append(f"Line {i+1}: {[(x1, y1), (x2, y1), (x2, y2), (x1, y2)]}")
136
+
137
+ # Draw bounding box
138
+ cv2.rectangle(img_cv2, (x1, y1), (x2, y2), (0, 255, 0), 2)
139
+
140
+ # Convert image back to RGB for display in Streamlit
141
+ img_result = cv2.cvtColor(img_cv2, cv2.COLOR_BGR2RGB)
142
+
143
+ # Combine polygons and texts into a DataFrame for table display
144
+ table_data = pd.DataFrame({"Polygons": polygons, "Recognized Text": texts})
145
+ return Image.fromarray(img_result), table_data
146
+
147
+ def segment_and_recognize(image):
148
+ segmented_image, table_data = process_image(image)
149
+ return segmented_image, table_data
150
+
151
+ # Streamlit app layout
152
+ st.title("YOLOv8 Text Line Segmentation & PyLaia Text Recognition")
153
+
154
+ # File uploader
155
+ uploaded_image = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg"])
156
+
157
+ # Process the image if uploaded
158
+ if uploaded_image is not None:
159
+ image = Image.open(uploaded_image)
160
+
161
+ if st.button("Segment and Recognize"):
162
+ # Perform segmentation and recognition
163
+ segmented_image, table_data = segment_and_recognize(image)
164
+
165
+ # Display the segmented image
166
+ st.image(segmented_image, caption="Segmented Image with Bounding Boxes", use_column_width=True)
167
+
168
+ # Display the table with polygons and recognized text
 
169
  st.table(table_data)
model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:02a10695831265b2821a267e5a239e78eeaae8e2865c57bf0c2c06cabe2e68be
3
+ size 54827221