Ashoka74 commited on
Commit
917c9b5
1 Parent(s): 3d304ce

Update inference_i2mv_sdxl.py

Browse files
Files changed (1) hide show
  1. inference_i2mv_sdxl.py +37 -8
inference_i2mv_sdxl.py CHANGED
@@ -71,16 +71,44 @@ def prepare_pipeline(
71
  return pipe
72
 
73
 
74
- def remove_bg(image, net, transform, device):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  image_size = image.size
76
- input_images = transform(image).unsqueeze(0).to(device)
77
- with torch.no_grad():
78
- preds = net(input_images)[-1].sigmoid().cpu()
79
- pred = preds[0].squeeze()
80
- pred_pil = transforms.ToPILImage()(pred)
81
- mask = pred_pil.resize(image_size)
82
  image.putalpha(mask)
83
  return image
 
84
 
85
 
86
  def preprocess_image(image: Image.Image, height, width):
@@ -150,7 +178,8 @@ def run_pipeline(
150
  # Prepare image
151
  reference_image = Image.open(image) if isinstance(image, str) else image
152
  if remove_bg_fn is not None:
153
- reference_image = remove_bg_fn(reference_image)
 
154
  reference_image = preprocess_image(reference_image, height, width)
155
  elif reference_image.mode == "RGBA":
156
  reference_image = preprocess_image(reference_image, height, width)
 
71
  return pipe
72
 
73
 
74
+ # def remove_bg(image, net, transform, device):
75
+ # image_size = image.size
76
+ # input_images = transform(image).unsqueeze(0).to(device)
77
+ # with torch.no_grad():
78
+ # preds = net(input_images)[-1].sigmoid().cpu()
79
+ # pred = preds[0].squeeze()
80
+ # pred_pil = transforms.ToPILImage()(pred)
81
+ # mask = pred_pil.resize(image_size)
82
+ # image.putalpha(mask)
83
+ # return image
84
+
85
+
86
+ def remove_bg(image: Image.Image, net, transform, device, mask: Image.Image = None):
87
+ """
88
+ Applies a pre-existing mask to an image to make the background transparent.
89
+
90
+ Args:
91
+ image (PIL.Image.Image): The input image.
92
+ net: Pre-trained neural network (not used but kept for compatibility).
93
+ transform: Image transformation object (not used but kept for compatibility).
94
+ device: Device used for inference (not used but kept for compatibility).
95
+ mask (PIL.Image.Image, optional): The mask to use. Should be the same size
96
+ as the input image, with values between 0 and 255 (or 0-1).
97
+ If None, will return image with no changes.
98
+
99
+ Returns:
100
+ PIL.Image.Image: The modified image with transparent background.
101
+ """
102
+ if mask is None:
103
+ return image
104
+
105
  image_size = image.size
106
+ if mask.size != image_size:
107
+ mask = mask.resize(image_size) # Resizing the mask if it is not the same size as image
108
+
 
 
 
109
  image.putalpha(mask)
110
  return image
111
+
112
 
113
 
114
  def preprocess_image(image: Image.Image, height, width):
 
178
  # Prepare image
179
  reference_image = Image.open(image) if isinstance(image, str) else image
180
  if remove_bg_fn is not None:
181
+ # reference_image = remove_bg_fn(reference_image)
182
+
183
  reference_image = preprocess_image(reference_image, height, width)
184
  elif reference_image.mode == "RGBA":
185
  reference_image = preprocess_image(reference_image, height, width)