Ashoka74 commited on
Commit
6b7695c
1 Parent(s): 5c1cf77

Update inference_i2mv_sdxl.py

Browse files
Files changed (1) hide show
  1. inference_i2mv_sdxl.py +22 -21
inference_i2mv_sdxl.py CHANGED
@@ -70,19 +70,6 @@ def prepare_pipeline(
70
 
71
  return pipe
72
 
73
-
74
- # def remove_bg(image, net, transform, device):
75
- # image_size = image.size
76
- # input_images = transform(image).unsqueeze(0).to(device)
77
- # with torch.no_grad():
78
- # preds = net(input_images)[-1].sigmoid().cpu()
79
- # pred = preds[0].squeeze()
80
- # pred_pil = transforms.ToPILImage()(pred)
81
- # mask = pred_pil.resize(image_size)
82
- # image.putalpha(mask)
83
- # return image
84
-
85
-
86
  def remove_bg(image: Image.Image, net, transform, device, mask: Image.Image = None):
87
  """
88
  Applies a pre-existing mask to an image to make the background transparent.
@@ -108,21 +95,35 @@ def remove_bg(image: Image.Image, net, transform, device, mask: Image.Image = No
108
 
109
  image.putalpha(mask)
110
  return image
111
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
 
113
 
114
  def preprocess_image(image: Image.Image, height, width):
115
 
116
- # alpha = image[..., 3] > 0
117
  # alpha = image
118
 
119
- if image.mode in ("RGBA", "LA"):
120
- image = np.array(image)
121
- alpha = image[..., 3] # Extract the alpha channel
122
- elif image.mode in ("RGB"):
123
- image = np.array(image)
124
  # Create default alpha for non-alpha images
125
- alpha = np.ones(image[..., 0].shape, dtype=np.uint8) * 255 # Create
126
  H, W = alpha.shape
127
  # get the bounding box of alpha
128
  y, x = np.where(alpha)
 
70
 
71
  return pipe
72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  def remove_bg(image: Image.Image, net, transform, device, mask: Image.Image = None):
74
  """
75
  Applies a pre-existing mask to an image to make the background transparent.
 
95
 
96
  image.putalpha(mask)
97
  return image
98
+
99
+
100
+ def remove_bg(image, net, transform, device):
101
+ image_size = image.size
102
+ input_images = transform(image).unsqueeze(0).to(device)
103
+ with torch.no_grad():
104
+ preds = net(input_images)[-1].sigmoid().cpu()
105
+ pred = preds[0].squeeze()
106
+ pred_pil = transforms.ToPILImage()(pred)
107
+ mask = pred_pil.resize(image_size)
108
+ image.putalpha(mask)
109
+ return image
110
+
111
+
112
+
113
 
114
 
115
  def preprocess_image(image: Image.Image, height, width):
116
 
117
+ alpha = image[..., 3] > 0
118
  # alpha = image
119
 
120
+ #if image.mode in ("RGBA", "LA"):
121
+ # image = np.array(image)
122
+ # alpha = image[..., 3] # Extract the alpha channel
123
+ #elif image.mode in ("RGB"):
124
+ # image = np.array(image)
125
  # Create default alpha for non-alpha images
126
+ # alpha = np.ones(image[..., 0].shape, dtype=np.uint8) * 255 # Create
127
  H, W = alpha.shape
128
  # get the bounding box of alpha
129
  y, x = np.where(alpha)