ketanmore commited on
Commit
dd31650
1 Parent(s): 2720487

Delete surya_yolo_pipeline.py

Browse files
Files changed (1) hide show
  1. surya_yolo_pipeline.py +0 -169
surya_yolo_pipeline.py DELETED
@@ -1,169 +0,0 @@
1
- import cv2
2
- import supervision as sv # pip install supervision
3
- from ultralytics import YOLO
4
- import numpy as np
5
- import matplotlib.pyplot as plt
6
-
7
- yolo_model = YOLO('yolov10x_best.pt')
8
-
9
-
10
- from surya.model.detection.segformer import load_processor , load_model
11
- import torch
12
- import os
13
-
14
-
15
- from surya.model.detection.segformer import load_processor , load_model
16
- import torch
17
- import os
18
- # os.environ['HF_HOME'] = '/share/data/drive_3/ketan/orc/HF_Cache'
19
-
20
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
- model = load_model("vikp/surya_layout2").to(device)
22
-
23
-
24
- from PIL import Image
25
- from surya.input.processing import prepare_image_detection
26
-
27
-
28
- def predicted_mask_function(image_path) :
29
-
30
- img = Image.open(image_path)
31
- img = [prepare_image_detection(img=img, processor=load_processor())]
32
- img = torch.stack(img, dim=0).to(model.dtype).to(model.device)
33
- logits = model(img).logits
34
-
35
- predicted_mask = torch.argmax(logits[0], dim=0).cpu().numpy()
36
-
37
- return predicted_mask
38
-
39
-
40
-
41
- def predict_boxes_labels(image_path):
42
- results = yolo_model(source=image_path, conf=0.2, iou=0.8)[0]
43
- detections = sv.Detections.from_ultralytics(results)
44
- labels = detections.data["class_name"].tolist()
45
- bboxes = detections.xyxy.tolist()
46
- return bboxes,labels
47
-
48
-
49
-
50
- def resize_segment(mask, class_id, target_size, method=cv2.INTER_AREA):
51
- # Create a binary mask for the current class
52
- class_mask = np.where(mask == class_id, 1, 0).astype(np.uint8)
53
-
54
- # Resize the class mask to the target size
55
- resized_class_mask = cv2.resize(class_mask, (target_size[1], target_size[0]), interpolation=method)
56
-
57
- return resized_class_mask
58
-
59
- def resize_and_combine_classes(mask, target_size, method=cv2.INTER_AREA):
60
- unique_classes = np.unique(mask)
61
-
62
- # Initialize a zero-filled mask for the combined result with the correct target size
63
- resized_masks = np.zeros((target_size[0], target_size[1]), dtype=np.uint8)
64
-
65
- # Process each class found in the mask
66
- for class_id in unique_classes:
67
- resized_class_mask = resize_segment(mask, class_id, target_size, method)
68
-
69
- # Assign the class ID to the resized output mask where the resized class mask is 1
70
- resized_masks[resized_class_mask == 1] = class_id
71
-
72
- return resized_masks
73
-
74
-
75
- class_labels = {
76
- 0: 'Blank',
77
- 1: 'Caption',
78
- 2: 'Footnote',
79
- 3: 'Formula',
80
- 4: 'List-item',
81
- 5: 'Page-footer',
82
- 6: 'Page-header',
83
- 7: 'Picture',
84
- 8: 'Section-header',
85
- 9: 'Table',
86
- 10: 'Text',
87
- 11: 'Title'
88
- }
89
-
90
- colors = plt.cm.get_cmap('tab20', len(class_labels))
91
-
92
- def colormap_to_rgb(cmap, index):
93
- color = cmap(index)[:3] # Extract RGB, ignore alpha
94
- return tuple(int(c * 255) for c in color)
95
-
96
- def mask_to_bboxes(colored_mask, class_labels):
97
- bboxes = []
98
-
99
- # Loop through each class in the class_labels
100
- for label, class_name in class_labels.items():
101
- # Get the RGB color for the current label
102
- color = colormap_to_rgb(colors, label)
103
-
104
- # Create a binary mask for the current label by checking where the colored mask matches the class color
105
- class_mask = np.all(colored_mask == color, axis=-1).astype(np.uint8)
106
-
107
- # Find contours of the class region in the binary mask
108
- contours, _ = cv2.findContours(class_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
109
-
110
- # Loop through all contours and extract bounding boxes
111
- for contour in contours:
112
- # Get the bounding box for the contour (in xywh format)
113
- x, y, w, h = cv2.boundingRect(contour)
114
-
115
- # Convert to xyxy format: (xmin, ymin, xmax, ymax)
116
- xmin, ymin, xmax, ymax = x, y, x + w, y + h
117
-
118
- # Append the bounding box with the corresponding class label
119
- bboxes.append((xmin, ymin, xmax, ymax))
120
- # bboxes.append((xmin, ymin, xmax, ymax, class_name))
121
-
122
- return bboxes
123
-
124
-
125
-
126
- import matplotlib.pyplot as plt
127
- # from matplotlib import colors
128
-
129
- def suryolo(image_path) :
130
-
131
- image = Image.open(image_path)
132
- L, W = image.size
133
-
134
-
135
- predicted_mask = predicted_mask_function(image_path)
136
-
137
- colored_mask = np.zeros((W, L, 3), dtype=np.uint8) # 3 channels for RGB
138
-
139
- label_name_to_int = {v: k for k, v in class_labels.items()}
140
-
141
- colors = plt.cm.get_cmap('tab20', len(class_labels))
142
-
143
- bboxes,labels = predict_boxes_labels(image_path)
144
-
145
- for box, label in zip(bboxes, labels): # Assuming labels list corresponds to bboxes
146
- xmin, ymin, xmax, ymax = box
147
- xmin, ymin, xmax, ymax = int(xmin), int(ymin), int(xmax), int(ymax)
148
-
149
- # Resize predicted mask to match the image dimensions (W = width, L = height)
150
- predicted_mask = resize_and_combine_classes(predicted_mask, (W, L))
151
-
152
- # Extract the mask region within the bounding box
153
- mask_region = predicted_mask[ymin:ymax, xmin:xmax]
154
-
155
- # Get the corresponding integer index for the label
156
- label_index = label_name_to_int[label]
157
-
158
- # Get the corresponding color for the label using the colormap
159
- color = colormap_to_rgb(colors, label_index)
160
-
161
- # Apply the color to the regions where mask_region > 0.5
162
- colored_mask[ymin:ymax, xmin:xmax][mask_region > 0.5] = color
163
-
164
- blank_color = colormap_to_rgb(colors, 0)
165
- colored_mask[(colored_mask == 0).all(axis=-1)] = blank_color
166
-
167
- return mask_to_bboxes(colored_mask,class_labels)
168
-
169
-