Clement Vachet commited on
Commit
ffdfdcd
·
1 Parent(s): 9b658e7

Add detection python files

Browse files
Files changed (3) hide show
  1. detect_pipeline.py +6 -0
  2. detect_torch.py +115 -0
  3. detect_transformers.py +26 -0
detect_pipeline.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from transformers import pipeline
2
+
3
+ detector = pipeline(model="facebook/detr-resnet-50", revision="no_timm")
4
+ result = detector("http://images.cocodataset.org/val2017/000000039769.jpg")
5
+ print(result)
6
+ # x, y are expressed relative to the top left hand corner.
detect_torch.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Main file
2
+
3
+ from PIL import Image
4
+ import requests
5
+ import matplotlib.pyplot as plt
6
+
7
+
8
+ import torch
9
+ # from torch import nn
10
+ # from torchvision.models import resnet50
11
+ import torchvision.transforms as T
12
+ torch.set_grad_enabled(False);
13
+
14
+ # COCO classes
15
+ CLASSES = [
16
+ 'N/A', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
17
+ 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A',
18
+ 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse',
19
+ 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack',
20
+ 'umbrella', 'N/A', 'N/A', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis',
21
+ 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove',
22
+ 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'N/A', 'wine glass',
23
+ 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich',
24
+ 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake',
25
+ 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table', 'N/A',
26
+ 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard',
27
+ 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A',
28
+ 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier',
29
+ 'toothbrush'
30
+ ]
31
+
32
+ # colors for visualization
33
+ COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125],
34
+ [0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]]
35
+
36
+ # standard PyTorch mean-std input image normalization
37
+ transform = T.Compose([
38
+ T.Resize(800),
39
+ T.ToTensor(),
40
+ T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
41
+ ])
42
+
43
+ # for output bounding box post-processing
44
+ # Convert center of bounding box to relative image coordinates
45
+ # from (cx, cy, w, h) to (x0, y0, x1, y1)
46
+ def box_cxcywh_to_xyxy(x):
47
+ x_c, y_c, w, h = x.unbind(1)
48
+ b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
49
+ (x_c + 0.5 * w), (y_c + 0.5 * h)]
50
+ return torch.stack(b, dim=1)
51
+
52
+ # convert predictions to absolute image coordinates
53
+ def rescale_bboxes(out_bbox, size):
54
+ img_w, img_h = size
55
+ b = box_cxcywh_to_xyxy(out_bbox)
56
+ b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32)
57
+ return b
58
+
59
+ def plot_results(pil_img, prob, boxes):
60
+ plt.figure(figsize=(8,5))
61
+ plt.imshow(pil_img)
62
+ ax = plt.gca()
63
+ colors = COLORS * 100
64
+ for p, (xmin, ymin, xmax, ymax), c in zip(prob, boxes.tolist(), colors):
65
+ ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin,
66
+ fill=False, color=c, linewidth=3))
67
+ cl = p.argmax()
68
+ text = f'{CLASSES[cl]}: {p[cl]:0.2f}'
69
+ ax.text(xmin, ymin, text, fontsize=15,
70
+ bbox=dict(facecolor='yellow', alpha=0.5))
71
+ plt.axis('off')
72
+ plt.show()
73
+
74
+
75
+ def detect(im, model, transform):
76
+ # mean-std normalize the input image (batch-size: 1)
77
+ img = transform(im).unsqueeze(0)
78
+
79
+ # demo model only support by default images with aspect ratio between 0.5 and 2
80
+ # if you want to use images with an aspect ratio outside this range
81
+ # rescale your image so that the maximum size is at most 1333 for best results
82
+ assert img.shape[-2] <= 1600 and img.shape[
83
+ -1] <= 1600, 'demo model only supports images up to 1600 pixels on each side'
84
+
85
+ # propagate through the model
86
+ outputs = model(img)
87
+
88
+ # keep only predictions with 0.9+ confidence
89
+ probas = outputs['pred_logits'].softmax(-1)[0, :, :-1]
90
+ keep = probas.max(-1).values > 0.9
91
+
92
+ # convert boxes from [0; 1] to image scales
93
+ bboxes_scaled = rescale_bboxes(outputs['pred_boxes'][0, keep], im.size)
94
+ return probas[keep], bboxes_scaled
95
+
96
+ def load_model():
97
+ model = torch.hub.load('facebookresearch/detr', 'detr_resnet50', pretrained=True)
98
+ model.eval();
99
+ return model
100
+
101
+ def main():
102
+ url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
103
+ im = Image.open(requests.get(url, stream=True).raw)
104
+ model = load_model()
105
+ scores, boxes = detect(im, model, transform)
106
+ print('len(scores)',len(scores))
107
+ print('scores[0].shape', scores[0].shape)
108
+ print('scores', scores)
109
+ print('len(boxes)',len(boxes))
110
+ print('boxes',boxes)
111
+ plot_results(im, scores, boxes)
112
+
113
+ if __name__ == "__main__":
114
+ main()
115
+
detect_transformers.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import DetrImageProcessor, DetrForObjectDetection
2
+ import torch
3
+ from PIL import Image
4
+ import requests
5
+
6
+ url = "http://images.cocodataset.org/val2017/000000039769.jpg"
7
+ image = Image.open(requests.get(url, stream=True).raw)
8
+
9
+ # you can specify the revision tag if you don't want the timm dependency
10
+ processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50", revision="no_timm")
11
+ model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50", revision="no_timm")
12
+
13
+ inputs = processor(images=image, return_tensors="pt")
14
+ outputs = model(**inputs)
15
+
16
+ # convert outputs (bounding boxes and class logits) to COCO API
17
+ # let's only keep detections with score > 0.9
18
+ target_sizes = torch.tensor([image.size[::-1]])
19
+ results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0]
20
+
21
+ for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
22
+ box = [round(i, 2) for i in box.tolist()]
23
+ print(
24
+ f"Detected {model.config.id2label[label.item()]} with confidence "
25
+ f"{round(score.item(), 3)} at location {box}"
26
+ )