Spaces:
Runtime error
Runtime error
fix session state issue
Browse files- app.py +22 -12
- requirements.txt +1 -1
app.py
CHANGED
@@ -119,7 +119,7 @@ class ImageComp:
|
|
119 |
self.baseoutput = output.astype(np.uint8)
|
120 |
return self.baseoutput
|
121 |
|
122 |
-
def
|
123 |
panoptic_mask_ = panoptic_mask + 1
|
124 |
mask_ = resize_image(mask['mask'][:, :, 0], min(panoptic_mask.shape))
|
125 |
mask_ = torch.tensor(mask_)
|
@@ -137,29 +137,29 @@ class ImageComp:
|
|
137 |
return final_mask, obj_class
|
138 |
|
139 |
|
140 |
-
def
|
141 |
input_pmask = self.input_pmask
|
142 |
input_segmask = self.input_segmask
|
143 |
|
144 |
if whole_ref:
|
145 |
reference_mask = torch.ones(self.ref_pmask.shape).cuda()
|
146 |
else:
|
147 |
-
reference_mask, _ = self.
|
148 |
|
149 |
-
edit_mask, _ = self.
|
150 |
ma = torch.max(input_pmask)
|
151 |
input_pmask[edit_mask == 1] = ma + 1
|
152 |
return reference_mask, input_pmask, input_segmask, edit_mask, ma
|
153 |
|
154 |
|
155 |
-
def
|
156 |
input_img = (self.input_img/127.5 - 1)
|
157 |
input_img = torch.from_numpy(input_img.astype(np.float32)).cuda().unsqueeze(0).permute(0,3,1,2)
|
158 |
|
159 |
reference_img = (self.ref_img/127.5 - 1)
|
160 |
reference_img = torch.from_numpy(reference_img.astype(np.float32)).cuda().unsqueeze(0).permute(0,3,1,2)
|
161 |
|
162 |
-
reference_mask, input_pmask, input_segmask, region_mask, ma = self.
|
163 |
|
164 |
input_pmask = input_pmask.float().cuda().unsqueeze(0).unsqueeze(1)
|
165 |
_, mean_feat_inpt, one_hot_inpt, empty_mask_flag_inpt = model.get_appearance(input_img, input_pmask, return_all=True)
|
@@ -182,7 +182,7 @@ class ImageComp:
|
|
182 |
def process(self, input_mask, ref_mask, prompt, a_prompt, n_prompt,
|
183 |
num_samples, ddim_steps, guess_mode, strength,
|
184 |
scale_s, scale_f, scale_t, seed, eta, masking=True,whole_ref=False,inter=1):
|
185 |
-
structure, appearance, mask, img = self.
|
186 |
whole_ref=whole_ref, inter=inter)
|
187 |
|
188 |
null_structure = torch.zeros(structure.shape).cuda() - 1
|
@@ -242,6 +242,17 @@ class ImageComp:
|
|
242 |
return [] + results
|
243 |
|
244 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
245 |
css = """
|
246 |
h1 {
|
247 |
text-align: center;
|
@@ -293,14 +304,14 @@ def create_app_demo():
|
|
293 |
""")
|
294 |
with gr.Column():
|
295 |
with gr.Row():
|
296 |
-
img_edit = ImageComp('edit_app')
|
297 |
with gr.Column():
|
298 |
btn1 = gr.Button("Input Image")
|
299 |
input_image = gr.Image(source='upload', label='Input Image', type="numpy",)
|
300 |
with gr.Column():
|
301 |
btn2 = gr.Button("Select Object to Edit")
|
302 |
input_mask = gr.Image(source="upload", label='Select Object in Input Image', type="numpy", tool="sketch")
|
303 |
-
input_image.change(fn=
|
304 |
|
305 |
# with gr.Row():
|
306 |
with gr.Column():
|
@@ -310,7 +321,7 @@ def create_app_demo():
|
|
310 |
btn4 = gr.Button("Select Reference Object")
|
311 |
reference_mask = gr.Image(source="upload", label='Select Object in Refernce Image', type="numpy", tool="sketch")
|
312 |
|
313 |
-
ref_img.change(fn=
|
314 |
|
315 |
with gr.Row():
|
316 |
prompt = gr.Textbox(label="Prompt", value='A picture of truck')
|
@@ -325,7 +336,6 @@ def create_app_demo():
|
|
325 |
|
326 |
with gr.Accordion("Advanced options", open=False):
|
327 |
num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
|
328 |
-
image_resolution = gr.Slider(label="Image Resolution", minimum=512, maximum=512, value=512, step=64)
|
329 |
strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
|
330 |
guess_mode = gr.Checkbox(label='Guess Mode', value=False)
|
331 |
ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
|
@@ -351,7 +361,7 @@ def create_app_demo():
|
|
351 |
)
|
352 |
ips = [input_mask, reference_mask, prompt, a_prompt, n_prompt, num_samples, ddim_steps, guess_mode, strength,
|
353 |
scale_s, scale_f, scale_t, seed, eta, masking, whole_ref, interpolation]
|
354 |
-
run_button.click(fn=
|
355 |
|
356 |
|
357 |
|
|
|
119 |
self.baseoutput = output.astype(np.uint8)
|
120 |
return self.baseoutput
|
121 |
|
122 |
+
def _process_mask(self, mask, panoptic_mask, segmask):
|
123 |
panoptic_mask_ = panoptic_mask + 1
|
124 |
mask_ = resize_image(mask['mask'][:, :, 0], min(panoptic_mask.shape))
|
125 |
mask_ = torch.tensor(mask_)
|
|
|
137 |
return final_mask, obj_class
|
138 |
|
139 |
|
140 |
+
def _edit_app(self, input_mask, ref_mask, whole_ref):
|
141 |
input_pmask = self.input_pmask
|
142 |
input_segmask = self.input_segmask
|
143 |
|
144 |
if whole_ref:
|
145 |
reference_mask = torch.ones(self.ref_pmask.shape).cuda()
|
146 |
else:
|
147 |
+
reference_mask, _ = self._process_mask(ref_mask, self.ref_pmask, self.ref_segmask)
|
148 |
|
149 |
+
edit_mask, _ = self._process_mask(input_mask, self.input_pmask, self.input_segmask)
|
150 |
ma = torch.max(input_pmask)
|
151 |
input_pmask[edit_mask == 1] = ma + 1
|
152 |
return reference_mask, input_pmask, input_segmask, edit_mask, ma
|
153 |
|
154 |
|
155 |
+
def _edit(self, input_mask, ref_mask, whole_ref=False, inter=1):
|
156 |
input_img = (self.input_img/127.5 - 1)
|
157 |
input_img = torch.from_numpy(input_img.astype(np.float32)).cuda().unsqueeze(0).permute(0,3,1,2)
|
158 |
|
159 |
reference_img = (self.ref_img/127.5 - 1)
|
160 |
reference_img = torch.from_numpy(reference_img.astype(np.float32)).cuda().unsqueeze(0).permute(0,3,1,2)
|
161 |
|
162 |
+
reference_mask, input_pmask, input_segmask, region_mask, ma = self._edit_app(input_mask, ref_mask, whole_ref)
|
163 |
|
164 |
input_pmask = input_pmask.float().cuda().unsqueeze(0).unsqueeze(1)
|
165 |
_, mean_feat_inpt, one_hot_inpt, empty_mask_flag_inpt = model.get_appearance(input_img, input_pmask, return_all=True)
|
|
|
182 |
def process(self, input_mask, ref_mask, prompt, a_prompt, n_prompt,
|
183 |
num_samples, ddim_steps, guess_mode, strength,
|
184 |
scale_s, scale_f, scale_t, seed, eta, masking=True,whole_ref=False,inter=1):
|
185 |
+
structure, appearance, mask, img = self._edit(input_mask, ref_mask,
|
186 |
whole_ref=whole_ref, inter=inter)
|
187 |
|
188 |
null_structure = torch.zeros(structure.shape).cuda() - 1
|
|
|
242 |
return [] + results
|
243 |
|
244 |
|
245 |
+
def init_input_canvas_wrapper(obj, *args):
|
246 |
+
return obj.init_input_canvas(*args)
|
247 |
+
|
248 |
+
def init_ref_canvas_wrapper(obj, *args):
|
249 |
+
return obj.init_ref_canvas(*args)
|
250 |
+
|
251 |
+
def process_wrapper(obj, *args):
|
252 |
+
return obj.process(*args)
|
253 |
+
|
254 |
+
|
255 |
+
|
256 |
css = """
|
257 |
h1 {
|
258 |
text-align: center;
|
|
|
304 |
""")
|
305 |
with gr.Column():
|
306 |
with gr.Row():
|
307 |
+
img_edit = gr.State(ImageComp('edit_app'))
|
308 |
with gr.Column():
|
309 |
btn1 = gr.Button("Input Image")
|
310 |
input_image = gr.Image(source='upload', label='Input Image', type="numpy",)
|
311 |
with gr.Column():
|
312 |
btn2 = gr.Button("Select Object to Edit")
|
313 |
input_mask = gr.Image(source="upload", label='Select Object in Input Image', type="numpy", tool="sketch")
|
314 |
+
input_image.change(fn=init_input_canvas_wrapper, inputs=[img_edit, input_image], outputs=[input_mask], queue=False)
|
315 |
|
316 |
# with gr.Row():
|
317 |
with gr.Column():
|
|
|
321 |
btn4 = gr.Button("Select Reference Object")
|
322 |
reference_mask = gr.Image(source="upload", label='Select Object in Refernce Image', type="numpy", tool="sketch")
|
323 |
|
324 |
+
ref_img.change(fn=init_ref_canvas_wrapper, inputs=[img_edit, ref_img], outputs=[reference_mask], queue=False)
|
325 |
|
326 |
with gr.Row():
|
327 |
prompt = gr.Textbox(label="Prompt", value='A picture of truck')
|
|
|
336 |
|
337 |
with gr.Accordion("Advanced options", open=False):
|
338 |
num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
|
|
|
339 |
strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
|
340 |
guess_mode = gr.Checkbox(label='Guess Mode', value=False)
|
341 |
ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
|
|
|
361 |
)
|
362 |
ips = [input_mask, reference_mask, prompt, a_prompt, n_prompt, num_samples, ddim_steps, guess_mode, strength,
|
363 |
scale_s, scale_f, scale_t, seed, eta, masking, whole_ref, interpolation]
|
364 |
+
run_button.click(fn=process_wrapper, inputs=[img_edit, *ips], outputs=[result_gallery])
|
365 |
|
366 |
|
367 |
|
requirements.txt
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
addict==2.4.0
|
2 |
albumentations==1.3.0
|
3 |
einops==0.3.0
|
4 |
-
gradio==3.
|
5 |
imageio==2.9.0
|
6 |
imageio-ffmpeg==0.4.2
|
7 |
kornia==0.6.0
|
|
|
1 |
addict==2.4.0
|
2 |
albumentations==1.3.0
|
3 |
einops==0.3.0
|
4 |
+
gradio==3.25.0
|
5 |
imageio==2.9.0
|
6 |
imageio-ffmpeg==0.4.2
|
7 |
kornia==0.6.0
|