shengqiangShi commited on
Commit
4dfaafd
·
1 Parent(s): 837d4ca

Application file

Browse files
Files changed (1) hide show
  1. app.py +16 -16
app.py CHANGED
@@ -32,15 +32,15 @@ if torch.cuda.is_available():
32
  else:
33
  device = torch.device("cpu")
34
 
35
- model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16-ensemble").to(device)
36
- processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16-ensemble")
37
- model_sam = SamModel.from_pretrained("facebook/sam-vit-huge").to(device)
38
- processor_sam = SamProcessor.from_pretrained("facebook/sam-vit-huge")
39
 
40
- # model = Owlv2ForObjectDetection.from_pretrained("owlv2-base-patch16-ensemble").to(device)
41
- # processor = Owlv2Processor.from_pretrained("owlv2-base-patch16-ensemble")
42
- # model_sam = SamModel.from_pretrained("SAM/sam-vit-huge").to(device)
43
- # processor_sam = SamProcessor.from_pretrained("SAM/sam-vit-huge")
44
 
45
 
46
  @spaces.GPU
@@ -75,13 +75,13 @@ def query_image(img, text_queries, score_threshold=0.5):
75
  return sam_image,result_labels
76
 
77
 
78
- def generate_image_with_sam(img, boxes):
79
  img_pil = Image.fromarray(img.astype('uint8'), 'RGB')
80
  inputs = processor_sam(img_pil, return_tensors="pt").to(device)
81
 
82
  image_embeddings = model_sam.get_image_embeddings(inputs["pixel_values"])
83
 
84
- inputs = processor_sam(img_pil, input_boxes=[boxes], return_tensors="pt").to(device)
85
  inputs["input_boxes"].shape
86
  inputs.pop("pixel_values", None)
87
  inputs.update({"image_embeddings": image_embeddings})
@@ -101,16 +101,16 @@ Split anythings
101
  """
102
  demo = gr.Interface(
103
  fn=query_image,
104
- inputs=[gr.Image(), gr.Textbox(label="Query Text"), gr.Slider(0, 1, value=0.5, label="Score Threshold")],
105
  outputs=gr.AnnotatedImage(),
106
  title="Zero-Shot Object Detection SV3",
107
  description="This interface demonstrates object detection using zero-shot object detection and SAM for image segmentation.",
108
- examples=[
109
- ["images/purple cell.png", "purple cells", 0.05],
110
- ["images/dark_cell.png", "gray cells", 0.1],
111
- ["images/animals.png", "Rabbit,Squirrel,Parrot,Hedgehog,Turtle,Ladybug,Chick,Frog,Butterfly,Snail,Mouse", 0.35],
112
 
113
- ],
114
  )
115
 
116
  demo.launch()
 
32
  else:
33
  device = torch.device("cpu")
34
 
35
+ # model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16-ensemble").to(device)
36
+ # processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16-ensemble")
37
+ # model_sam = SamModel.from_pretrained("facebook/sam-vit-huge").to(device)
38
+ # processor_sam = SamProcessor.from_pretrained("facebook/sam-vit-huge")
39
 
40
+ model = Owlv2ForObjectDetection.from_pretrained("owlv2-base-patch16-ensemble").to(device)
41
+ processor = Owlv2Processor.from_pretrained("owlv2-base-patch16-ensemble")
42
+ model_sam = SamModel.from_pretrained("SAM/sam-vit-huge").to(device)
43
+ processor_sam = SamProcessor.from_pretrained("SAM/sam-vit-huge")
44
 
45
 
46
  @spaces.GPU
 
75
  return sam_image,result_labels
76
 
77
 
78
+ def generate_image_with_sam(img, input_boxes):
79
  img_pil = Image.fromarray(img.astype('uint8'), 'RGB')
80
  inputs = processor_sam(img_pil, return_tensors="pt").to(device)
81
 
82
  image_embeddings = model_sam.get_image_embeddings(inputs["pixel_values"])
83
 
84
+ inputs = processor_sam(img_pil, input_boxes=[input_boxes], return_tensors="pt").to(device)
85
  inputs["input_boxes"].shape
86
  inputs.pop("pixel_values", None)
87
  inputs.update({"image_embeddings": image_embeddings})
 
101
  """
102
  demo = gr.Interface(
103
  fn=query_image,
104
+ inputs=[gr.Image(), gr.Textbox(label="Query Text"), gr.Slider(0, 1, value=0.1, label="Score Threshold")],
105
  outputs=gr.AnnotatedImage(),
106
  title="Zero-Shot Object Detection SV3",
107
  description="This interface demonstrates object detection using zero-shot object detection and SAM for image segmentation.",
108
+ # examples=[
109
+ # ["images/purple cell.png", "purple cells", 0.05],
110
+ # ["images/dark_cell.png", "gray cells", 0.1],
111
+ # ["images/animals.png", "Rabbit,Squirrel,Parrot,Hedgehog,Turtle,Ladybug,Chick,Frog,Butterfly,Snail,Mouse", 0.35],
112
 
113
+ # ],
114
  )
115
 
116
  demo.launch()