Spaces:
Running
on
Zero
Running
on
Zero
Update inference_i2mv_sdxl.py
Browse files- inference_i2mv_sdxl.py +22 -21
inference_i2mv_sdxl.py
CHANGED
@@ -70,19 +70,6 @@ def prepare_pipeline(
|
|
70 |
|
71 |
return pipe
|
72 |
|
73 |
-
|
74 |
-
# def remove_bg(image, net, transform, device):
|
75 |
-
# image_size = image.size
|
76 |
-
# input_images = transform(image).unsqueeze(0).to(device)
|
77 |
-
# with torch.no_grad():
|
78 |
-
# preds = net(input_images)[-1].sigmoid().cpu()
|
79 |
-
# pred = preds[0].squeeze()
|
80 |
-
# pred_pil = transforms.ToPILImage()(pred)
|
81 |
-
# mask = pred_pil.resize(image_size)
|
82 |
-
# image.putalpha(mask)
|
83 |
-
# return image
|
84 |
-
|
85 |
-
|
86 |
def remove_bg(image: Image.Image, net, transform, device, mask: Image.Image = None):
|
87 |
"""
|
88 |
Applies a pre-existing mask to an image to make the background transparent.
|
@@ -108,21 +95,35 @@ def remove_bg(image: Image.Image, net, transform, device, mask: Image.Image = No
|
|
108 |
|
109 |
image.putalpha(mask)
|
110 |
return image
|
111 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
112 |
|
113 |
|
114 |
def preprocess_image(image: Image.Image, height, width):
|
115 |
|
116 |
-
|
117 |
# alpha = image
|
118 |
|
119 |
-
if image.mode in ("RGBA", "LA"):
|
120 |
-
|
121 |
-
|
122 |
-
elif image.mode in ("RGB"):
|
123 |
-
|
124 |
# Create default alpha for non-alpha images
|
125 |
-
|
126 |
H, W = alpha.shape
|
127 |
# get the bounding box of alpha
|
128 |
y, x = np.where(alpha)
|
|
|
70 |
|
71 |
return pipe
|
72 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
def remove_bg(image: Image.Image, net, transform, device, mask: Image.Image = None):
|
74 |
"""
|
75 |
Applies a pre-existing mask to an image to make the background transparent.
|
|
|
95 |
|
96 |
image.putalpha(mask)
|
97 |
return image
|
98 |
+
|
99 |
+
|
100 |
+
def remove_bg(image, net, transform, device):
|
101 |
+
image_size = image.size
|
102 |
+
input_images = transform(image).unsqueeze(0).to(device)
|
103 |
+
with torch.no_grad():
|
104 |
+
preds = net(input_images)[-1].sigmoid().cpu()
|
105 |
+
pred = preds[0].squeeze()
|
106 |
+
pred_pil = transforms.ToPILImage()(pred)
|
107 |
+
mask = pred_pil.resize(image_size)
|
108 |
+
image.putalpha(mask)
|
109 |
+
return image
|
110 |
+
|
111 |
+
|
112 |
+
|
113 |
|
114 |
|
115 |
def preprocess_image(image: Image.Image, height, width):
|
116 |
|
117 |
+
alpha = image[..., 3] > 0
|
118 |
# alpha = image
|
119 |
|
120 |
+
#if image.mode in ("RGBA", "LA"):
|
121 |
+
# image = np.array(image)
|
122 |
+
# alpha = image[..., 3] # Extract the alpha channel
|
123 |
+
#elif image.mode in ("RGB"):
|
124 |
+
# image = np.array(image)
|
125 |
# Create default alpha for non-alpha images
|
126 |
+
# alpha = np.ones(image[..., 0].shape, dtype=np.uint8) * 255 # Create
|
127 |
H, W = alpha.shape
|
128 |
# get the bounding box of alpha
|
129 |
y, x = np.where(alpha)
|