flaviagiammarino commited on
Commit
6d43de4
1 Parent(s): bfaa649

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +4 -4
README.md CHANGED
@@ -23,11 +23,11 @@ The training was performed for 100 epochs with a batch size of 160 using the Ada
23
 
24
  ```python
25
  import requests
26
- import torch
27
  import numpy as np
28
  import matplotlib.pyplot as plt
29
  from PIL import Image
30
- from transformers import SamModel, SamProcessor
 
31
 
32
  device = "cuda" if torch.cuda.is_available() else "cpu"
33
 
@@ -40,7 +40,7 @@ input_boxes = [95., 255., 190., 350.]
40
 
41
  inputs = processor(raw_image, input_boxes=[[input_boxes]], return_tensors="pt").to(device)
42
  outputs = model(**inputs, multimask_output=False)
43
- masks = processor.image_processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu())
44
 
45
  def show_mask(mask, ax, random_color):
46
  if random_color:
@@ -62,7 +62,7 @@ show_box(input_boxes, ax[0])
62
  ax[0].set_title("Input Image and Bounding Box")
63
  ax[0].axis("off")
64
  ax[1].imshow(np.array(raw_image))
65
- show_mask(masks[0], ax=ax[1], random_color=False)
66
  show_box(input_boxes, ax[1])
67
  ax[1].set_title("MedSAM Segmentation")
68
  ax[1].axis("off")
 
23
 
24
  ```python
25
  import requests
 
26
  import numpy as np
27
  import matplotlib.pyplot as plt
28
  from PIL import Image
29
+ from transformers import SamModel, SamProcessor, SamImageProcessor
30
+ import torch
31
 
32
  device = "cuda" if torch.cuda.is_available() else "cpu"
33
 
 
40
 
41
  inputs = processor(raw_image, input_boxes=[[input_boxes]], return_tensors="pt").to(device)
42
  outputs = model(**inputs, multimask_output=False)
43
+ probs = processor.image_processor.post_process_masks(outputs.pred_masks.sigmoid().cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu(), binarize=False)
44
 
45
  def show_mask(mask, ax, random_color):
46
  if random_color:
 
62
  ax[0].set_title("Input Image and Bounding Box")
63
  ax[0].axis("off")
64
  ax[1].imshow(np.array(raw_image))
65
+ show_mask(mask=probs[0] > 0.5, ax=ax[1], random_color=False)
66
  show_box(input_boxes, ax[1])
67
  ax[1].set_title("MedSAM Segmentation")
68
  ax[1].axis("off")