vishnun commited on
Commit
15ad19e
·
1 Parent(s): 2c050f0
Files changed (1) hide show
  1. app.py +72 -0
app.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ from PIL import Image
4
+ from transformers import CLIPProcessor, CLIPModel, DetrFeatureExtractor, DetrForObjectDetection
5
+ import torch
6
+
7
+ feature_extractor = DetrFeatureExtractor.from_pretrained('facebook/detr-resnet-50')
8
+ dmodel = DetrForObjectDetection.from_pretrained('facebook/detr-resnet-50')
9
+
10
+ model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
11
+ processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
12
+
13
+ i = gr.inputs.Image()
14
+ o1 = gr.outputs.Image()
15
+ o2 = gr.outputs.Textbox()
16
+
17
+ def extract_image(image, text, num=1):
18
+
19
+ feature_extractor = DetrFeatureExtractor.from_pretrained('facebook/detr-resnet-50')
20
+ dmodel = DetrForObjectDetection.from_pretrained('facebook/detr-resnet-50')
21
+
22
+ inputs = feature_extractor(images=image, return_tensors="pt")
23
+ outputs = dmodel(**inputs)
24
+
25
+ # model predicts bounding boxes and corresponding COCO classes
26
+ logits = outputs.logits
27
+ bboxes = outputs.pred_boxes
28
+ probas = outputs.logits.softmax(-1)[0, :, :-1] #removing no class as detr maps
29
+
30
+ keep = probas.max(-1).values > 0.96
31
+ outs = feature_extractor.post_process(outputs, torch.tensor(image.size[::-1]).unsqueeze(0))
32
+ bboxes_scaled = outs[0]['boxes'][keep].detach().numpy()
33
+ labels = outs[0]['labels'][keep].detach().numpy()
34
+ scores = outs[0]['scores'][keep].detach().numpy()
35
+
36
+ images_list = []
37
+ for i,j in enumerate(bboxes_scaled):
38
+
39
+ xmin = int(j[0])
40
+ ymin = int(j[1])
41
+ xmax = int(j[2])
42
+ ymax = int(j[3])
43
+
44
+ im_arr = np.array(image)
45
+ roi = im_arr[ymin:ymax, xmin:xmax]
46
+ roi_im = Image.fromarray(roi)
47
+
48
+ images_list.append(roi_im)
49
+
50
+ inputs = processor(text = [text], images=images_list , return_tensors="pt", padding=True)
51
+ outputs = model(**inputs)
52
+ logits_per_image = outputs.logits_per_text
53
+ probs = logits_per_image.softmax(-1)
54
+ l_idx = np.argsort(probs[-1].detach().numpy())[::-1][0:num]
55
+
56
+ final_ims = []
57
+ for i,j in enumerate(images_list):
58
+ json_dict = {}
59
+ if i in l_idx:
60
+ json_dict['image'] = images_list[i]
61
+ json_dict['score'] = probs[-1].detach().numpy()[i]
62
+
63
+ final_ims.append(json_dict)
64
+
65
+ fi = sorted(final_ims, key=lambda item: item.get("score"), reverse=True)
66
+ return fi[0]['image'], fi[0]['score']
67
+
68
+ title = "ClipnCrop"
69
+ description = "Extract sections of images from your image by using OpenAI's CLIP and Facebooks Detr implemented on HuggingFace Transformers"
70
+ examples=[['ex1.jpg'],['ex2.jpg']]
71
+ article = "<p style='text-align: center'>"
72
+ gr.Interface(fn=extract_image, inputs=i, outputs=[o1, o2], title=title, description=description, article=article, examples=examples, enable_queue=True).launch()