Spaces:
Running
on
Zero
Running
on
Zero
Update inference_i2mv_sdxl.py
Browse files- inference_i2mv_sdxl.py +37 -8
inference_i2mv_sdxl.py
CHANGED
@@ -71,16 +71,44 @@ def prepare_pipeline(
|
|
71 |
return pipe
|
72 |
|
73 |
|
74 |
-
def remove_bg(image, net, transform, device):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
image_size = image.size
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
pred = preds[0].squeeze()
|
80 |
-
pred_pil = transforms.ToPILImage()(pred)
|
81 |
-
mask = pred_pil.resize(image_size)
|
82 |
image.putalpha(mask)
|
83 |
return image
|
|
|
84 |
|
85 |
|
86 |
def preprocess_image(image: Image.Image, height, width):
|
@@ -150,7 +178,8 @@ def run_pipeline(
|
|
150 |
# Prepare image
|
151 |
reference_image = Image.open(image) if isinstance(image, str) else image
|
152 |
if remove_bg_fn is not None:
|
153 |
-
reference_image = remove_bg_fn(reference_image)
|
|
|
154 |
reference_image = preprocess_image(reference_image, height, width)
|
155 |
elif reference_image.mode == "RGBA":
|
156 |
reference_image = preprocess_image(reference_image, height, width)
|
|
|
71 |
return pipe
|
72 |
|
73 |
|
74 |
+
# def remove_bg(image, net, transform, device):
|
75 |
+
# image_size = image.size
|
76 |
+
# input_images = transform(image).unsqueeze(0).to(device)
|
77 |
+
# with torch.no_grad():
|
78 |
+
# preds = net(input_images)[-1].sigmoid().cpu()
|
79 |
+
# pred = preds[0].squeeze()
|
80 |
+
# pred_pil = transforms.ToPILImage()(pred)
|
81 |
+
# mask = pred_pil.resize(image_size)
|
82 |
+
# image.putalpha(mask)
|
83 |
+
# return image
|
84 |
+
|
85 |
+
|
86 |
+
def remove_bg(image: Image.Image, net, transform, device, mask: Image.Image = None):
|
87 |
+
"""
|
88 |
+
Applies a pre-existing mask to an image to make the background transparent.
|
89 |
+
|
90 |
+
Args:
|
91 |
+
image (PIL.Image.Image): The input image.
|
92 |
+
net: Pre-trained neural network (not used but kept for compatibility).
|
93 |
+
transform: Image transformation object (not used but kept for compatibility).
|
94 |
+
device: Device used for inference (not used but kept for compatibility).
|
95 |
+
mask (PIL.Image.Image, optional): The mask to use. Should be the same size
|
96 |
+
as the input image, with values between 0 and 255 (or 0-1).
|
97 |
+
If None, will return image with no changes.
|
98 |
+
|
99 |
+
Returns:
|
100 |
+
PIL.Image.Image: The modified image with transparent background.
|
101 |
+
"""
|
102 |
+
if mask is None:
|
103 |
+
return image
|
104 |
+
|
105 |
image_size = image.size
|
106 |
+
if mask.size != image_size:
|
107 |
+
mask = mask.resize(image_size) # Resizing the mask if it is not the same size as image
|
108 |
+
|
|
|
|
|
|
|
109 |
image.putalpha(mask)
|
110 |
return image
|
111 |
+
|
112 |
|
113 |
|
114 |
def preprocess_image(image: Image.Image, height, width):
|
|
|
178 |
# Prepare image
|
179 |
reference_image = Image.open(image) if isinstance(image, str) else image
|
180 |
if remove_bg_fn is not None:
|
181 |
+
# reference_image = remove_bg_fn(reference_image)
|
182 |
+
|
183 |
reference_image = preprocess_image(reference_image, height, width)
|
184 |
elif reference_image.mode == "RGBA":
|
185 |
reference_image = preprocess_image(reference_image, height, width)
|