not-lain commited on
Commit
dcfda89
·
1 Parent(s): ff8d629

update outpainting

Browse files
Files changed (1) hide show
  1. app.py +41 -147
app.py CHANGED
@@ -5,8 +5,7 @@ from loadimg import load_img
5
  from torchvision import transforms
6
  from transformers import AutoModelForImageSegmentation
7
  from diffusers import FluxFillPipeline
8
- from PIL import Image, ImageDraw
9
- from diffusers.utils import load_image
10
 
11
  torch.set_float32_matmul_precision(["high", "highest"][0])
12
 
@@ -38,169 +37,52 @@ def can_expand(source_width, source_height, target_width, target_height, alignme
38
 
39
  def prepare_image_and_mask(
40
  image,
41
- width,
42
- height,
43
- overlap_percentage,
44
- resize_percentage,
45
- alignment,
46
- overlap_left,
47
- overlap_right,
48
- overlap_top,
49
- overlap_bottom,
50
  ):
51
- target_size = (width, height)
52
-
53
- scale_factor = min(target_size[0] / image.width, target_size[1] / image.height)
54
- new_width = int(image.width * scale_factor)
55
- new_height = int(image.height * scale_factor)
56
-
57
- source = image.resize((new_width, new_height), Image.LANCZOS)
58
-
59
- resize_percentage = 50
60
-
61
- # Calculate new dimensions based on percentage
62
- resize_factor = resize_percentage / 100
63
- new_width = int(source.width * resize_factor)
64
- new_height = int(source.height * resize_factor)
65
-
66
- # Ensure minimum size of 64 pixels
67
- new_width = max(new_width, 64)
68
- new_height = max(new_height, 64)
69
-
70
- # Resize the image
71
- source = source.resize((new_width, new_height), Image.LANCZOS)
72
-
73
- # Calculate the overlap in pixels based on the percentage
74
- overlap_x = int(new_width * (overlap_percentage / 100))
75
- overlap_y = int(new_height * (overlap_percentage / 100))
76
-
77
- # Ensure minimum overlap of 1 pixel
78
- overlap_x = max(overlap_x, 1)
79
- overlap_y = max(overlap_y, 1)
80
-
81
- # Calculate margins based on alignment
82
- if alignment == "Middle":
83
- margin_x = (target_size[0] - new_width) // 2
84
- margin_y = (target_size[1] - new_height) // 2
85
- elif alignment == "Left":
86
- margin_x = 0
87
- margin_y = (target_size[1] - new_height) // 2
88
- elif alignment == "Right":
89
- margin_x = target_size[0] - new_width
90
- margin_y = (target_size[1] - new_height) // 2
91
- elif alignment == "Top":
92
- margin_x = (target_size[0] - new_width) // 2
93
- margin_y = 0
94
- elif alignment == "Bottom":
95
- margin_x = (target_size[0] - new_width) // 2
96
- margin_y = target_size[1] - new_height
97
-
98
- # Adjust margins to eliminate gaps
99
- margin_x = max(0, min(margin_x, target_size[0] - new_width))
100
- margin_y = max(0, min(margin_y, target_size[1] - new_height))
101
-
102
- # Create a new background image and paste the resized source image
103
- background = Image.new("RGB", target_size, (255, 255, 255))
104
- background.paste(source, (margin_x, margin_y))
105
-
106
- # Create the mask
107
- mask = Image.new("L", target_size, 255)
108
- mask_draw = ImageDraw.Draw(mask)
109
-
110
- # Calculate overlap areas
111
- white_gaps_patch = 2
112
-
113
- left_overlap = margin_x + overlap_x if overlap_left else margin_x + white_gaps_patch
114
- right_overlap = (
115
- margin_x + new_width - overlap_x
116
- if overlap_right
117
- else margin_x + new_width - white_gaps_patch
118
- )
119
- top_overlap = margin_y + overlap_y if overlap_top else margin_y + white_gaps_patch
120
- bottom_overlap = (
121
- margin_y + new_height - overlap_y
122
- if overlap_bottom
123
- else margin_y + new_height - white_gaps_patch
124
- )
125
-
126
- if alignment == "Left":
127
- left_overlap = margin_x + overlap_x if overlap_left else margin_x
128
- elif alignment == "Right":
129
- right_overlap = (
130
- margin_x + new_width - overlap_x if overlap_right else margin_x + new_width
131
- )
132
- elif alignment == "Top":
133
- top_overlap = margin_y + overlap_y if overlap_top else margin_y
134
- elif alignment == "Bottom":
135
- bottom_overlap = (
136
- margin_y + new_height - overlap_y
137
- if overlap_bottom
138
- else margin_y + new_height
139
- )
140
-
141
- # Draw the mask
142
- mask_draw.rectangle(
143
- [(left_overlap, top_overlap), (right_overlap, bottom_overlap)], fill=0
144
  )
145
-
 
146
  return background, mask
147
 
148
 
149
  def inpaint(
150
  image,
151
- width,
152
- height,
153
- overlap_percentage,
154
- num_inference_steps,
155
- custom_resize_percentage,
156
- prompt_input,
157
- alignment,
158
- overlap_left,
159
- overlap_right,
160
- overlap_top,
161
- overlap_bottom,
162
  progress=gr.Progress(track_tqdm=True),
163
  ):
164
  background, mask = prepare_image_and_mask(
165
- image,
166
- width,
167
- height,
168
- overlap_percentage,
169
- custom_resize_percentage,
170
- alignment,
171
- overlap_left,
172
- overlap_right,
173
- overlap_top,
174
- overlap_bottom,
175
  )
176
 
177
- if not can_expand(background.width, background.height, width, height, alignment):
178
- alignment = "Middle"
179
-
180
- cnet_image = background.copy()
181
- cnet_image.paste(0, (0, 0), mask)
182
-
183
- final_prompt = prompt_input
184
-
185
  # generator = torch.Generator(device="cuda").manual_seed(42)
186
 
187
  result = pipe(
188
- prompt=final_prompt,
189
- height=height,
190
- width=width,
191
- image=cnet_image,
192
  mask_image=mask,
193
- num_inference_steps=num_inference_steps,
194
  guidance_scale=30,
195
  ).images[0]
196
 
197
  result = result.convert("RGBA")
198
- cnet_image.paste(result, (0, 0), mask)
199
 
200
- return cnet_image
201
 
202
-
203
- @spaces.GPU
204
  def rmbg(image, url):
205
  if image is None:
206
  image = url
@@ -217,16 +99,28 @@ def rmbg(image, url):
217
  return image
218
 
219
 
220
- def placeholder(img,*args):
221
- return img
 
 
222
 
223
 
224
  rmbg_tab = gr.Interface(
225
- fn=rmbg, inputs=["image", "text"], outputs=["image"], api_name="rmbg"
226
  )
227
 
228
  outpaint_tab = gr.Interface(
229
- fn=placeholder, inputs=["image","text",gr.Dropdown(["down","up","left","right","middle"])], outputs=["image"], api_name="outpainting"
 
 
 
 
 
 
 
 
 
 
230
  )
231
 
232
  demo = gr.TabbedInterface(
 
5
  from torchvision import transforms
6
  from transformers import AutoModelForImageSegmentation
7
  from diffusers import FluxFillPipeline
8
+ from PIL import Image, ImageOps
 
9
 
10
  torch.set_float32_matmul_precision(["high", "highest"][0])
11
 
 
37
 
38
  def prepare_image_and_mask(
39
  image,
40
+ padding_top=0,
41
+ padding_bottom=0,
42
+ padding_left=0,
43
+ padding_right=0,
 
 
 
 
 
44
  ):
45
+ image = load_img(image).convert("RGB")
46
+ # expand image (left,top,right,bottom)
47
+ background = ImageOps.expand(
48
+ image,
49
+ border=(padding_left, padding_top, padding_right, padding_bottom),
50
+ fill="white",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  )
52
+ mask = Image.new("RGB", image.size, "black")
53
+ mask = ImageOps.expand(mask, border=(0, 20, 0, 0), fill="white")
54
  return background, mask
55
 
56
 
57
  def inpaint(
58
  image,
59
+ padding_top=0,
60
+ padding_bottom=0,
61
+ padding_left=0,
62
+ padding_right=0,
63
+ prompt="",
 
 
 
 
 
 
64
  progress=gr.Progress(track_tqdm=True),
65
  ):
66
  background, mask = prepare_image_and_mask(
67
+ image, padding_top, padding_bottom, padding_left, padding_right
 
 
 
 
 
 
 
 
 
68
  )
69
 
 
 
 
 
 
 
 
 
70
  # generator = torch.Generator(device="cuda").manual_seed(42)
71
 
72
  result = pipe(
73
+ prompt=prompt,
74
+ height=background.height,
75
+ width=background.width,
76
+ image=background,
77
  mask_image=mask,
78
+ num_inference_steps=28,
79
  guidance_scale=30,
80
  ).images[0]
81
 
82
  result = result.convert("RGBA")
83
+ return result
84
 
 
85
 
 
 
86
  def rmbg(image, url):
87
  if image is None:
88
  image = url
 
99
  return image
100
 
101
 
102
+ @spaces.GPU
103
+ def main(*args, **kwargs):
104
+ print(args, kwargs)
105
+ return None
106
 
107
 
108
  rmbg_tab = gr.Interface(
109
+ fn=main, inputs=["image", "text"], outputs=["image"], api_name="rmbg"
110
  )
111
 
112
  outpaint_tab = gr.Interface(
113
+ fn=main,
114
+ inputs=[
115
+ "image",
116
+ gr.Slider(label="padding top"),
117
+ gr.Slider(label="padding bottom"),
118
+ gr.Slider(label="padding left"),
119
+ gr.Slider(label="padding right"),
120
+ gr.Text(label="prompt"),
121
+ ],
122
+ outputs=["image"],
123
+ api_name="outpainting",
124
  )
125
 
126
  demo = gr.TabbedInterface(