Change prompt from a bounding box to point and click

#2
Files changed (1) hide show
  1. app.py +27 -32
app.py CHANGED
@@ -30,25 +30,23 @@ def load_image(file_path):
30
  return img, H, W
31
 
32
  @torch.no_grad()
33
- def medsam_inference(medsam_model, img_embed, box_1024, H, W):
34
- box_torch = torch.as_tensor(box_1024, dtype=torch.float, device=img_embed.device)
35
- if len(box_torch.shape) == 2:
36
- box_torch = box_torch[:, None, :] # (B, 1, 4)
37
 
38
- box_torch=box_torch.reshape(1,4)
39
  sparse_embeddings, dense_embeddings = medsam_model.prompt_encoder(
40
- points=None,
41
- boxes=box_torch,
42
  masks=None,
43
  )
44
 
45
  low_res_logits, _ = medsam_model.mask_decoder(
46
- image_embeddings=img_embed, # (B, 256, 64, 64)
47
- image_pe=medsam_model.prompt_encoder.get_dense_pe(), # (1, 256, 64, 64)
48
- sparse_prompt_embeddings=sparse_embeddings, # (B, 2, 256)
49
- dense_prompt_embeddings=dense_embeddings, # (B, 256, 64, 64)
50
  multimask_output=False,
51
- )
52
 
53
  low_res_pred = torch.sigmoid(low_res_logits) # (1, 1, 256, 256)
54
 
@@ -58,15 +56,16 @@ def medsam_inference(medsam_model, img_embed, box_1024, H, W):
58
  mode="bilinear",
59
  align_corners=False,
60
  ) # (1, 1, gt.shape)
61
- low_res_pred = low_res_pred.squeeze().cpu().numpy() # (256, 256)
62
  medsam_seg = (low_res_pred > 0.5).astype(np.uint8)
63
  return medsam_seg
64
 
65
  # Function for visualizing images with masks
66
- def visualize(image, mask, box):
67
  fig, ax = plt.subplots(1, 2, figsize=(10, 5))
68
  ax[0].imshow(image, cmap='gray')
69
- ax[0].add_patch(plt.Rectangle((box[0], box[1]), box[2] - box[0], box[3] - box[1], edgecolor="red", facecolor="none"))
 
70
  ax[1].imshow(image, cmap='gray')
71
  ax[1].imshow(mask, alpha=0.5, cmap="jet")
72
  plt.tight_layout()
@@ -78,19 +77,18 @@ def visualize(image, mask, box):
78
  buf.seek(0)
79
  pil_img = Image.open(buf)
80
 
81
- return pil_img
 
82
  # Main function for Gradio app
83
  def process_images(img_dict):
84
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
85
 
86
  # Load and preprocess image
87
- print(img_dict)
88
  img = img_dict['image']
89
- points = img_dict['points'][0] # Accessing the first (and possibly only) set of points
90
- if len(points) >= 6:
91
- x_min, y_min, x_max, y_max = points[0], points[1], points[3], points[4]
92
- else:
93
- raise ValueError("Insufficient data for bounding box coordinates.")
94
  image, H, W = img, img.shape[0], img.shape[1]
95
  if len(image.shape) == 2:
96
  image = np.repeat(image[:, :, None], 3, axis=-1)
@@ -106,20 +104,17 @@ def process_images(img_dict):
106
  medsam_model = medsam_model.to(device)
107
  medsam_model.eval()
108
 
109
- # Generate image embedding
110
- with torch.no_grad():
111
- img_embed = medsam_model.image_encoder(image_tensor)
112
-
113
- # Calculate resized box coordinates
114
- scale_factors = np.array([1024 / W, 1024 / H, 1024 / W, 1024 / H])
115
- box_1024 = np.array([x_min, y_min, x_max, y_max]) * scale_factors
116
 
117
  # Perform inference
118
- mask = medsam_inference(medsam_model, img_embed, box_1024, H, W)
119
 
120
  # Visualization
121
- visualization = visualize(image, mask, [x_min, y_min, x_max, y_max])
122
  return visualization
 
123
  # Set up Gradio interface
124
  iface = gr.Interface(
125
  fn=process_images,
@@ -130,7 +125,7 @@ iface = gr.Interface(
130
  gr.Image(type="pil", label="Processed Image")
131
  ],
132
  title="ROI Selection with MEDSAM",
133
- description="Upload an image (including NRRD files) and select regions of interest for processing."
134
  )
135
 
136
  # Launch the interface
 
30
  return img, H, W
31
 
32
  @torch.no_grad()
33
+ def medsam_inference(medsam_model, img_embed, points_1024, H, W):
34
+ points_torch = torch.as_tensor(points_1024, dtype=torch.float, device=img_embed.device)
35
+ points_torch = points_torch.reshape(1, -1, 2) # (1, N, 2)
 
36
 
 
37
  sparse_embeddings, dense_embeddings = medsam_model.prompt_encoder(
38
+ points=points_torch,
39
+ boxes=None,
40
  masks=None,
41
  )
42
 
43
  low_res_logits, _ = medsam_model.mask_decoder(
44
+ image_embeddings=img_embed, # (B, 256, 64, 64)
45
+ image_pe=medsam_model.prompt_encoder.get_dense_pe(), # (1, 256, 64, 64)
46
+ sparse_prompt_embeddings=sparse_embeddings, # (B, 2, 256)
47
+ dense_prompt_embeddings=dense_embeddings, # (B, 256, 64, 64)
48
  multimask_output=False,
49
+ )
50
 
51
  low_res_pred = torch.sigmoid(low_res_logits) # (1, 1, 256, 256)
52
 
 
56
  mode="bilinear",
57
  align_corners=False,
58
  ) # (1, 1, gt.shape)
59
+ low_res_pred = low_res_pred.squeeze().cpu().numpy() # (H, W)
60
  medsam_seg = (low_res_pred > 0.5).astype(np.uint8)
61
  return medsam_seg
62
 
63
  # Function for visualizing images with masks
64
+ def visualize(image, mask, points):
65
  fig, ax = plt.subplots(1, 2, figsize=(10, 5))
66
  ax[0].imshow(image, cmap='gray')
67
+ for point in points:
68
+ ax[0].plot(point[0], point[1], 'ro') # Mark points on the image
69
  ax[1].imshow(image, cmap='gray')
70
  ax[1].imshow(mask, alpha=0.5, cmap="jet")
71
  plt.tight_layout()
 
77
  buf.seek(0)
78
  pil_img = Image.open(buf)
79
 
80
+ return pil_img
81
+
82
  # Main function for Gradio app
83
  def process_images(img_dict):
84
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
85
 
86
  # Load and preprocess image
 
87
  img = img_dict['image']
88
+ points = img_dict['points']
89
+ if len(points) == 0:
90
+ raise ValueError("No points provided.")
91
+
 
92
  image, H, W = img, img.shape[0], img.shape[1]
93
  if len(image.shape) == 2:
94
  image = np.repeat(image[:, :, None], 3, axis=-1)
 
104
  medsam_model = medsam_model.to(device)
105
  medsam_model.eval()
106
 
107
+ # Calculate resized point coordinates
108
+ scale_factors = np.array([1024 / W, 1024 / H])
109
+ points_1024 = np.array(points) * scale_factors
 
 
 
 
110
 
111
  # Perform inference
112
+ mask = medsam_inference(medsam_model, img_embed, points_1024, H, W)
113
 
114
  # Visualization
115
+ visualization = visualize(image, mask, points)
116
  return visualization
117
+
118
  # Set up Gradio interface
119
  iface = gr.Interface(
120
  fn=process_images,
 
125
  gr.Image(type="pil", label="Processed Image")
126
  ],
127
  title="ROI Selection with MEDSAM",
128
+ description="Upload an image (including NRRD files) and select points of interest for processing."
129
  )
130
 
131
  # Launch the interface