aryswisnu commited on
Commit
28c2469
·
1 Parent(s): cfe8e4c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -14
app.py CHANGED
@@ -1,6 +1,6 @@
1
  from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
2
  import gradio as gr
3
- from PIL import Image, ImageFilter
4
  import torch
5
  import matplotlib.pyplot as plt
6
  import torch
@@ -44,28 +44,33 @@ def get_masks(prompts, img, threhsold):
44
  return masks
45
 
46
 
47
- def extract_image(img, pos_prompts, neg_prompts, threshold, blur_radius=5):
48
  positive_masks = get_masks(pos_prompts, img, threshold)
49
  negative_masks = get_masks(neg_prompts, img, threshold)
50
 
51
- # combine masks into one masks, logic OR
52
  pos_mask = np.any(np.stack(positive_masks), axis=0)
53
  neg_mask = np.any(np.stack(negative_masks), axis=0)
54
  final_mask = pos_mask & ~neg_mask
55
 
56
- # apply Gaussian blur for feathering
57
- final_mask_img = Image.fromarray((final_mask * 255).astype(np.uint8), "L")
58
- final_mask_img = final_mask_img.filter(ImageFilter.GaussianBlur(radius=blur_radius))
59
- final_mask = np.array(final_mask_img) / 255
60
- final_mask = final_mask > threshold
61
 
62
- # extract the final image
63
- final_mask = Image.fromarray(final_mask.astype(np.uint8) * 255, "L")
64
- inverse_mask = np.invert(final_mask)
65
- output_image = Image.new("RGBA", img.size, (0, 0, 0, 0))
66
- output_image.paste(img, mask=final_mask)
67
 
68
- return output_image, final_mask, inverse_mask
 
 
 
 
 
 
 
69
 
70
 
71
 
 
1
  from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
2
  import gradio as gr
3
+ from PIL import Image
4
  import torch
5
  import matplotlib.pyplot as plt
6
  import torch
 
44
  return masks
45
 
46
 
47
+ def extract_image(img, pos_prompts, neg_prompts, threshold, alpha_value=0.5):
48
  positive_masks = get_masks(pos_prompts, img, threshold)
49
  negative_masks = get_masks(neg_prompts, img, threshold)
50
 
51
+ # combine masks into one mask, logic OR
52
  pos_mask = np.any(np.stack(positive_masks), axis=0)
53
  neg_mask = np.any(np.stack(negative_masks), axis=0)
54
  final_mask = pos_mask & ~neg_mask
55
 
56
+ # threshold the mask
57
+ bmask = final_mask > threshold
58
+ # zero out values below the threshold
59
+ final_mask[final_mask < threshold] = 0
 
60
 
61
+ # convert PIL image to RGBA numpy array
62
+ img_np = np.array(img.convert("RGBA"))
63
+ # create an empty RGBA image with the same size
64
+ output_image = np.zeros_like(img_np)
 
65
 
66
+ # apply the final_mask as alpha channel on the output image
67
+ output_image[:, :, :3] = img_np[:, :, :3]
68
+ output_image[:, :, 3] = (final_mask * 255 * alpha_value).astype(np.uint8)
69
+
70
+ # convert the output_image back to a PIL.Image object
71
+ output_image = Image.fromarray(output_image, "RGBA")
72
+
73
+ return output_image, final_mask, bmask
74
 
75
 
76