nkanungo commited on
Commit
e3f3c34
·
1 Parent(s): cd099ce

Upload 15 files

Browse files
PASCAL_VOC/1.txt ADDED
File without changes
PASCAL_VOC/100examples.csv ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ image,text
2
+ 000007.jpg,000007.txt
3
+ 000026.jpg,000026.txt
4
+ 000032.jpg,000032.txt
5
+ 000033.jpg,000033.txt
6
+ 000034.jpg,000034.txt
7
+ 000035.jpg,000035.txt
8
+ 000036.jpg,000036.txt
9
+ 000042.jpg,000042.txt
10
+ 000089.jpg,000089.txt
11
+ 000091.jpg,000091.txt
12
+ 000104.jpg,000104.txt
13
+ 000112.jpg,000112.txt
14
+ 000122.jpg,000122.txt
15
+ 000129.jpg,000129.txt
16
+ 000133.jpg,000133.txt
17
+ 000134.jpg,000134.txt
18
+ 000187.jpg,000187.txt
19
+ 000189.jpg,000189.txt
20
+ 000192.jpg,000192.txt
21
+ 000193.jpg,000193.txt
22
+ 000194.jpg,000194.txt
23
+ 000198.jpg,000198.txt
24
+ 000200.jpg,000200.txt
25
+ 000207.jpg,000207.txt
26
+ 000209.jpg,000209.txt
27
+ 000164.jpg,000164.txt
28
+ 000171.jpg,000171.txt
29
+ 000173.jpg,000173.txt
30
+ 000174.jpg,000174.txt
31
+ 000187.jpg,000187.txt
32
+ 000189.jpg,000189.txt
33
+ 000192.jpg,000192.txt
34
+ 000193.jpg,000193.txt
35
+ 000194.jpg,000194.txt
36
+ 000198.jpg,000198.txt
37
+ 000200.jpg,000200.txt
38
+ 000207.jpg,000207.txt
39
+ 000209.jpg,000209.txt
40
+ 000219.jpg,000219.txt
41
+ 000220.jpg,000220.txt
42
+ 000222.jpg,000222.txt
43
+ 000225.jpg,000225.txt
44
+ 000228.jpg,000228.txt
45
+ 000235.jpg,000235.txt
46
+ 000242.jpg,000242.txt
47
+ 000250.jpg,000250.txt
48
+ 000256.jpg,000256.txt
49
+ 000259.jpg,000259.txt
50
+ 000262.jpg,000262.txt
51
+ 000263.jpg,000263.txt
52
+ 000276.jpg,000276.txt
53
+ 000278.jpg,000278.txt
54
+ 000282.jpg,000282.txt
55
+ 000391.jpg,000391.txt
56
+ 000394.jpg,000394.txt
57
+ 000395.jpg,000395.txt
58
+ 000400.jpg,000400.txt
59
+ 000404.jpg,000404.txt
60
+ 000406.jpg,000406.txt
61
+ 000407.jpg,000407.txt
62
+ 000411.jpg,000411.txt
63
+ 000416.jpg,000416.txt
64
+ 000430.jpg,000430.txt
65
+ 000431.jpg,000431.txt
66
+ 000438.jpg,000438.txt
67
+ 000446.jpg,000446.txt
68
+ 000450.jpg,000450.txt
69
+ 000454.jpg,000454.txt
70
+ 000463.jpg,000463.txt
71
+ 000468.jpg,000468.txt
72
+ 000469.jpg,000469.txt
73
+ 000470.jpg,000470.txt
74
+ 000474.jpg,000474.txt
75
+ 000476.jpg,000476.txt
76
+ 000477.jpg,000477.txt
77
+ 000484.jpg,000484.txt
78
+ 000489.jpg,000489.txt
79
+ 000496.jpg,000496.txt
80
+ 000503.jpg,000503.txt
81
+ 000508.jpg,000508.txt
82
+ 000516.jpg,000516.txt
83
+ 000518.jpg,000518.txt
84
+ 000519.jpg,000519.txt
85
+ 000522.jpg,000522.txt
86
+ 000524.jpg,000524.txt
87
+ 000525.jpg,000525.txt
88
+ 000552.jpg,000552.txt
89
+ 000554.jpg,000554.txt
90
+ 000555.jpg,000555.txt
91
+ 000559.jpg,000559.txt
92
+ 000565.jpg,000565.txt
93
+ 000577.jpg,000577.txt
94
+ 000583.jpg,000583.txt
95
+ 000589.jpg,000589.txt
96
+ 000590.jpg,000590.txt
97
+ 000592.jpg,000592.txt
98
+ 000597.jpg,000597.txt
99
+ 000605.jpg,000605.txt
100
+ 000609.jpg,000609.txt
101
+ 000612.jpg,000612.txt
102
+ 000620.jpg,000620.txt
103
+ 000622.jpg,000622.txt
104
+ 000625.jpg,000625.txt
PASCAL_VOC/1examples.csv ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ img,label
2
+ 000007.jpg,000007.txt
PASCAL_VOC/2examples.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ img,label
2
+ 000007.jpg,000007.txt
3
+ 000009.jpg,000009.txt
PASCAL_VOC/8examples.csv ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ img,label
2
+ 000007.jpg,000007.txt
3
+ 000009.jpg,000009.txt
4
+ 000016.jpg,000016.txt
5
+ 000019.jpg,000019.txt
6
+ 000020.jpg,000020.txt
7
+ 000021.jpg,000021.txt
8
+ 000122.jpg,000122.txt
9
+ 000129.jpg,000129.txt
PASCAL_VOC/test.csv ADDED
The diff for this file is too large to render. See raw diff
 
PASCAL_VOC/train.csv ADDED
The diff for this file is too large to render. See raw diff
 
src/config.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import albumentations as A
2
+ import cv2
3
+ import torch
4
+
5
+ from albumentations.pytorch import ToTensorV2
6
+
7
+
8
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
9
+
10
+ DATASET = 'PASCAL_VOC'
11
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
12
+ # seed_everything() # If you want deterministic behavior
13
+ NUM_WORKERS = 0
14
+ BATCH_SIZE = 32
15
+ IMAGE_SIZE = 416
16
+ NUM_CLASSES = 20
17
+ LEARNING_RATE = 1e-5
18
+ WEIGHT_DECAY = 1e-4
19
+ NUM_EPOCHS = 100
20
+ CONF_THRESHOLD = 0.05
21
+ MAP_IOU_THRESH = 0.5
22
+ NMS_IOU_THRESH = 0.45
23
+ S = [IMAGE_SIZE // 32, IMAGE_SIZE // 16, IMAGE_SIZE // 8]
24
+ PIN_MEMORY = True
25
+ LOAD_MODEL = False
26
+ SAVE_MODEL = True
27
+ CHECKPOINT_FILE = "checkpoint.pth.tar"
28
+ IMG_DIR = DATASET + "/images/"
29
+ LABEL_DIR = DATASET + "/labels/"
30
+
31
+ means = [0.485, 0.456, 0.406]
32
+
33
+ scale = 1.1
34
+
35
+
36
+ train_transforms = A.Compose(
37
+ [
38
+ A.LongestMaxSize(max_size=int(IMAGE_SIZE * scale)),
39
+ A.PadIfNeeded(
40
+ min_height=int(IMAGE_SIZE * scale),
41
+ min_width=int(IMAGE_SIZE * scale),
42
+ border_mode=cv2.BORDER_CONSTANT,
43
+ ),
44
+ A.Rotate(limit = 10, interpolation=1, border_mode=4),
45
+ A.RandomCrop(width=IMAGE_SIZE, height=IMAGE_SIZE),
46
+ A.ColorJitter(brightness=0.6, contrast=0.6, saturation=0.6, hue=0.6, p=0.4),
47
+ A.OneOf(
48
+ [
49
+ A.ShiftScaleRotate(
50
+ rotate_limit=20, p=0.5, border_mode=cv2.BORDER_CONSTANT
51
+ ),
52
+ # A.Affine(shear=15, p=0.5, mode="constant"),
53
+ ],
54
+ p=1.0,
55
+ ),
56
+ A.HorizontalFlip(p=0.5),
57
+ A.Blur(p=0.1),
58
+ A.CLAHE(p=0.1),
59
+ A.Posterize(p=0.1),
60
+ A.ToGray(p=0.1),
61
+ A.ChannelShuffle(p=0.05),
62
+ A.Normalize(mean=[0, 0, 0], std=[1, 1, 1], max_pixel_value=255,),
63
+ ToTensorV2(),
64
+ ],
65
+ bbox_params=A.BboxParams(format="yolo", min_visibility=0.4, label_fields=[],),
66
+ )
67
+ test_transforms = A.Compose(
68
+ [
69
+ A.LongestMaxSize(max_size=IMAGE_SIZE),
70
+ A.PadIfNeeded(
71
+ min_height=IMAGE_SIZE, min_width=IMAGE_SIZE, border_mode=cv2.BORDER_CONSTANT
72
+ ),
73
+ A.Normalize(mean=[0, 0, 0], std=[1, 1, 1], max_pixel_value=255,),
74
+ ToTensorV2(),
75
+ ],
76
+ bbox_params=A.BboxParams(format="yolo", min_visibility=0.4, label_fields=[]),
77
+ )
78
+
79
+
80
+
81
+ IMAGE_SIZE = 416
82
+ transforms = A.Compose(
83
+ [
84
+ A.LongestMaxSize(max_size=IMAGE_SIZE),
85
+ A.PadIfNeeded(
86
+ min_height=IMAGE_SIZE, min_width=IMAGE_SIZE, border_mode=cv2.BORDER_CONSTANT
87
+ ),
88
+ A.Normalize(mean=[0, 0, 0], std=[1, 1, 1], max_pixel_value=255,),
89
+ ToTensorV2(),
90
+ ],
91
+ )
92
+ ANCHORS = [
93
+ [(0.28, 0.22), (0.38, 0.48), (0.9, 0.78)],
94
+ [(0.07, 0.15), (0.15, 0.11), (0.14, 0.29)],
95
+ [(0.02, 0.03), (0.04, 0.07), (0.08, 0.06)],
96
+ ] # Note these have been rescaled to be between [0, 1]
97
+ S = [IMAGE_SIZE // 32, IMAGE_SIZE // 16, IMAGE_SIZE // 8]
98
+
99
+ PASCAL_CLASSES = [
100
+ "aeroplane",
101
+ "bicycle",
102
+ "bird",
103
+ "boat",
104
+ "bottle",
105
+ "bus",
106
+ "car",
107
+ "cat",
108
+ "chair",
109
+ "cow",
110
+ "diningtable",
111
+ "dog",
112
+ "horse",
113
+ "motorbike",
114
+ "person",
115
+ "pottedplant",
116
+ "sheep",
117
+ "sofa",
118
+ "train",
119
+ "tvmonitor"
120
+ ]
src/dataset.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Creates a Pytorch dataset to load the Pascal VOC & MS COCO datasets
3
+ """
4
+
5
+ import src.config as config
6
+ import numpy as np
7
+ import os
8
+ import pandas as pd
9
+ import torch
10
+ from src.utils_rh import xywhn2xyxy, xyxy2xywhn
11
+ import random
12
+
13
+ from PIL import Image, ImageFile
14
+ from torch.utils.data import Dataset, DataLoader
15
+ from src.utils_rh import (
16
+ cells_to_bboxes,
17
+ iou_width_height as iou,
18
+ non_max_suppression as nms,
19
+ plot_image
20
+ )
21
+
22
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
23
+
24
+ class YOLODataset(Dataset):
25
+ def __init__(
26
+ self,
27
+ csv_file,
28
+ img_dir,
29
+ label_dir,
30
+ anchors,
31
+ image_size=416,
32
+ S=[13, 26, 52],
33
+ C=20,
34
+ transform=None,
35
+ ):
36
+ self.annotations = pd.read_csv(csv_file)
37
+ self.img_dir = img_dir
38
+ self.label_dir = label_dir
39
+ self.image_size = image_size
40
+ self.mosaic_border = [image_size // 2, image_size // 2]
41
+ self.transform = transform
42
+ self.S = S
43
+ self.anchors = torch.tensor(anchors[0] + anchors[1] + anchors[2]) # for all 3 scales
44
+ self.num_anchors = self.anchors.shape[0]
45
+ self.num_anchors_per_scale = self.num_anchors // 3
46
+ self.C = C
47
+ self.ignore_iou_thresh = 0.5
48
+
49
+ def __len__(self):
50
+ return len(self.annotations)
51
+
52
+ def load_mosaic(self, index):
53
+ # YOLOv5 4-mosaic loader. Loads 1 image + 3 random images into a 4-image mosaic
54
+ labels4 = []
55
+ s = self.image_size
56
+ yc, xc = (int(random.uniform(x, 2 * s - x)) for x in self.mosaic_border) # mosaic center x, y
57
+ indices = [index] + random.choices(range(len(self)), k=3) # 3 additional image indices
58
+ random.shuffle(indices)
59
+ for i, index in enumerate(indices):
60
+ # Load image
61
+ label_path = os.path.join(self.label_dir, self.annotations.iloc[index, 1])
62
+ bboxes = np.roll(np.loadtxt(fname=label_path, delimiter=" ", ndmin=2), 4, axis=1).tolist()
63
+ img_path = os.path.join(self.img_dir, self.annotations.iloc[index, 0])
64
+ img = np.array(Image.open(img_path).convert("RGB"))
65
+
66
+
67
+ h, w = img.shape[0], img.shape[1]
68
+ labels = np.array(bboxes)
69
+
70
+ # place img in img4
71
+ if i == 0: # top left
72
+ img4 = np.full((s * 2, s * 2, img.shape[2]), 114, dtype=np.uint8) # base image with 4 tiles
73
+ x1a, y1a, x2a, y2a = max(xc - w, 0), max(yc - h, 0), xc, yc # xmin, ymin, xmax, ymax (large image)
74
+ x1b, y1b, x2b, y2b = w - (x2a - x1a), h - (y2a - y1a), w, h # xmin, ymin, xmax, ymax (small image)
75
+ elif i == 1: # top right
76
+ x1a, y1a, x2a, y2a = xc, max(yc - h, 0), min(xc + w, s * 2), yc
77
+ x1b, y1b, x2b, y2b = 0, h - (y2a - y1a), min(w, x2a - x1a), h
78
+ elif i == 2: # bottom left
79
+ x1a, y1a, x2a, y2a = max(xc - w, 0), yc, xc, min(s * 2, yc + h)
80
+ x1b, y1b, x2b, y2b = w - (x2a - x1a), 0, w, min(y2a - y1a, h)
81
+ elif i == 3: # bottom right
82
+ x1a, y1a, x2a, y2a = xc, yc, min(xc + w, s * 2), min(s * 2, yc + h)
83
+ x1b, y1b, x2b, y2b = 0, 0, min(w, x2a - x1a), min(y2a - y1a, h)
84
+
85
+ img4[y1a:y2a, x1a:x2a] = img[y1b:y2b, x1b:x2b] # img4[ymin:ymax, xmin:xmax]
86
+ padw = x1a - x1b
87
+ padh = y1a - y1b
88
+
89
+ # Labels
90
+ if labels.size:
91
+ labels[:, :-1] = xywhn2xyxy(labels[:, :-1], w, h, padw, padh) # normalized xywh to pixel xyxy format
92
+ labels4.append(labels)
93
+
94
+ # Concat/clip labels
95
+ labels4 = np.concatenate(labels4, 0)
96
+ for x in (labels4[:, :-1],):
97
+ np.clip(x, 0, 2 * s, out=x) # clip when using random_perspective()
98
+ # img4, labels4 = replicate(img4, labels4) # replicate
99
+ labels4[:, :-1] = xyxy2xywhn(labels4[:, :-1], 2 * s, 2 * s)
100
+ labels4[:, :-1] = np.clip(labels4[:, :-1], 0, 1)
101
+ labels4 = labels4[labels4[:, 2] > 0]
102
+ labels4 = labels4[labels4[:, 3] > 0]
103
+ return img4, labels4
104
+
105
+ def __getitem__(self, index):
106
+
107
+ image, bboxes = self.load_mosaic(index)
108
+
109
+ if self.transform:
110
+ augmentations = self.transform(image=image, bboxes=bboxes)
111
+ image = augmentations["image"]
112
+ bboxes = augmentations["bboxes"]
113
+
114
+ # Below assumes 3 scale predictions (as paper) and same num of anchors per scale
115
+ targets = [torch.zeros((self.num_anchors // 3, S, S, 6)) for S in self.S]
116
+ for box in bboxes:
117
+ iou_anchors = iou(torch.tensor(box[2:4]), self.anchors)
118
+ anchor_indices = iou_anchors.argsort(descending=True, dim=0)
119
+ x, y, width, height, class_label = box
120
+ has_anchor = [False] * 3 # each scale should have one anchor
121
+ for anchor_idx in anchor_indices:
122
+ scale_idx = anchor_idx // self.num_anchors_per_scale
123
+ anchor_on_scale = anchor_idx % self.num_anchors_per_scale
124
+ S = self.S[scale_idx]
125
+ i, j = int(S * y), int(S * x) # which cell
126
+ anchor_taken = targets[scale_idx][anchor_on_scale, i, j, 0]
127
+ if not anchor_taken and not has_anchor[scale_idx]:
128
+ targets[scale_idx][anchor_on_scale, i, j, 0] = 1
129
+ x_cell, y_cell = S * x - j, S * y - i # both between [0,1]
130
+ width_cell, height_cell = (
131
+ width * S,
132
+ height * S,
133
+ ) # can be greater than 1 since it's relative to cell
134
+ box_coordinates = torch.tensor(
135
+ [x_cell, y_cell, width_cell, height_cell]
136
+ )
137
+ targets[scale_idx][anchor_on_scale, i, j, 1:5] = box_coordinates
138
+ targets[scale_idx][anchor_on_scale, i, j, 5] = int(class_label)
139
+ has_anchor[scale_idx] = True
140
+
141
+ elif not anchor_taken and iou_anchors[anchor_idx] > self.ignore_iou_thresh:
142
+ targets[scale_idx][anchor_on_scale, i, j, 0] = -1 # ignore prediction
143
+
144
+ return image, tuple(targets)
145
+
146
+
147
+ def test():
148
+ anchors = config.ANCHORS
149
+
150
+ transform = config.test_transforms
151
+
152
+ dataset = YOLODataset(
153
+ "COCO/train.csv",
154
+ "COCO/images/images/",
155
+ "COCO/labels/labels_new/",
156
+ S=[13, 26, 52],
157
+ anchors=anchors,
158
+ transform=transform,
159
+ )
160
+ S = [13, 26, 52]
161
+ scaled_anchors = torch.tensor(anchors) / (
162
+ 1 / torch.tensor(S).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2)
163
+ )
164
+ loader = DataLoader(dataset=dataset, batch_size=1, shuffle=True)
165
+ for x, y in loader:
166
+ boxes = []
167
+
168
+ for i in range(y[0].shape[1]):
169
+ anchor = scaled_anchors[i]
170
+ print(anchor.shape)
171
+ print(y[i].shape)
172
+ boxes += cells_to_bboxes(
173
+ y[i], is_preds=False, S=y[i].shape[2], anchors=anchor
174
+ )[0]
175
+ boxes = nms(boxes, iou_threshold=1, threshold=0.7, box_format="midpoint")
176
+ print(boxes)
177
+ plot_image(x[0].permute(1, 2, 0).to("cpu"), boxes)
178
+
179
+
180
+ if __name__ == "__main__":
181
+ test()
src/detect.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ import cv2
3
+ import torch
4
+ import numpy as np
5
+ import src.config as config
6
+ from pytorch_grad_cam.utils.image import show_cam_on_image
7
+
8
+ from src.model_obj import Assignment13
9
+ from src.utils import cells_to_bboxes, non_max_suppression, draw_predictions, YoloCAM
10
+
11
+
12
+
13
+
14
+ weights_path = "/home/user/app/model_ass_13_up.ckpt"
15
+ model = Assignment13().load_from_checkpoint(weights_path,map_location=torch.device("cpu"))
16
+ model = model.model
17
+ #ckpt = torch.load(weights_path, map_location="cpu")
18
+ #model.load_state_dict(ckpt)
19
+ model.eval()
20
+ print("[x] Model Loaded..")
21
+
22
+ scaled_anchors = (
23
+ torch.tensor(config.ANCHORS)
24
+ * torch.tensor(config.S).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2)
25
+ ).to(config.DEVICE)
26
+
27
+ cam = YoloCAM(model=model, target_layers=[model.layers[-2]], use_cuda=False)
28
+
29
+ def predict(image: np.ndarray, iou_thresh: float = 0.5, thresh: float = 0.4, show_cam: bool = False, transparency: float = 0.5) -> List[np.ndarray]:
30
+ with torch.no_grad():
31
+ transformed_image = config.transforms(image=image)["image"].unsqueeze(0)
32
+ output = model(transformed_image)
33
+
34
+ bboxes = [[] for _ in range(1)]
35
+ for i in range(3):
36
+ batch_size, A, S, _, _ = output[i].shape
37
+ anchor = scaled_anchors[i]
38
+ boxes_scale_i = cells_to_bboxes(
39
+ output[i], anchor, S=S, is_preds=True
40
+ )
41
+ for idx, (box) in enumerate(boxes_scale_i):
42
+ bboxes[idx] += box
43
+
44
+ nms_boxes = non_max_suppression(
45
+ bboxes[0], iou_threshold=iou_thresh, threshold=thresh, box_format="midpoint",
46
+ )
47
+ plot_img = draw_predictions(image, nms_boxes, class_labels=config.PASCAL_CLASSES)
48
+ if not show_cam:
49
+ return [plot_img]
50
+
51
+ grayscale_cam = cam(transformed_image, scaled_anchors)[0, :, :]
52
+ img = cv2.resize(image, (416, 416))
53
+ img = np.float32(img) / 255
54
+ cam_image = show_cam_on_image(img, grayscale_cam, use_rgb=True, image_weight=transparency)
55
+ return [plot_img, cam_image]
56
+
57
+
src/loss.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Implementation of Yolo Loss Function similar to the one in Yolov3 paper,
3
+ the difference from what I can tell is I use CrossEntropy for the classes
4
+ instead of BinaryCrossEntropy.
5
+ """
6
+ import random
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+ from src.utils_rh import intersection_over_union
11
+
12
+
13
+ class YoloLoss(nn.Module):
14
+ def __init__(self):
15
+ super().__init__()
16
+ self.mse = nn.MSELoss()
17
+ self.bce = nn.BCEWithLogitsLoss()
18
+ self.entropy = nn.CrossEntropyLoss()
19
+ self.sigmoid = nn.Sigmoid()
20
+
21
+ # Constants signifying how much to pay for each respective part of the loss
22
+ self.lambda_class = 1
23
+ self.lambda_noobj = 10
24
+ self.lambda_obj = 1
25
+ self.lambda_box = 10
26
+
27
+ def forward(self, predictions, target, anchors):
28
+ # Check where obj and noobj (we ignore if target == -1)
29
+ obj = target[..., 0] == 1 # in paper this is Iobj_i
30
+ noobj = target[..., 0] == 0 # in paper this is Inoobj_i
31
+
32
+ # ======================= #
33
+ # FOR NO OBJECT LOSS #
34
+ # ======================= #
35
+
36
+ no_object_loss = self.bce(
37
+ (predictions[..., 0:1][noobj]), (target[..., 0:1][noobj]),
38
+ )
39
+
40
+ # ==================== #
41
+ # FOR OBJECT LOSS #
42
+ # ==================== #
43
+
44
+ anchors = anchors.reshape(1, 3, 1, 1, 2)
45
+ box_preds = torch.cat([self.sigmoid(predictions[..., 1:3]), torch.exp(predictions[..., 3:5]) * anchors], dim=-1)
46
+ ious = intersection_over_union(box_preds[obj], target[..., 1:5][obj]).detach()
47
+ object_loss = self.mse(self.sigmoid(predictions[..., 0:1][obj]), ious * target[..., 0:1][obj])
48
+
49
+ # ======================== #
50
+ # FOR BOX COORDINATES #
51
+ # ======================== #
52
+
53
+ predictions[..., 1:3] = self.sigmoid(predictions[..., 1:3]) # x,y coordinates
54
+ target[..., 3:5] = torch.log(
55
+ (1e-16 + target[..., 3:5] / anchors)
56
+ ) # width, height coordinates
57
+ box_loss = self.mse(predictions[..., 1:5][obj], target[..., 1:5][obj])
58
+
59
+ # ================== #
60
+ # FOR CLASS LOSS #
61
+ # ================== #
62
+
63
+ class_loss = self.entropy(
64
+ (predictions[..., 5:][obj]), (target[..., 5][obj].long()),
65
+ )
66
+
67
+ #print("__________________________________")
68
+ #print(self.lambda_box * box_loss)
69
+ #print(self.lambda_obj * object_loss)
70
+ #print(self.lambda_noobj * no_object_loss)
71
+ #print(self.lambda_class * class_loss)
72
+ #print("\n")
73
+
74
+ return (
75
+ self.lambda_box * box_loss
76
+ + self.lambda_obj * object_loss
77
+ + self.lambda_noobj * no_object_loss
78
+ + self.lambda_class * class_loss
79
+ )
src/model_obj.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """Training_with_lr_A4000.ipynb
3
+
4
+ Automatically generated by Colaboratory.
5
+
6
+ Original file is located at
7
+ https://colab.research.google.com/drive/1gtGnp5dp_W4rB-Kz7uMsbPGYUuNPmVtu
8
+ """
9
+
10
+
11
+
12
+
13
+ import torch
14
+ import torch.optim as optim
15
+
16
+ from src.model_yolov3 import YOLOv3
17
+ from tqdm import tqdm
18
+ import src.config as config
19
+ from src.utils_rh import (
20
+ mean_average_precision,
21
+ cells_to_bboxes,
22
+ get_evaluation_bboxes,
23
+ save_checkpoint,
24
+ load_checkpoint,
25
+ check_class_accuracy,
26
+ get_loaders,
27
+ plot_couple_examples
28
+ )
29
+ from src.loss import YoloLoss
30
+ import warnings
31
+ warnings.filterwarnings("ignore")
32
+
33
+ import torch
34
+ from pytorch_lightning import LightningModule, Trainer
35
+ from torch import nn
36
+ from torch.nn import functional as F
37
+ from torch.utils.data import DataLoader, random_split
38
+ import torchvision
39
+ from pytorch_lightning.callbacks import LearningRateMonitor
40
+ from pytorch_lightning.callbacks.progress import TQDMProgressBar
41
+ from pytorch_lightning.loggers import CSVLogger
42
+ from pytorch_lightning.callbacks import ModelCheckpoint
43
+ import pandas as pd
44
+ from torch.optim.lr_scheduler import OneCycleLR
45
+
46
+ import seaborn as sn
47
+
48
+ class Assignment13(LightningModule):
49
+ def __init__(self):
50
+ super().__init__()
51
+ self.save_hyperparameters()
52
+ self.epoch_number = 0
53
+ self.config = config
54
+ self.train_csv_path = self.config.DATASET + "/train.csv"
55
+ self.test_csv_path = self.config.DATASET + "/test.csv"
56
+ self.train_loader, self.test_loader, self.train_eval_loader = get_loaders(
57
+ train_csv_path=self.train_csv_path, test_csv_path=self.test_csv_path)
58
+ self.check_class_accuracy = check_class_accuracy
59
+ self.model = YOLOv3(num_classes=self.config.NUM_CLASSES)
60
+ self.loss_fn = YoloLoss()
61
+ self.check_class_accuracy = check_class_accuracy
62
+ self.get_evaluation_bboxes = get_evaluation_bboxes
63
+ self.scaled_anchors = (torch.tensor(self.config.ANCHORS) * torch.tensor(self.config.S).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2))
64
+ self.losses = []
65
+ self.plot_couple_examples = plot_couple_examples
66
+ self.mean_average_precision = mean_average_precision
67
+ self.EPOCHS = self.config.NUM_EPOCHS * 2 // 5
68
+ def forward(self, x):
69
+ out = self.model(x)
70
+ return out
71
+ def training_step(self, batch, batch_idx):
72
+ x, y = batch
73
+ out = self(x)
74
+ y0, y1, y2 = (y[0],y[1],y[2])
75
+ loss = (
76
+ self.loss_fn(out[0], y0, self.scaled_anchors[0].to(y0))
77
+ + self.loss_fn(out[1], y1, self.scaled_anchors[1].to(y1))
78
+ + self.loss_fn(out[2], y2, self.scaled_anchors[2].to(y2))
79
+ )
80
+ self.losses.append(loss.item())
81
+ mean_loss = sum(self.losses) / len(self.losses)
82
+ self.log("train_loss", mean_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
83
+ #self.log("train_loss", mean_loss)
84
+ return loss
85
+
86
+
87
+ def on_train_epoch_start(self):
88
+ self.epoch_number += 1
89
+ self.losses = []
90
+ #self.plot_couple_examples(self.model,self.test_loader,0.6,0.5,self.scaled_anchors)
91
+ if self.epoch_number > 1 and self.epoch_number % 10 == 0:
92
+ self.plot_couple_examples(self.model,self.test_loader,0.6,0.5,self.scaled_anchors)
93
+
94
+ def on_train_epoch_end(self):
95
+ print(f"Currently epoch {self.epoch_number}")
96
+ print("On Train Eval loader:")
97
+ print("On Train loader:")
98
+ self.check_class_accuracy(self.model, self.train_loader, threshold=self.config.CONF_THRESHOLD)
99
+ if self.epoch_number == self.EPOCHS:
100
+ #if self.epoch_number > 1 and self.epoch_number % 3 == 0:
101
+ self.check_class_accuracy(self.model, self.test_loader, threshold=self.config.CONF_THRESHOLD)
102
+ pred_boxes, true_boxes = self.get_evaluation_bboxes( self.test_loader,self.model,iou_threshold=self.config.NMS_IOU_THRESH,
103
+ anchors=self.config.ANCHORS,
104
+ threshold=self.config.CONF_THRESHOLD,)
105
+ mapval = self.mean_average_precision(
106
+ pred_boxes,
107
+ true_boxes,
108
+ iou_threshold=self.config.MAP_IOU_THRESH,
109
+ box_format="midpoint",
110
+ num_classes=self.config.NUM_CLASSES,
111
+ )
112
+ print(f"MAP: {mapval.item()}")
113
+ self.model.train()
114
+ pass
115
+
116
+
117
+ def configure_optimizers(self):
118
+ optimizer = optimizer = optim.Adam(
119
+ model.parameters(), lr=config.LEARNING_RATE, weight_decay=config.WEIGHT_DECAY)
120
+ #EPOCHS = config.NUM_EPOCHS * 2 // 5
121
+ scheduler = OneCycleLR(
122
+ optimizer,
123
+ max_lr=1E-3,
124
+ steps_per_epoch=len(self.train_loader),
125
+ epochs=self.EPOCHS,
126
+ pct_start=5/self.EPOCHS,
127
+ div_factor=100,
128
+ three_phase=False,
129
+ final_div_factor=100,
130
+ anneal_strategy='linear'
131
+ )
132
+
133
+ return {"optimizer": optimizer, "lr_scheduler":scheduler}
134
+
135
+ ####################
136
+ # DATA RELATED HOOKS
137
+ ####################
138
+
139
+ def train_dataloader(self):
140
+ return self.train_loader
141
+
142
+ def test_dataloader(self):
143
+ return self.test_loader
144
+
145
+ #finding maximum learning rate
146
+ model = Assignment13()
147
+ #trainer = Trainer(precision=16,accelerator='cpu',callbacks=[TQDMProgressBar(refresh_rate=0)])
148
+
149
+ # Run learning rate finder
150
+ #lr_finder = trainer.tuner.lr_find(model,max_lr=2, num_training=200,mode="exponential")
151
+
152
+ # Inspect results
153
+ #fig = lr_finder.plot(); fig.show()
154
+ #suggested_lr = lr_finder.suggestion()
155
+ #print(suggested_lr)
156
+ # Overwrite lr and create new model
157
+ #hparams.lr = suggested_lr
158
+ #model = MyModelClass(hparams)
159
+
160
+ class Assignment13(LightningModule):
161
+ def __init__(self):
162
+ super().__init__()
163
+ self.save_hyperparameters()
164
+ self.epoch_number = 0
165
+ self.config = config
166
+ self.train_csv_path = self.config.DATASET + "/train.csv"
167
+ self.test_csv_path = self.config.DATASET + "/test.csv"
168
+ self.train_loader, self.test_loader, self.train_eval_loader = get_loaders(
169
+ train_csv_path=self.train_csv_path, test_csv_path=self.test_csv_path)
170
+ self.check_class_accuracy = check_class_accuracy
171
+ self.model = YOLOv3(num_classes=self.config.NUM_CLASSES)
172
+ self.loss_fn = YoloLoss()
173
+ self.check_class_accuracy = check_class_accuracy
174
+ self.get_evaluation_bboxes = get_evaluation_bboxes
175
+ self.scaled_anchors = (torch.tensor(self.config.ANCHORS) * torch.tensor(self.config.S).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2))
176
+ self.losses = []
177
+ self.plot_couple_examples = plot_couple_examples
178
+ self.mean_average_precision = mean_average_precision
179
+ self.EPOCHS = self.config.NUM_EPOCHS * 2 // 5
180
+ def forward(self, x):
181
+ out = self.model(x)
182
+ return out
183
+ def training_step(self, batch, batch_idx):
184
+ x, y = batch
185
+ out = self(x)
186
+ y0, y1, y2 = (y[0],y[1],y[2])
187
+ loss = (
188
+ self.loss_fn(out[0], y0, self.scaled_anchors[0].to(y0))
189
+ + self.loss_fn(out[1], y1, self.scaled_anchors[1].to(y1))
190
+ + self.loss_fn(out[2], y2, self.scaled_anchors[2].to(y2))
191
+ )
192
+ self.losses.append(loss.item())
193
+ mean_loss = sum(self.losses) / len(self.losses)
194
+ self.log("train_loss", mean_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
195
+ #self.log("train_loss", mean_loss)
196
+ return loss
197
+
198
+
199
+ def on_train_epoch_start(self):
200
+ self.epoch_number += 1
201
+ self.losses = []
202
+ #self.plot_couple_examples(self.model,self.test_loader,0.6,0.5,self.scaled_anchors)
203
+ if self.epoch_number > 1 and self.epoch_number % 10 == 0:
204
+ self.plot_couple_examples(self.model,self.test_loader,0.6,0.5,self.scaled_anchors)
205
+
206
+ def on_train_epoch_end(self):
207
+ print(f"Currently epoch {self.epoch_number}")
208
+ print("On Train Eval loader:")
209
+ print("On Train loader:")
210
+ self.check_class_accuracy(self.model, self.train_loader, threshold=self.config.CONF_THRESHOLD)
211
+ if self.epoch_number == self.EPOCHS:
212
+ #if self.epoch_number > 1 and self.epoch_number % 3 == 0:
213
+ self.check_class_accuracy(self.model, self.test_loader, threshold=self.config.CONF_THRESHOLD)
214
+ pred_boxes, true_boxes = self.get_evaluation_bboxes( self.test_loader,self.model,iou_threshold=self.config.NMS_IOU_THRESH,
215
+ anchors=self.config.ANCHORS,
216
+ threshold=self.config.CONF_THRESHOLD,)
217
+ mapval = self.mean_average_precision(
218
+ pred_boxes,
219
+ true_boxes,
220
+ iou_threshold=self.config.MAP_IOU_THRESH,
221
+ box_format="midpoint",
222
+ num_classes=self.config.NUM_CLASSES,
223
+ )
224
+ print(f"MAP: {mapval.item()}")
225
+ self.model.train()
226
+ pass
227
+
228
+
229
+ def configure_optimizers(self):
230
+ optimizer = optimizer = optim.Adam(
231
+ model.parameters(), lr=config.LEARNING_RATE, weight_decay=config.WEIGHT_DECAY)
232
+ #EPOCHS = config.NUM_EPOCHS * 2 // 5
233
+ scheduler = OneCycleLR(
234
+ optimizer,
235
+ max_lr=8E-4,
236
+ steps_per_epoch=len(self.train_loader),
237
+ epochs=self.EPOCHS,
238
+ pct_start=5/self.EPOCHS,
239
+ div_factor=100,
240
+ three_phase=False,
241
+ final_div_factor=100,
242
+ anneal_strategy='linear'
243
+ )
244
+
245
+ return {"optimizer": optimizer, "lr_scheduler":scheduler}
246
+
247
+ ####################
248
+ # DATA RELATED HOOKS
249
+ ####################
250
+
251
+ def train_dataloader(self):
252
+ return self.train_loader
253
+
254
+ def test_dataloader(self):
255
+ return self.test_loader
256
+
257
+ model = Assignment13()
258
+ checkpoint_callback = ModelCheckpoint(
259
+ monitor='train_loss', # Metric to monitor for saving the best model
260
+ mode='min', # 'min' to save the model with the lowest value of the monitored metric
261
+ dirpath='/storage/',
262
+ filename='assignment13_final{epoch:02d}-train_loss_min_A400{train_loss:.2f}',
263
+ save_top_k=1 # Save only the best model
264
+ )
265
+
266
+ #trainer = Trainer(
267
+ #max_epochs=config.NUM_EPOCHS * 2 // 5,
268
+ #accelerator="cpu",
269
+ #precision=16, # limiting got iPython runs
270
+ #logger=CSVLogger(save_dir="logs/"),
271
+ #callbacks=[LearningRateMonitor(logging_interval="step"),TQDMProgressBar(refresh_rate=0),checkpoint_callback],
272
+ #)
273
+
274
+
275
+ #trainer.fit(model)
276
+
277
+ #metrics = pd.read_csv(f"{trainer.logger.log_dir}/metrics.csv")
278
+ #del metrics["step"]
279
+ #metrics.set_index("epoch", inplace=True)
280
+ #display(metrics.dropna(axis=1, how="all").head())
281
+ #sn.relplot(data=metrics, kind="line")
282
+
src/model_yolov3.py ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Implementation of YOLOv3 architecture."""
2
+ from typing import Any, List
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch
7
+ from pytorch_lightning import LightningModule, Trainer
8
+ from torch import nn
9
+ from torch.nn import functional as F
10
+ from torch.utils.data import DataLoader, random_split
11
+ import torchvision
12
+ from pytorch_lightning.callbacks import LearningRateMonitor
13
+ from pytorch_lightning.callbacks.progress import TQDMProgressBar
14
+ from pytorch_lightning.loggers import CSVLogger
15
+ from pytorch_lightning.callbacks import ModelCheckpoint
16
+ import pandas as pd
17
+ from torch.optim.lr_scheduler import OneCycleLR
18
+
19
+
20
+ """
21
+ Information about architecture config:
22
+ Tuple is structured by (filters, kernel_size, stride)
23
+ Every conv is a same convolution.
24
+ List is structured by "B" indicating a residual block followed by the number of repeats
25
+ "S" is for scale prediction block and computing the yolo loss
26
+ "U" is for upsampling the feature map and concatenating with a previous layer
27
+ """
28
+ config = [
29
+ (32, 3, 1),
30
+ (64, 3, 2),
31
+ ["B", 1],
32
+ (128, 3, 2),
33
+ ["B", 2],
34
+ (256, 3, 2),
35
+ ["B", 8],
36
+ (512, 3, 2),
37
+ ["B", 8],
38
+ (1024, 3, 2),
39
+ ["B", 4], # To this point is Darknet-53
40
+ (512, 1, 1),
41
+ (1024, 3, 1),
42
+ "S",
43
+ (256, 1, 1),
44
+ "U",
45
+ (256, 1, 1),
46
+ (512, 3, 1),
47
+ "S",
48
+ (128, 1, 1),
49
+ "U",
50
+ (128, 1, 1),
51
+ (256, 3, 1),
52
+ "S",
53
+ ]
54
+
55
+
56
+ class CNNBlock(nn.Module):
57
+ def __init__(self, in_channels, out_channels, bn_act=True, **kwargs):
58
+ super().__init__()
59
+ self.conv = nn.Conv2d(in_channels, out_channels, bias=not bn_act, **kwargs)
60
+ self.bn = nn.BatchNorm2d(out_channels)
61
+ self.leaky = nn.LeakyReLU(0.1)
62
+ self.use_bn_act = bn_act
63
+
64
+ def forward(self, x):
65
+ if self.use_bn_act:
66
+ return self.leaky(self.bn(self.conv(x)))
67
+ else:
68
+ return self.conv(x)
69
+
70
+
71
+ class ResidualBlock(nn.Module):
72
+ def __init__(self, channels, use_residual=True, num_repeats=1):
73
+ super().__init__()
74
+ self.layers = nn.ModuleList()
75
+ for repeat in range(num_repeats):
76
+ self.layers += [
77
+ nn.Sequential(
78
+ CNNBlock(channels, channels // 2, kernel_size=1),
79
+ CNNBlock(channels // 2, channels, kernel_size=3, padding=1),
80
+ )
81
+ ]
82
+
83
+ self.use_residual = use_residual
84
+ self.num_repeats = num_repeats
85
+
86
+ def forward(self, x):
87
+ for layer in self.layers:
88
+ if self.use_residual:
89
+ x = x + layer(x)
90
+ else:
91
+ x = layer(x)
92
+
93
+ return x
94
+
95
+
96
+ class ScalePrediction(nn.Module):
97
+ def __init__(self, in_channels, num_classes):
98
+ super().__init__()
99
+ self.pred = nn.Sequential(
100
+ CNNBlock(in_channels, 2 * in_channels, kernel_size=3, padding=1),
101
+ CNNBlock(2 * in_channels, (num_classes + 5) * 3, bn_act=False, kernel_size=1),
102
+ )
103
+ self.num_classes = num_classes
104
+
105
+ def forward(self, x):
106
+ return (
107
+ self.pred(x)
108
+ .reshape(x.shape[0], 3, self.num_classes + 5, x.shape[2], x.shape[3])
109
+ .permute(0, 1, 3, 4, 2)
110
+ )
111
+
112
+
113
+ class YOLOv3(nn.Module):
114
+ def __init__(self, load_config: List[Any] = config, in_channels=3, num_classes=80):
115
+ super().__init__()
116
+ self.load_config = load_config
117
+ self.num_classes = num_classes
118
+ self.in_channels = in_channels
119
+ self.layers = self._create_conv_layers()
120
+
121
+ def forward(self, x):
122
+ outputs = [] # for each scale
123
+ route_connections = []
124
+ for layer in self.layers:
125
+ if isinstance(layer, ScalePrediction):
126
+ outputs.append(layer(x))
127
+ continue
128
+
129
+ x = layer(x)
130
+
131
+ if isinstance(layer, ResidualBlock) and layer.num_repeats == 8:
132
+ route_connections.append(x)
133
+
134
+ elif isinstance(layer, nn.Upsample):
135
+ x = torch.cat([x, route_connections[-1]], dim=1)
136
+ route_connections.pop()
137
+
138
+ return outputs
139
+
140
+ def _create_conv_layers(self):
141
+ layers = nn.ModuleList()
142
+ in_channels = self.in_channels
143
+
144
+ for module in self.load_config:
145
+ if isinstance(module, tuple):
146
+ out_channels, kernel_size, stride = module
147
+ layers.append(
148
+ CNNBlock(
149
+ in_channels,
150
+ out_channels,
151
+ kernel_size=kernel_size,
152
+ stride=stride,
153
+ padding=1 if kernel_size == 3 else 0,
154
+ )
155
+ )
156
+ in_channels = out_channels
157
+
158
+ elif isinstance(module, list):
159
+ num_repeats = module[1]
160
+ layers.append(
161
+ ResidualBlock(
162
+ in_channels,
163
+ num_repeats=num_repeats,
164
+ )
165
+ )
166
+
167
+ elif isinstance(module, str):
168
+ if module == "S":
169
+ layers += [
170
+ ResidualBlock(in_channels, use_residual=False, num_repeats=1),
171
+ CNNBlock(in_channels, in_channels // 2, kernel_size=1),
172
+ ScalePrediction(in_channels // 2, num_classes=self.num_classes),
173
+ ]
174
+ in_channels = in_channels // 2
175
+
176
+ elif module == "U":
177
+ layers.append(
178
+ nn.Upsample(scale_factor=2),
179
+ )
180
+ in_channels = in_channels * 3
181
+
182
+ return layers
183
+
184
+
185
+ class Assignment13(LightningModule):
186
+ def __init__(self):
187
+ super().__init__()
188
+ self.save_hyperparameters()
189
+ self.epoch_number = 0
190
+ self.config = config
191
+ self.train_csv_path = self.config.DATASET + "/train.csv"
192
+ self.test_csv_path = self.config.DATASET + "/test.csv"
193
+ self.train_loader, self.test_loader, self.train_eval_loader = get_loaders(
194
+ train_csv_path=self.train_csv_path, test_csv_path=self.test_csv_path)
195
+ self.check_class_accuracy = check_class_accuracy
196
+ self.model = YOLOv3(num_classes=self.config.NUM_CLASSES)
197
+ self.loss_fn = YoloLoss()
198
+ self.check_class_accuracy = check_class_accuracy
199
+ self.get_evaluation_bboxes = get_evaluation_bboxes
200
+ self.scaled_anchors = (torch.tensor(self.config.ANCHORS) * torch.tensor(self.config.S).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2))
201
+ self.losses = []
202
+ self.plot_couple_examples = plot_couple_examples
203
+ self.mean_average_precision = mean_average_precision
204
+ self.EPOCHS = self.config.NUM_EPOCHS * 2 // 5
205
+ def forward(self, x):
206
+ out = self.model(x)
207
+ return out
208
+ def training_step(self, batch, batch_idx):
209
+ x, y = batch
210
+ out = self(x)
211
+ y0, y1, y2 = (y[0],y[1],y[2])
212
+ loss = (
213
+ self.loss_fn(out[0], y0, self.scaled_anchors[0].to(y0))
214
+ + self.loss_fn(out[1], y1, self.scaled_anchors[1].to(y1))
215
+ + self.loss_fn(out[2], y2, self.scaled_anchors[2].to(y2))
216
+ )
217
+ self.losses.append(loss.item())
218
+ mean_loss = sum(self.losses) / len(self.losses)
219
+ self.log("train_loss", mean_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
220
+ #self.log("train_loss", mean_loss)
221
+ return loss
222
+
223
+
224
+ def on_train_epoch_start(self):
225
+ self.epoch_number += 1
226
+ self.losses = []
227
+ #self.plot_couple_examples(self.model,self.test_loader,0.6,0.5,self.scaled_anchors)
228
+ if self.epoch_number > 1 and self.epoch_number % 10 == 0:
229
+ self.plot_couple_examples(self.model,self.test_loader,0.6,0.5,self.scaled_anchors)
230
+
231
+ def on_train_epoch_end(self):
232
+ print(f"Currently epoch {self.epoch_number}")
233
+ print("On Train Eval loader:")
234
+ print("On Train loader:")
235
+ self.check_class_accuracy(self.model, self.train_loader, threshold=self.config.CONF_THRESHOLD)
236
+ if self.epoch_number == self.EPOCHS:
237
+ #if self.epoch_number > 1 and self.epoch_number % 3 == 0:
238
+ self.check_class_accuracy(self.model, self.test_loader, threshold=self.config.CONF_THRESHOLD)
239
+ pred_boxes, true_boxes = self.get_evaluation_bboxes( self.test_loader,self.model,iou_threshold=self.config.NMS_IOU_THRESH,
240
+ anchors=self.config.ANCHORS,
241
+ threshold=self.config.CONF_THRESHOLD,)
242
+ mapval = self.mean_average_precision(
243
+ pred_boxes,
244
+ true_boxes,
245
+ iou_threshold=self.config.MAP_IOU_THRESH,
246
+ box_format="midpoint",
247
+ num_classes=self.config.NUM_CLASSES,
248
+ )
249
+ print(f"MAP: {mapval.item()}")
250
+ self.model.train()
251
+ pass
252
+
253
+
254
+ def configure_optimizers(self):
255
+ optimizer = optimizer = optim.Adam(
256
+ model.parameters(), lr=config.LEARNING_RATE, weight_decay=config.WEIGHT_DECAY)
257
+ #EPOCHS = config.NUM_EPOCHS * 2 // 5
258
+ scheduler = OneCycleLR(
259
+ optimizer,
260
+ max_lr=1E-3,
261
+ steps_per_epoch=len(self.train_loader),
262
+ epochs=self.EPOCHS,
263
+ pct_start=5/self.EPOCHS,
264
+ div_factor=100,
265
+ three_phase=False,
266
+ final_div_factor=100,
267
+ anneal_strategy='linear'
268
+ )
269
+
270
+ return {"optimizer": optimizer, "lr_scheduler":scheduler}
271
+
272
+ ####################
273
+ # DATA RELATED HOOKS
274
+ ####################
275
+
276
+ def train_dataloader(self):
277
+ return self.train_loader
278
+
279
+ def test_dataloader(self):
280
+ return self.test_loader
281
+
282
+ if __name__ == "__main__":
283
+ num_classes = 20
284
+ IMAGE_SIZE = 416
285
+ model = YOLOv3(load_config=config, num_classes=num_classes)
286
+ x = torch.randn((2, 3, IMAGE_SIZE, IMAGE_SIZE))
287
+ out = model(x)
288
+ assert out[0].shape == (2, 3, IMAGE_SIZE // 32, IMAGE_SIZE // 32, num_classes + 5)
289
+ assert out[1].shape == (2, 3, IMAGE_SIZE // 16, IMAGE_SIZE // 16, num_classes + 5)
290
+ assert out[2].shape == (2, 3, IMAGE_SIZE // 8, IMAGE_SIZE // 8, num_classes + 5)
291
+ print("Success!")
src/utils.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ import torch
3
+ import numpy as np
4
+ import cv2
5
+ import random
6
+
7
+ from pytorch_grad_cam.base_cam import BaseCAM
8
+ from pytorch_grad_cam.utils.svd_on_activations import get_2d_projection
9
+ from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
10
+
11
+ def seed_everything(seed=42):
12
+ os.environ['PYTHONHASHSEED'] = str(seed)
13
+ random.seed(seed)
14
+ np.random.seed(seed)
15
+ torch.manual_seed(seed)
16
+ torch.cuda.manual_seed(seed)
17
+ torch.cuda.manual_seed_all(seed)
18
+ torch.backends.cudnn.deterministic = True
19
+ torch.backends.cudnn.benchmark = False
20
+
21
+
22
+ def cells_to_bboxes(predictions, anchors, S, is_preds=True):
23
+ """
24
+ Scales the predictions coming from the model to
25
+ be relative to the entire image such that they for example later
26
+ can be plotted or.
27
+ INPUT:
28
+ predictions: tensor of size (N, 3, S, S, num_classes+5)
29
+ anchors: the anchors used for the predictions
30
+ S: the number of cells the image is divided in on the width (and height)
31
+ is_preds: whether the input is predictions or the true bounding boxes
32
+ OUTPUT:
33
+ converted_bboxes: the converted boxes of sizes (N, num_anchors, S, S, 1+5) with class index,
34
+ object score, bounding box coordinates
35
+ """
36
+ BATCH_SIZE = predictions.shape[0]
37
+ num_anchors = len(anchors)
38
+ box_predictions = predictions[..., 1:5]
39
+ if is_preds:
40
+ anchors = anchors.reshape(1, len(anchors), 1, 1, 2)
41
+ box_predictions[..., 0:2] = torch.sigmoid(box_predictions[..., 0:2])
42
+ box_predictions[..., 2:] = torch.exp(box_predictions[..., 2:]) * anchors
43
+ scores = torch.sigmoid(predictions[..., 0:1])
44
+ best_class = torch.argmax(predictions[..., 5:], dim=-1).unsqueeze(-1)
45
+ else:
46
+ scores = predictions[..., 0:1]
47
+ best_class = predictions[..., 5:6]
48
+
49
+ cell_indices = (
50
+ torch.arange(S)
51
+ .repeat(predictions.shape[0], 3, S, 1)
52
+ .unsqueeze(-1)
53
+ .to(predictions.device)
54
+ )
55
+ x = 1 / S * (box_predictions[..., 0:1] + cell_indices)
56
+ y = 1 / S * (box_predictions[..., 1:2] + cell_indices.permute(0, 1, 3, 2, 4))
57
+ w_h = 1 / S * box_predictions[..., 2:4]
58
+ converted_bboxes = torch.cat((best_class, scores, x, y, w_h), dim=-1).reshape(BATCH_SIZE, num_anchors * S * S, 6)
59
+ return converted_bboxes.tolist()
60
+
61
+
62
+
63
+ def intersection_over_union(boxes_preds, boxes_labels, box_format="midpoint"):
64
+ """
65
+ Video explanation of this function:
66
+ https://youtu.be/XXYG5ZWtjj0
67
+
68
+ This function calculates intersection over union (iou) given pred boxes
69
+ and target boxes.
70
+
71
+ Parameters:
72
+ boxes_preds (tensor): Predictions of Bounding Boxes (BATCH_SIZE, 4)
73
+ boxes_labels (tensor): Correct labels of Bounding Boxes (BATCH_SIZE, 4)
74
+ box_format (str): midpoint/corners, if boxes (x,y,w,h) or (x1,y1,x2,y2)
75
+
76
+ Returns:
77
+ tensor: Intersection over union for all examples
78
+ """
79
+
80
+ if box_format == "midpoint":
81
+ box1_x1 = boxes_preds[..., 0:1] - boxes_preds[..., 2:3] / 2
82
+ box1_y1 = boxes_preds[..., 1:2] - boxes_preds[..., 3:4] / 2
83
+ box1_x2 = boxes_preds[..., 0:1] + boxes_preds[..., 2:3] / 2
84
+ box1_y2 = boxes_preds[..., 1:2] + boxes_preds[..., 3:4] / 2
85
+ box2_x1 = boxes_labels[..., 0:1] - boxes_labels[..., 2:3] / 2
86
+ box2_y1 = boxes_labels[..., 1:2] - boxes_labels[..., 3:4] / 2
87
+ box2_x2 = boxes_labels[..., 0:1] + boxes_labels[..., 2:3] / 2
88
+ box2_y2 = boxes_labels[..., 1:2] + boxes_labels[..., 3:4] / 2
89
+
90
+ if box_format == "corners":
91
+ box1_x1 = boxes_preds[..., 0:1]
92
+ box1_y1 = boxes_preds[..., 1:2]
93
+ box1_x2 = boxes_preds[..., 2:3]
94
+ box1_y2 = boxes_preds[..., 3:4]
95
+ box2_x1 = boxes_labels[..., 0:1]
96
+ box2_y1 = boxes_labels[..., 1:2]
97
+ box2_x2 = boxes_labels[..., 2:3]
98
+ box2_y2 = boxes_labels[..., 3:4]
99
+
100
+ x1 = torch.max(box1_x1, box2_x1)
101
+ y1 = torch.max(box1_y1, box2_y1)
102
+ x2 = torch.min(box1_x2, box2_x2)
103
+ y2 = torch.min(box1_y2, box2_y2)
104
+
105
+ intersection = (x2 - x1).clamp(0) * (y2 - y1).clamp(0)
106
+ box1_area = abs((box1_x2 - box1_x1) * (box1_y2 - box1_y1))
107
+ box2_area = abs((box2_x2 - box2_x1) * (box2_y2 - box2_y1))
108
+
109
+ return intersection / (box1_area + box2_area - intersection + 1e-6)
110
+
111
+ def non_max_suppression(bboxes, iou_threshold, threshold, box_format="corners"):
112
+ """
113
+ Video explanation of this function:
114
+ https://youtu.be/YDkjWEN8jNA
115
+
116
+ Does Non Max Suppression given bboxes
117
+
118
+ Parameters:
119
+ bboxes (list): list of lists containing all bboxes with each bboxes
120
+ specified as [class_pred, prob_score, x1, y1, x2, y2]
121
+ iou_threshold (float): threshold where predicted bboxes is correct
122
+ threshold (float): threshold to remove predicted bboxes (independent of IoU)
123
+ box_format (str): "midpoint" or "corners" used to specify bboxes
124
+
125
+ Returns:
126
+ list: bboxes after performing NMS given a specific IoU threshold
127
+ """
128
+
129
+ assert type(bboxes) == list
130
+
131
+ bboxes = [box for box in bboxes if box[1] > threshold]
132
+ bboxes = sorted(bboxes, key=lambda x: x[1], reverse=True)
133
+ bboxes_after_nms = []
134
+
135
+ while bboxes:
136
+ chosen_box = bboxes.pop(0)
137
+
138
+ bboxes = [
139
+ box
140
+ for box in bboxes
141
+ if box[0] != chosen_box[0]
142
+ or intersection_over_union(
143
+ torch.tensor(chosen_box[2:]),
144
+ torch.tensor(box[2:]),
145
+ box_format=box_format,
146
+ )
147
+ < iou_threshold
148
+ ]
149
+
150
+ bboxes_after_nms.append(chosen_box)
151
+
152
+ return bboxes_after_nms
153
+
154
+
155
+
156
+
157
+ def draw_predictions(image: np.ndarray, boxes: List[List], class_labels: List[str]) -> np.ndarray:
158
+ """Plots predicted bounding boxes on the image"""
159
+
160
+ colors = [[random.randint(0, 255) for _ in range(3)] for name in class_labels]
161
+
162
+ im = np.array(image)
163
+ height, width, _ = im.shape
164
+ bbox_thick = int(0.6 * (height + width) / 600)
165
+
166
+ # Create a Rectangle patch
167
+ for box in boxes:
168
+ assert len(box) == 6, "box should contain class pred, confidence, x, y, width, height"
169
+ class_pred = box[0]
170
+ conf = box[1]
171
+ box = box[2:]
172
+ upper_left_x = box[0] - box[2] / 2
173
+ upper_left_y = box[1] - box[3] / 2
174
+
175
+ x1 = int(upper_left_x * width)
176
+ y1 = int(upper_left_y * height)
177
+
178
+ x2 = x1 + int(box[2] * width)
179
+ y2 = y1 + int(box[3] * height)
180
+
181
+ cv2.rectangle(
182
+ image,
183
+ (x1, y1), (x2, y2),
184
+ color=colors[int(class_pred)],
185
+ thickness=bbox_thick
186
+ )
187
+ text = f"{class_labels[int(class_pred)]}: {conf:.2f}"
188
+ t_size = cv2.getTextSize(text, 0, 0.7, thickness=bbox_thick // 2)[0]
189
+ c3 = (x1 + t_size[0], y1 - t_size[1] - 3)
190
+
191
+ cv2.rectangle(image, (x1, y1), c3, colors[int(class_pred)], -1)
192
+ cv2.putText(
193
+ image,
194
+ text,
195
+ (x1, y1 - 2),
196
+ cv2.FONT_HERSHEY_SIMPLEX,
197
+ 0.7,
198
+ (0, 0, 0),
199
+ bbox_thick // 2,
200
+ lineType=cv2.LINE_AA,
201
+ )
202
+
203
+ return image
204
+
205
+
206
+ class YoloCAM(BaseCAM):
207
+ def __init__(self, model, target_layers, use_cuda=False,
208
+ reshape_transform=None):
209
+ super(YoloCAM, self).__init__(model,
210
+ target_layers,
211
+ use_cuda,
212
+ reshape_transform,
213
+ uses_gradients=False)
214
+
215
+ def forward(self,
216
+ input_tensor: torch.Tensor,
217
+ scaled_anchors: torch.Tensor,
218
+ targets: List[torch.nn.Module],
219
+ eigen_smooth: bool = False) -> np.ndarray:
220
+
221
+ if self.cuda:
222
+ input_tensor = input_tensor.cuda()
223
+
224
+ if self.compute_input_gradient:
225
+ input_tensor = torch.autograd.Variable(input_tensor,
226
+ requires_grad=True)
227
+
228
+ outputs = self.activations_and_grads(input_tensor)
229
+ if targets is None:
230
+ bboxes = [[] for _ in range(1)]
231
+ for i in range(3):
232
+ batch_size, A, S, _, _ = outputs[i].shape
233
+ anchor = scaled_anchors[i]
234
+ boxes_scale_i = cells_to_bboxes(
235
+ outputs[i], anchor, S=S, is_preds=True
236
+ )
237
+ for idx, (box) in enumerate(boxes_scale_i):
238
+ bboxes[idx] += box
239
+
240
+ nms_boxes = non_max_suppression(
241
+ bboxes[0], iou_threshold=0.5, threshold=0.4, box_format="midpoint",
242
+ )
243
+ # target_categories = np.argmax(outputs.cpu().data.numpy(), axis=-1)
244
+ target_categories = [box[0] for box in nms_boxes]
245
+ targets = [ClassifierOutputTarget(
246
+ category) for category in target_categories]
247
+
248
+ if self.uses_gradients:
249
+ self.model.zero_grad()
250
+ loss = sum([target(output)
251
+ for target, output in zip(targets, outputs)])
252
+ loss.backward(retain_graph=True)
253
+
254
+ # In most of the saliency attribution papers, the saliency is
255
+ # computed with a single target layer.
256
+ # Commonly it is the last convolutional layer.
257
+ # Here we support passing a list with multiple target layers.
258
+ # It will compute the saliency image for every image,
259
+ # and then aggregate them (with a default mean aggregation).
260
+ # This gives you more flexibility in case you just want to
261
+ # use all conv layers for example, all Batchnorm layers,
262
+ # or something else.
263
+ cam_per_layer = self.compute_cam_per_layer(input_tensor,
264
+ targets,
265
+ eigen_smooth)
266
+ return self.aggregate_multi_layers(cam_per_layer)
267
+
268
+ def get_cam_image(self,
269
+ input_tensor,
270
+ target_layer,
271
+ target_category,
272
+ activations,
273
+ grads,
274
+ eigen_smooth):
275
+ return get_2d_projection(activations)
276
+
277
+
278
+
src/utils_rh.py ADDED
@@ -0,0 +1,582 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import src.config as config
2
+ import matplotlib.pyplot as plt
3
+ import matplotlib.patches as patches
4
+ import numpy as np
5
+ import os
6
+ import random
7
+ import torch
8
+
9
+ from collections import Counter
10
+ from torch.utils.data import DataLoader
11
+ from tqdm import tqdm
12
+
13
+
14
+ def iou_width_height(boxes1, boxes2):
15
+ """
16
+ Parameters:
17
+ boxes1 (tensor): width and height of the first bounding boxes
18
+ boxes2 (tensor): width and height of the second bounding boxes
19
+ Returns:
20
+ tensor: Intersection over union of the corresponding boxes
21
+ """
22
+ intersection = torch.min(boxes1[..., 0], boxes2[..., 0]) * torch.min(
23
+ boxes1[..., 1], boxes2[..., 1]
24
+ )
25
+ union = (
26
+ boxes1[..., 0] * boxes1[..., 1] + boxes2[..., 0] * boxes2[..., 1] - intersection
27
+ )
28
+ return intersection / union
29
+
30
+
31
+ def intersection_over_union(boxes_preds, boxes_labels, box_format="midpoint"):
32
+ """
33
+ Video explanation of this function:
34
+ https://youtu.be/XXYG5ZWtjj0
35
+
36
+ This function calculates intersection over union (iou) given pred boxes
37
+ and target boxes.
38
+
39
+ Parameters:
40
+ boxes_preds (tensor): Predictions of Bounding Boxes (BATCH_SIZE, 4)
41
+ boxes_labels (tensor): Correct labels of Bounding Boxes (BATCH_SIZE, 4)
42
+ box_format (str): midpoint/corners, if boxes (x,y,w,h) or (x1,y1,x2,y2)
43
+
44
+ Returns:
45
+ tensor: Intersection over union for all examples
46
+ """
47
+
48
+ if box_format == "midpoint":
49
+ box1_x1 = boxes_preds[..., 0:1] - boxes_preds[..., 2:3] / 2
50
+ box1_y1 = boxes_preds[..., 1:2] - boxes_preds[..., 3:4] / 2
51
+ box1_x2 = boxes_preds[..., 0:1] + boxes_preds[..., 2:3] / 2
52
+ box1_y2 = boxes_preds[..., 1:2] + boxes_preds[..., 3:4] / 2
53
+ box2_x1 = boxes_labels[..., 0:1] - boxes_labels[..., 2:3] / 2
54
+ box2_y1 = boxes_labels[..., 1:2] - boxes_labels[..., 3:4] / 2
55
+ box2_x2 = boxes_labels[..., 0:1] + boxes_labels[..., 2:3] / 2
56
+ box2_y2 = boxes_labels[..., 1:2] + boxes_labels[..., 3:4] / 2
57
+
58
+ if box_format == "corners":
59
+ box1_x1 = boxes_preds[..., 0:1]
60
+ box1_y1 = boxes_preds[..., 1:2]
61
+ box1_x2 = boxes_preds[..., 2:3]
62
+ box1_y2 = boxes_preds[..., 3:4]
63
+ box2_x1 = boxes_labels[..., 0:1]
64
+ box2_y1 = boxes_labels[..., 1:2]
65
+ box2_x2 = boxes_labels[..., 2:3]
66
+ box2_y2 = boxes_labels[..., 3:4]
67
+
68
+ x1 = torch.max(box1_x1, box2_x1)
69
+ y1 = torch.max(box1_y1, box2_y1)
70
+ x2 = torch.min(box1_x2, box2_x2)
71
+ y2 = torch.min(box1_y2, box2_y2)
72
+
73
+ intersection = (x2 - x1).clamp(0) * (y2 - y1).clamp(0)
74
+ box1_area = abs((box1_x2 - box1_x1) * (box1_y2 - box1_y1))
75
+ box2_area = abs((box2_x2 - box2_x1) * (box2_y2 - box2_y1))
76
+
77
+ return intersection / (box1_area + box2_area - intersection + 1e-6)
78
+
79
+
80
+ def non_max_suppression(bboxes, iou_threshold, threshold, box_format="corners"):
81
+ """
82
+ Video explanation of this function:
83
+ https://youtu.be/YDkjWEN8jNA
84
+
85
+ Does Non Max Suppression given bboxes
86
+
87
+ Parameters:
88
+ bboxes (list): list of lists containing all bboxes with each bboxes
89
+ specified as [class_pred, prob_score, x1, y1, x2, y2]
90
+ iou_threshold (float): threshold where predicted bboxes is correct
91
+ threshold (float): threshold to remove predicted bboxes (independent of IoU)
92
+ box_format (str): "midpoint" or "corners" used to specify bboxes
93
+
94
+ Returns:
95
+ list: bboxes after performing NMS given a specific IoU threshold
96
+ """
97
+
98
+ assert type(bboxes) == list
99
+
100
+ bboxes = [box for box in bboxes if box[1] > threshold]
101
+ bboxes = sorted(bboxes, key=lambda x: x[1], reverse=True)
102
+ bboxes_after_nms = []
103
+
104
+ while bboxes:
105
+ chosen_box = bboxes.pop(0)
106
+
107
+ bboxes = [
108
+ box
109
+ for box in bboxes
110
+ if box[0] != chosen_box[0]
111
+ or intersection_over_union(
112
+ torch.tensor(chosen_box[2:]),
113
+ torch.tensor(box[2:]),
114
+ box_format=box_format,
115
+ )
116
+ < iou_threshold
117
+ ]
118
+
119
+ bboxes_after_nms.append(chosen_box)
120
+
121
+ return bboxes_after_nms
122
+
123
+
124
+ def mean_average_precision(
125
+ pred_boxes, true_boxes, iou_threshold=0.5, box_format="midpoint", num_classes=20
126
+ ):
127
+ """
128
+ Video explanation of this function:
129
+ https://youtu.be/FppOzcDvaDI
130
+
131
+ This function calculates mean average precision (mAP)
132
+
133
+ Parameters:
134
+ pred_boxes (list): list of lists containing all bboxes with each bboxes
135
+ specified as [train_idx, class_prediction, prob_score, x1, y1, x2, y2]
136
+ true_boxes (list): Similar as pred_boxes except all the correct ones
137
+ iou_threshold (float): threshold where predicted bboxes is correct
138
+ box_format (str): "midpoint" or "corners" used to specify bboxes
139
+ num_classes (int): number of classes
140
+
141
+ Returns:
142
+ float: mAP value across all classes given a specific IoU threshold
143
+ """
144
+
145
+ # list storing all AP for respective classes
146
+ average_precisions = []
147
+
148
+ # used for numerical stability later on
149
+ epsilon = 1e-6
150
+
151
+ for c in range(num_classes):
152
+ detections = []
153
+ ground_truths = []
154
+
155
+ # Go through all predictions and targets,
156
+ # and only add the ones that belong to the
157
+ # current class c
158
+ for detection in pred_boxes:
159
+ if detection[1] == c:
160
+ detections.append(detection)
161
+
162
+ for true_box in true_boxes:
163
+ if true_box[1] == c:
164
+ ground_truths.append(true_box)
165
+
166
+ # find the amount of bboxes for each training example
167
+ # Counter here finds how many ground truth bboxes we get
168
+ # for each training example, so let's say img 0 has 3,
169
+ # img 1 has 5 then we will obtain a dictionary with:
170
+ # amount_bboxes = {0:3, 1:5}
171
+ amount_bboxes = Counter([gt[0] for gt in ground_truths])
172
+
173
+ # We then go through each key, val in this dictionary
174
+ # and convert to the following (w.r.t same example):
175
+ # ammount_bboxes = {0:torch.tensor[0,0,0], 1:torch.tensor[0,0,0,0,0]}
176
+ for key, val in amount_bboxes.items():
177
+ amount_bboxes[key] = torch.zeros(val)
178
+
179
+ # sort by box probabilities which is index 2
180
+ detections.sort(key=lambda x: x[2], reverse=True)
181
+ TP = torch.zeros((len(detections)))
182
+ FP = torch.zeros((len(detections)))
183
+ total_true_bboxes = len(ground_truths)
184
+
185
+ # If none exists for this class then we can safely skip
186
+ if total_true_bboxes == 0:
187
+ continue
188
+
189
+ for detection_idx, detection in enumerate(detections):
190
+ # Only take out the ground_truths that have the same
191
+ # training idx as detection
192
+ ground_truth_img = [
193
+ bbox for bbox in ground_truths if bbox[0] == detection[0]
194
+ ]
195
+
196
+ num_gts = len(ground_truth_img)
197
+ best_iou = 0
198
+
199
+ for idx, gt in enumerate(ground_truth_img):
200
+ iou = intersection_over_union(
201
+ torch.tensor(detection[3:]),
202
+ torch.tensor(gt[3:]),
203
+ box_format=box_format,
204
+ )
205
+
206
+ if iou > best_iou:
207
+ best_iou = iou
208
+ best_gt_idx = idx
209
+
210
+ if best_iou > iou_threshold:
211
+ # only detect ground truth detection once
212
+ if amount_bboxes[detection[0]][best_gt_idx] == 0:
213
+ # true positive and add this bounding box to seen
214
+ TP[detection_idx] = 1
215
+ amount_bboxes[detection[0]][best_gt_idx] = 1
216
+ else:
217
+ FP[detection_idx] = 1
218
+
219
+ # if IOU is lower then the detection is a false positive
220
+ else:
221
+ FP[detection_idx] = 1
222
+
223
+ TP_cumsum = torch.cumsum(TP, dim=0)
224
+ FP_cumsum = torch.cumsum(FP, dim=0)
225
+ recalls = TP_cumsum / (total_true_bboxes + epsilon)
226
+ precisions = TP_cumsum / (TP_cumsum + FP_cumsum + epsilon)
227
+ precisions = torch.cat((torch.tensor([1]), precisions))
228
+ recalls = torch.cat((torch.tensor([0]), recalls))
229
+ # torch.trapz for numerical integration
230
+ average_precisions.append(torch.trapz(precisions, recalls))
231
+
232
+ return sum(average_precisions) / len(average_precisions)
233
+
234
+
235
+ def plot_image(image, boxes):
236
+ """Plots predicted bounding boxes on the image"""
237
+ cmap = plt.get_cmap("tab20b")
238
+ class_labels = config.COCO_LABELS if config.DATASET=='COCO' else config.PASCAL_CLASSES
239
+ colors = [cmap(i) for i in np.linspace(0, 1, len(class_labels))]
240
+ im = np.array(image)
241
+ height, width, _ = im.shape
242
+
243
+ # Create figure and axes
244
+ fig, ax = plt.subplots(1)
245
+ # Display the image
246
+ ax.imshow(im)
247
+
248
+ # box[0] is x midpoint, box[2] is width
249
+ # box[1] is y midpoint, box[3] is height
250
+
251
+ # Create a Rectangle patch
252
+ for box in boxes:
253
+ assert len(box) == 6, "box should contain class pred, confidence, x, y, width, height"
254
+ class_pred = box[0]
255
+ box = box[2:]
256
+ upper_left_x = box[0] - box[2] / 2
257
+ upper_left_y = box[1] - box[3] / 2
258
+ rect = patches.Rectangle(
259
+ (upper_left_x * width, upper_left_y * height),
260
+ box[2] * width,
261
+ box[3] * height,
262
+ linewidth=2,
263
+ edgecolor=colors[int(class_pred)],
264
+ facecolor="none",
265
+ )
266
+ # Add the patch to the Axes
267
+ ax.add_patch(rect)
268
+ plt.text(
269
+ upper_left_x * width,
270
+ upper_left_y * height,
271
+ s=class_labels[int(class_pred)],
272
+ color="white",
273
+ verticalalignment="top",
274
+ bbox={"color": colors[int(class_pred)], "pad": 0},
275
+ )
276
+
277
+ plt.show()
278
+
279
+
280
+ def get_evaluation_bboxes(
281
+ loader,
282
+ model,
283
+ iou_threshold,
284
+ anchors,
285
+ threshold,
286
+ box_format="midpoint",
287
+ device="cuda",
288
+ ):
289
+ # make sure model is in eval before get bboxes
290
+ model.eval()
291
+ train_idx = 0
292
+ all_pred_boxes = []
293
+ all_true_boxes = []
294
+ for batch_idx, (x, labels) in enumerate(tqdm(loader)):
295
+ x = x.to(device)
296
+
297
+ with torch.no_grad():
298
+ predictions = model(x)
299
+
300
+ batch_size = x.shape[0]
301
+ bboxes = [[] for _ in range(batch_size)]
302
+ for i in range(3):
303
+ S = predictions[i].shape[2]
304
+ anchor = torch.tensor([*anchors[i]]).to(device) * S
305
+ boxes_scale_i = cells_to_bboxes(
306
+ predictions[i], anchor, S=S, is_preds=True
307
+ )
308
+ for idx, (box) in enumerate(boxes_scale_i):
309
+ bboxes[idx] += box
310
+
311
+ # we just want one bbox for each label, not one for each scale
312
+ true_bboxes = cells_to_bboxes(
313
+ labels[2], anchor, S=S, is_preds=False
314
+ )
315
+
316
+ for idx in range(batch_size):
317
+ nms_boxes = non_max_suppression(
318
+ bboxes[idx],
319
+ iou_threshold=iou_threshold,
320
+ threshold=threshold,
321
+ box_format=box_format,
322
+ )
323
+
324
+ for nms_box in nms_boxes:
325
+ all_pred_boxes.append([train_idx] + nms_box)
326
+
327
+ for box in true_bboxes[idx]:
328
+ if box[1] > threshold:
329
+ all_true_boxes.append([train_idx] + box)
330
+
331
+ train_idx += 1
332
+
333
+ model.train()
334
+ return all_pred_boxes, all_true_boxes
335
+
336
+
337
+ def cells_to_bboxes(predictions, anchors, S, is_preds=True):
338
+ """
339
+ Scales the predictions coming from the model to
340
+ be relative to the entire image such that they for example later
341
+ can be plotted or.
342
+ INPUT:
343
+ predictions: tensor of size (N, 3, S, S, num_classes+5)
344
+ anchors: the anchors used for the predictions
345
+ S: the number of cells the image is divided in on the width (and height)
346
+ is_preds: whether the input is predictions or the true bounding boxes
347
+ OUTPUT:
348
+ converted_bboxes: the converted boxes of sizes (N, num_anchors, S, S, 1+5) with class index,
349
+ object score, bounding box coordinates
350
+ """
351
+ BATCH_SIZE = predictions.shape[0]
352
+ num_anchors = len(anchors)
353
+ box_predictions = predictions[..., 1:5]
354
+ if is_preds:
355
+ anchors = anchors.reshape(1, len(anchors), 1, 1, 2)
356
+ box_predictions[..., 0:2] = torch.sigmoid(box_predictions[..., 0:2])
357
+ box_predictions[..., 2:] = torch.exp(box_predictions[..., 2:]) * anchors
358
+ scores = torch.sigmoid(predictions[..., 0:1])
359
+ best_class = torch.argmax(predictions[..., 5:], dim=-1).unsqueeze(-1)
360
+ else:
361
+ scores = predictions[..., 0:1]
362
+ best_class = predictions[..., 5:6]
363
+
364
+ cell_indices = (
365
+ torch.arange(S)
366
+ .repeat(predictions.shape[0], 3, S, 1)
367
+ .unsqueeze(-1)
368
+ .to(predictions.device)
369
+ )
370
+ x = 1 / S * (box_predictions[..., 0:1] + cell_indices)
371
+ y = 1 / S * (box_predictions[..., 1:2] + cell_indices.permute(0, 1, 3, 2, 4))
372
+ w_h = 1 / S * box_predictions[..., 2:4]
373
+ converted_bboxes = torch.cat((best_class, scores, x, y, w_h), dim=-1).reshape(BATCH_SIZE, num_anchors * S * S, 6)
374
+ return converted_bboxes.tolist()
375
+
376
+ def check_class_accuracy(model, loader, threshold):
377
+ model.eval()
378
+ tot_class_preds, correct_class = 0, 0
379
+ tot_noobj, correct_noobj = 0, 0
380
+ tot_obj, correct_obj = 0, 0
381
+
382
+ for idx, (x, y) in enumerate(tqdm(loader)):
383
+ x = x.to(config.DEVICE)
384
+ with torch.no_grad():
385
+ out = model(x)
386
+
387
+ for i in range(3):
388
+ y[i] = y[i].to(config.DEVICE)
389
+ obj = y[i][..., 0] == 1 # in paper this is Iobj_i
390
+ noobj = y[i][..., 0] == 0 # in paper this is Iobj_i
391
+
392
+ correct_class += torch.sum(
393
+ torch.argmax(out[i][..., 5:][obj], dim=-1) == y[i][..., 5][obj]
394
+ )
395
+ tot_class_preds += torch.sum(obj)
396
+
397
+ obj_preds = torch.sigmoid(out[i][..., 0]) > threshold
398
+ correct_obj += torch.sum(obj_preds[obj] == y[i][..., 0][obj])
399
+ tot_obj += torch.sum(obj)
400
+ correct_noobj += torch.sum(obj_preds[noobj] == y[i][..., 0][noobj])
401
+ tot_noobj += torch.sum(noobj)
402
+
403
+ print(f"Class accuracy is: {(correct_class/(tot_class_preds+1e-16))*100:2f}%")
404
+ print(f"No obj accuracy is: {(correct_noobj/(tot_noobj+1e-16))*100:2f}%")
405
+ print(f"Obj accuracy is: {(correct_obj/(tot_obj+1e-16))*100:2f}%")
406
+ model.train()
407
+
408
+
409
+ def get_mean_std(loader):
410
+ # var[X] = E[X**2] - E[X]**2
411
+ channels_sum, channels_sqrd_sum, num_batches = 0, 0, 0
412
+
413
+ for data, _ in tqdm(loader):
414
+ channels_sum += torch.mean(data, dim=[0, 2, 3])
415
+ channels_sqrd_sum += torch.mean(data ** 2, dim=[0, 2, 3])
416
+ num_batches += 1
417
+
418
+ mean = channels_sum / num_batches
419
+ std = (channels_sqrd_sum / num_batches - mean ** 2) ** 0.5
420
+
421
+ return mean, std
422
+
423
+
424
+ def save_checkpoint(model, optimizer, filename="my_checkpoint.pth.tar"):
425
+ print("=> Saving checkpoint")
426
+ checkpoint = {
427
+ "state_dict": model.state_dict(),
428
+ "optimizer": optimizer.state_dict(),
429
+ }
430
+ torch.save(checkpoint, filename)
431
+
432
+
433
+ def load_checkpoint(checkpoint_file, model, optimizer, lr):
434
+ print("=> Loading checkpoint")
435
+ checkpoint = torch.load(checkpoint_file, map_location=config.DEVICE)
436
+ model.load_state_dict(checkpoint["state_dict"])
437
+ optimizer.load_state_dict(checkpoint["optimizer"])
438
+
439
+ # If we don't do this then it will just have learning rate of old checkpoint
440
+ # and it will lead to many hours of debugging \:
441
+ for param_group in optimizer.param_groups:
442
+ param_group["lr"] = lr
443
+
444
+
445
+ def get_loaders(train_csv_path, test_csv_path):
446
+ from src.dataset import YOLODataset
447
+
448
+ IMAGE_SIZE = config.IMAGE_SIZE
449
+ train_dataset = YOLODataset(
450
+ train_csv_path,
451
+ transform=config.train_transforms,
452
+ S=[IMAGE_SIZE // 32, IMAGE_SIZE // 16, IMAGE_SIZE // 8],
453
+ img_dir=config.IMG_DIR,
454
+ label_dir=config.LABEL_DIR,
455
+ anchors=config.ANCHORS,
456
+ )
457
+ test_dataset = YOLODataset(
458
+ test_csv_path,
459
+ transform=config.test_transforms,
460
+ S=[IMAGE_SIZE // 32, IMAGE_SIZE // 16, IMAGE_SIZE // 8],
461
+ img_dir=config.IMG_DIR,
462
+ label_dir=config.LABEL_DIR,
463
+ anchors=config.ANCHORS,
464
+ )
465
+ train_loader = DataLoader(
466
+ dataset=train_dataset,
467
+ batch_size=config.BATCH_SIZE,
468
+ num_workers=config.NUM_WORKERS,
469
+ pin_memory=config.PIN_MEMORY,
470
+ shuffle=True,
471
+ drop_last=False,
472
+ )
473
+ test_loader = DataLoader(
474
+ dataset=test_dataset,
475
+ batch_size=config.BATCH_SIZE,
476
+ num_workers=config.NUM_WORKERS,
477
+ pin_memory=config.PIN_MEMORY,
478
+ shuffle=False,
479
+ drop_last=False,
480
+ )
481
+
482
+ train_eval_dataset = YOLODataset(
483
+ train_csv_path,
484
+ transform=config.test_transforms,
485
+ S=[IMAGE_SIZE // 32, IMAGE_SIZE // 16, IMAGE_SIZE // 8],
486
+ img_dir=config.IMG_DIR,
487
+ label_dir=config.LABEL_DIR,
488
+ anchors=config.ANCHORS,
489
+ )
490
+ train_eval_loader = DataLoader(
491
+ dataset=train_eval_dataset,
492
+ batch_size=config.BATCH_SIZE,
493
+ num_workers=config.NUM_WORKERS,
494
+ pin_memory=config.PIN_MEMORY,
495
+ shuffle=False,
496
+ drop_last=False,
497
+ )
498
+
499
+ return train_loader, test_loader, train_eval_loader
500
+
501
+ def plot_couple_examples(model, loader, thresh, iou_thresh, anchors):
502
+ model.eval()
503
+ x, y = next(iter(loader))
504
+ x = x.to("cuda")
505
+ with torch.no_grad():
506
+ out = model(x)
507
+ bboxes = [[] for _ in range(x.shape[0])]
508
+ for i in range(3):
509
+ batch_size, A, S, _, _ = out[i].shape
510
+ anchor = anchors[i]
511
+ boxes_scale_i = cells_to_bboxes(
512
+ out[i], anchor, S=S, is_preds=True
513
+ )
514
+ for idx, (box) in enumerate(boxes_scale_i):
515
+ bboxes[idx] += box
516
+
517
+ model.train()
518
+
519
+ for i in range(batch_size//4):
520
+ nms_boxes = non_max_suppression(
521
+ bboxes[i], iou_threshold=iou_thresh, threshold=thresh, box_format="midpoint",
522
+ )
523
+ plot_image(x[i].permute(1,2,0).detach().cpu(), nms_boxes)
524
+
525
+
526
+
527
+ def seed_everything(seed=42):
528
+ os.environ['PYTHONHASHSEED'] = str(seed)
529
+ random.seed(seed)
530
+ np.random.seed(seed)
531
+ torch.manual_seed(seed)
532
+ torch.cuda.manual_seed(seed)
533
+ torch.cuda.manual_seed_all(seed)
534
+ torch.backends.cudnn.deterministic = True
535
+ torch.backends.cudnn.benchmark = False
536
+
537
+
538
+ def clip_coords(boxes, img_shape):
539
+ # Clip bounding xyxy bounding boxes to image shape (height, width)
540
+ boxes[:, 0].clamp_(0, img_shape[1]) # x1
541
+ boxes[:, 1].clamp_(0, img_shape[0]) # y1
542
+ boxes[:, 2].clamp_(0, img_shape[1]) # x2
543
+ boxes[:, 3].clamp_(0, img_shape[0]) # y2
544
+
545
+ def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0):
546
+ # Convert nx4 boxes from [x, y, w, h] normalized to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
547
+ y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
548
+ y[..., 0] = w * (x[..., 0] - x[..., 2] / 2) + padw # top left x
549
+ y[..., 1] = h * (x[..., 1] - x[..., 3] / 2) + padh # top left y
550
+ y[..., 2] = w * (x[..., 0] + x[..., 2] / 2) + padw # bottom right x
551
+ y[..., 3] = h * (x[..., 1] + x[..., 3] / 2) + padh # bottom right y
552
+ return y
553
+
554
+
555
+ def xyn2xy(x, w=640, h=640, padw=0, padh=0):
556
+ # Convert normalized segments into pixel segments, shape (n,2)
557
+ y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
558
+ y[..., 0] = w * x[..., 0] + padw # top left x
559
+ y[..., 1] = h * x[..., 1] + padh # top left y
560
+ return y
561
+
562
+ def xyxy2xywhn(x, w=640, h=640, clip=False, eps=0.0):
563
+ # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] normalized where xy1=top-left, xy2=bottom-right
564
+ if clip:
565
+ clip_boxes(x, (h - eps, w - eps)) # warning: inplace clip
566
+ y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
567
+ y[..., 0] = ((x[..., 0] + x[..., 2]) / 2) / w # x center
568
+ y[..., 1] = ((x[..., 1] + x[..., 3]) / 2) / h # y center
569
+ y[..., 2] = (x[..., 2] - x[..., 0]) / w # width
570
+ y[..., 3] = (x[..., 3] - x[..., 1]) / h # height
571
+ return y
572
+
573
+ def clip_boxes(boxes, shape):
574
+ # Clip boxes (xyxy) to image shape (height, width)
575
+ if isinstance(boxes, torch.Tensor): # faster individually
576
+ boxes[..., 0].clamp_(0, shape[1]) # x1
577
+ boxes[..., 1].clamp_(0, shape[0]) # y1
578
+ boxes[..., 2].clamp_(0, shape[1]) # x2
579
+ boxes[..., 3].clamp_(0, shape[0]) # y2
580
+ else: # np.array (faster grouped)
581
+ boxes[..., [0, 2]] = boxes[..., [0, 2]].clip(0, shape[1]) # x1, x2
582
+ boxes[..., [1, 3]] = boxes[..., [1, 3]].clip(0, shape[0]) # y1, y2