sagar007 commited on
Commit
9a34a8b
1 Parent(s): f86e74e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -83
app.py CHANGED
@@ -2,50 +2,74 @@ import gradio as gr
2
  import torch
3
  import cv2
4
  import numpy as np
5
- from transformers import SamModel, SamProcessor, BlipProcessor, BlipForConditionalGeneration
6
  from PIL import Image
7
- from scipy.ndimage import label, center_of_mass
 
 
 
8
 
9
  # Set up device
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
11
 
12
- # Load SAM model and processor
13
- sam_model = SamModel.from_pretrained("facebook/sam-vit-base").to(device)
14
- sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
15
 
16
- # Load BLIP model and processor for image-to-text
17
- blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
18
- blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to(device)
 
 
 
19
 
20
- def process_mask(mask, target_size):
21
- if mask.ndim > 2:
22
- mask = mask.squeeze()
23
- if mask.ndim > 2:
24
- mask = mask[0]
25
- mask = (mask > 0.5).astype(np.uint8) * 255
26
- mask_image = Image.fromarray(mask)
27
- mask_image = mask_image.resize(target_size, Image.NEAREST)
28
- return np.array(mask_image) > 0
 
29
 
30
- def is_cat_like(mask, image_area):
31
- labeled, num_features = label(mask)
32
- if num_features == 0:
33
- return False
34
-
35
- largest_component = (labeled == (np.bincount(labeled.flatten())[1:].argmax() + 1))
36
- area = largest_component.sum()
37
-
38
- # Check if the area is reasonable for a cat (between 5% and 30% of image)
39
- if not (0.05 * image_area < area < 0.3 * image_area):
40
- return False
41
-
42
- # Check if the shape is roughly elliptical
43
- cy, cx = center_of_mass(largest_component)
44
- major_axis = max(largest_component.shape)
45
- minor_axis = min(largest_component.shape)
46
- aspect_ratio = major_axis / minor_axis
47
-
48
- return 1.5 < aspect_ratio < 3 # Most cats have an aspect ratio in this range
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
  def segment_image(input_image, object_name):
51
  try:
@@ -53,55 +77,21 @@ def segment_image(input_image, object_name):
53
  return None, "Please upload an image before submitting."
54
 
55
  input_image = Image.fromarray(input_image).convert("RGB")
56
- original_size = input_image.size
57
- if not original_size or 0 in original_size:
58
- return None, "Invalid image size. Please upload a different image."
59
-
60
- # Generate detailed image caption
61
- blip_inputs = blip_processor(input_image, return_tensors="pt").to(device)
62
- caption = blip_model.generate(**blip_inputs, max_length=50)
63
- caption_text = blip_processor.decode(caption[0], skip_special_tokens=True)
64
-
65
- # Process the image with SAM
66
- sam_inputs = sam_processor(input_image, return_tensors="pt").to(device)
67
-
68
- # Generate masks
69
- with torch.no_grad():
70
- sam_outputs = sam_model(**sam_inputs)
71
-
72
- # Post-process masks
73
- masks = sam_processor.image_processor.post_process_masks(
74
- sam_outputs.pred_masks.cpu(),
75
- sam_inputs["original_sizes"].cpu(),
76
- sam_inputs["reshaped_input_sizes"].cpu()
77
- )
78
-
79
- # Find the mask that best matches the specified object
80
- best_mask = None
81
- best_score = -1
82
- image_area = original_size[0] * original_size[1]
83
 
84
- cat_related_words = ['cat', 'kitten', 'feline', 'tabby', 'kitty']
85
- caption_contains_cat = any(word in caption_text.lower() for word in cat_related_words)
86
 
87
- for mask in masks[0]:
88
- mask_binary = mask.numpy() > 0.5
89
- if is_cat_like(mask_binary, image_area) and caption_contains_cat:
90
- mask_area = mask_binary.sum()
91
- if mask_area > best_score:
92
- best_mask = mask_binary
93
- best_score = mask_area
94
 
95
- if best_mask is None:
96
- return input_image, f"Could not find a suitable '{object_name}' in the image."
97
 
98
- combined_mask = process_mask(best_mask, original_size)
 
99
 
100
- # Overlay the mask on the original image
101
- result_image = np.array(input_image)
102
- mask_rgb = np.zeros_like(result_image)
103
- mask_rgb[combined_mask] = [255, 0, 0] # Red color for the mask
104
- result_image = cv2.addWeighted(result_image, 1, mask_rgb, 0.5, 0)
105
 
106
  return result_image, f"Segmented '{object_name}' in the image."
107
 
@@ -116,11 +106,11 @@ iface = gr.Interface(
116
  gr.Textbox(label="Specify object to segment (e.g., dog, cat, grass)")
117
  ],
118
  outputs=[
119
- gr.Image(type="numpy", label="Segmented Image"),
120
  gr.Textbox(label="Status")
121
  ],
122
- title="Segment Anything Model (SAM) with Object Specification",
123
- description="Upload an image and specify an object to segment."
124
  )
125
 
126
  # Launch the interface
 
2
  import torch
3
  import cv2
4
  import numpy as np
 
5
  from PIL import Image
6
+ import matplotlib.pyplot as plt
7
+ import io
8
+ from ultralytics import FastSAM
9
+ from ultralytics.models.fastsam import FastSAMPrompt
10
 
11
  # Set up device
12
  device = "cuda" if torch.cuda.is_available() else "cpu"
13
 
14
+ # Load FastSAM model
15
+ model = FastSAM("FastSAM-s.pt") # or FastSAM-x.pt
 
16
 
17
+ def fig2img(fig):
18
+ buf = io.BytesIO()
19
+ fig.savefig(buf)
20
+ buf.seek(0)
21
+ img = Image.open(buf)
22
+ return img
23
 
24
+ def plot(annotations, prompt_process, mask_random_color=True, better_quality=True, retina=True, with_contours=True):
25
+ for ann in annotations:
26
+ image = ann.orig_img[..., ::-1] # BGR to RGB
27
+ original_h, original_w = ann.orig_shape
28
+ fig = plt.figure(figsize=(original_w / 100, original_h / 100))
29
+ plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
30
+ plt.margins(0, 0)
31
+ plt.gca().xaxis.set_major_locator(plt.NullLocator())
32
+ plt.gca().yaxis.set_major_locator(plt.NullLocator())
33
+ plt.imshow(image)
34
 
35
+ if ann.masks is not None:
36
+ masks = ann.masks.data
37
+ if better_quality:
38
+ if isinstance(masks[0], torch.Tensor):
39
+ masks = np.array(masks.cpu())
40
+ for i, mask in enumerate(masks):
41
+ mask = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8))
42
+ masks[i] = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((8, 8), np.uint8))
43
+
44
+ prompt_process.fast_show_mask(
45
+ masks,
46
+ plt.gca(),
47
+ random_color=mask_random_color,
48
+ bbox=None,
49
+ points=None,
50
+ pointlabel=None,
51
+ retinamask=retina,
52
+ target_height=original_h,
53
+ target_width=original_w,
54
+ )
55
+
56
+ if with_contours:
57
+ contour_all = []
58
+ temp = np.zeros((original_h, original_w, 1))
59
+ for i, mask in enumerate(masks):
60
+ mask = mask.astype(np.uint8)
61
+ if not retina:
62
+ mask = cv2.resize(mask, (original_w, original_h), interpolation=cv2.INTER_NEAREST)
63
+ contours, _ = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
64
+ contour_all.extend(iter(contours))
65
+ cv2.drawContours(temp, contour_all, -1, (255, 255, 255), 2)
66
+ color = np.array([0 / 255, 0 / 255, 1.0, 0.8])
67
+ contour_mask = temp / 255 * color.reshape(1, 1, -1)
68
+ plt.imshow(contour_mask)
69
+
70
+ plt.axis("off")
71
+ plt.close()
72
+ return fig2img(fig)
73
 
74
  def segment_image(input_image, object_name):
75
  try:
 
77
  return None, "Please upload an image before submitting."
78
 
79
  input_image = Image.fromarray(input_image).convert("RGB")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
+ # Run FastSAM model
82
+ everything_results = model(input_image, retina_masks=True, imgsz=1024, conf=0.4, iou=0.9)
83
 
84
+ # Prepare a Prompt Process object
85
+ prompt_process = FastSAMPrompt(input_image, everything_results, device=device)
 
 
 
 
 
86
 
87
+ # Use text prompt to segment the specified object
88
+ results = prompt_process.text_prompt(text=object_name)
89
 
90
+ if not results:
91
+ return input_image, f"Could not find '{object_name}' in the image."
92
 
93
+ # Plot the results
94
+ result_image = plot(annotations=results, prompt_process=prompt_process)
 
 
 
95
 
96
  return result_image, f"Segmented '{object_name}' in the image."
97
 
 
106
  gr.Textbox(label="Specify object to segment (e.g., dog, cat, grass)")
107
  ],
108
  outputs=[
109
+ gr.Image(type="pil", label="Segmented Image"),
110
  gr.Textbox(label="Status")
111
  ],
112
+ title="FastSAM Segmentation with Object Specification",
113
+ description="Upload an image and specify an object to segment using FastSAM."
114
  )
115
 
116
  # Launch the interface