Harshithtd commited on
Commit
6f359a3
1 Parent(s): cbc8634

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -0
app.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import torch
4
+ from PIL import Image
5
+ from transformers import SamModel, SamProcessor
6
+ from gradio_image_prompter import ImagePrompter
7
+ import spaces
8
+
9
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
+ slimsam_model = SamModel.from_pretrained("nielsr/slimsam-50-uniform").to(device)
11
+ slimsam_processor = SamProcessor.from_pretrained("nielsr/slimsam-50-uniform")
12
+
13
+ @spaces.GPU
14
+ def sam_box_inference(image, x_min, y_min, x_max, y_max):
15
+ inputs = slimsam_processor(
16
+ Image.fromarray(image),
17
+ input_boxes=[[[[x_min, y_min, x_max, y_max]]]],
18
+ return_tensors="pt"
19
+ ).to(device)
20
+
21
+ with torch.no_grad():
22
+ outputs = slimsam_model(**inputs)
23
+
24
+ mask = slimsam_processor.image_processor.post_process_masks(
25
+ outputs.pred_masks.cpu(),
26
+ inputs["original_sizes"].cpu(),
27
+ inputs["reshaped_input_sizes"].cpu()
28
+ )[0][0][0].numpy()
29
+ mask = mask[np.newaxis, ...]
30
+ print(mask)
31
+ print(mask.shape)
32
+ return [(mask, "mask")]
33
+
34
+ def infer_box(prompts):
35
+ image = prompts["image"]
36
+ if image is None:
37
+ gr.Error("Please upload an image and draw a box before submitting.")
38
+ points = prompts["points"][0]
39
+ if points is None:
40
+ gr.Error("Please draw a box before submitting.")
41
+ print(points)
42
+ return [(image, sam_box_inference(image, points[0], points[1], points[3], points[4]))]
43
+
44
+ with gr.Blocks(title="SlimSAM Box Prompt") as demo:
45
+ gr.Markdown("# SlimSAM Box Prompt")
46
+ gr.Markdown("In this demo, you can upload an image and draw a box for SlimSAM to process.")
47
+
48
+ with gr.Row():
49
+ with gr.Column():
50
+ im = ImagePrompter()
51
+ btn = gr.Button("Submit")
52
+ with gr.Column():
53
+ output_box_slimsam = gr.AnnotatedImage(label="SlimSAM Output")
54
+
55
+ btn.click(infer_box, inputs=im, outputs=[output_box_slimsam])
56
+
57
+ demo.launch(debug=True)