vidit98 commited on
Commit
87a26ac
1 Parent(s): 90d54c5

fix session state issue

Browse files
Files changed (2) hide show
  1. app.py +22 -12
  2. requirements.txt +1 -1
app.py CHANGED
@@ -119,7 +119,7 @@ class ImageComp:
119
  self.baseoutput = output.astype(np.uint8)
120
  return self.baseoutput
121
 
122
- def process_mask(self, mask, panoptic_mask, segmask):
123
  panoptic_mask_ = panoptic_mask + 1
124
  mask_ = resize_image(mask['mask'][:, :, 0], min(panoptic_mask.shape))
125
  mask_ = torch.tensor(mask_)
@@ -137,29 +137,29 @@ class ImageComp:
137
  return final_mask, obj_class
138
 
139
 
140
- def edit_app(self, input_mask, ref_mask, whole_ref):
141
  input_pmask = self.input_pmask
142
  input_segmask = self.input_segmask
143
 
144
  if whole_ref:
145
  reference_mask = torch.ones(self.ref_pmask.shape).cuda()
146
  else:
147
- reference_mask, _ = self.process_mask(ref_mask, self.ref_pmask, self.ref_segmask)
148
 
149
- edit_mask, _ = self.process_mask(input_mask, self.input_pmask, self.input_segmask)
150
  ma = torch.max(input_pmask)
151
  input_pmask[edit_mask == 1] = ma + 1
152
  return reference_mask, input_pmask, input_segmask, edit_mask, ma
153
 
154
 
155
- def edit(self, input_mask, ref_mask, whole_ref=False, inter=1):
156
  input_img = (self.input_img/127.5 - 1)
157
  input_img = torch.from_numpy(input_img.astype(np.float32)).cuda().unsqueeze(0).permute(0,3,1,2)
158
 
159
  reference_img = (self.ref_img/127.5 - 1)
160
  reference_img = torch.from_numpy(reference_img.astype(np.float32)).cuda().unsqueeze(0).permute(0,3,1,2)
161
 
162
- reference_mask, input_pmask, input_segmask, region_mask, ma = self.edit_app(input_mask, ref_mask, whole_ref)
163
 
164
  input_pmask = input_pmask.float().cuda().unsqueeze(0).unsqueeze(1)
165
  _, mean_feat_inpt, one_hot_inpt, empty_mask_flag_inpt = model.get_appearance(input_img, input_pmask, return_all=True)
@@ -182,7 +182,7 @@ class ImageComp:
182
  def process(self, input_mask, ref_mask, prompt, a_prompt, n_prompt,
183
  num_samples, ddim_steps, guess_mode, strength,
184
  scale_s, scale_f, scale_t, seed, eta, masking=True,whole_ref=False,inter=1):
185
- structure, appearance, mask, img = self.edit(input_mask, ref_mask,
186
  whole_ref=whole_ref, inter=inter)
187
 
188
  null_structure = torch.zeros(structure.shape).cuda() - 1
@@ -242,6 +242,17 @@ class ImageComp:
242
  return [] + results
243
 
244
 
 
 
 
 
 
 
 
 
 
 
 
245
  css = """
246
  h1 {
247
  text-align: center;
@@ -293,14 +304,14 @@ def create_app_demo():
293
  """)
294
  with gr.Column():
295
  with gr.Row():
296
- img_edit = ImageComp('edit_app')
297
  with gr.Column():
298
  btn1 = gr.Button("Input Image")
299
  input_image = gr.Image(source='upload', label='Input Image', type="numpy",)
300
  with gr.Column():
301
  btn2 = gr.Button("Select Object to Edit")
302
  input_mask = gr.Image(source="upload", label='Select Object in Input Image', type="numpy", tool="sketch")
303
- input_image.change(fn=img_edit.init_input_canvas, inputs=[input_image], outputs=[input_mask], queue=False)
304
 
305
  # with gr.Row():
306
  with gr.Column():
@@ -310,7 +321,7 @@ def create_app_demo():
310
  btn4 = gr.Button("Select Reference Object")
311
  reference_mask = gr.Image(source="upload", label='Select Object in Refernce Image', type="numpy", tool="sketch")
312
 
313
- ref_img.change(fn=img_edit.init_ref_canvas, inputs=[ref_img], outputs=[reference_mask], queue=False)
314
 
315
  with gr.Row():
316
  prompt = gr.Textbox(label="Prompt", value='A picture of truck')
@@ -325,7 +336,6 @@ def create_app_demo():
325
 
326
  with gr.Accordion("Advanced options", open=False):
327
  num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
328
- image_resolution = gr.Slider(label="Image Resolution", minimum=512, maximum=512, value=512, step=64)
329
  strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
330
  guess_mode = gr.Checkbox(label='Guess Mode', value=False)
331
  ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
@@ -351,7 +361,7 @@ def create_app_demo():
351
  )
352
  ips = [input_mask, reference_mask, prompt, a_prompt, n_prompt, num_samples, ddim_steps, guess_mode, strength,
353
  scale_s, scale_f, scale_t, seed, eta, masking, whole_ref, interpolation]
354
- run_button.click(fn=img_edit.process, inputs=ips, outputs=[result_gallery])
355
 
356
 
357
 
 
119
  self.baseoutput = output.astype(np.uint8)
120
  return self.baseoutput
121
 
122
+ def _process_mask(self, mask, panoptic_mask, segmask):
123
  panoptic_mask_ = panoptic_mask + 1
124
  mask_ = resize_image(mask['mask'][:, :, 0], min(panoptic_mask.shape))
125
  mask_ = torch.tensor(mask_)
 
137
  return final_mask, obj_class
138
 
139
 
140
+ def _edit_app(self, input_mask, ref_mask, whole_ref):
141
  input_pmask = self.input_pmask
142
  input_segmask = self.input_segmask
143
 
144
  if whole_ref:
145
  reference_mask = torch.ones(self.ref_pmask.shape).cuda()
146
  else:
147
+ reference_mask, _ = self._process_mask(ref_mask, self.ref_pmask, self.ref_segmask)
148
 
149
+ edit_mask, _ = self._process_mask(input_mask, self.input_pmask, self.input_segmask)
150
  ma = torch.max(input_pmask)
151
  input_pmask[edit_mask == 1] = ma + 1
152
  return reference_mask, input_pmask, input_segmask, edit_mask, ma
153
 
154
 
155
+ def _edit(self, input_mask, ref_mask, whole_ref=False, inter=1):
156
  input_img = (self.input_img/127.5 - 1)
157
  input_img = torch.from_numpy(input_img.astype(np.float32)).cuda().unsqueeze(0).permute(0,3,1,2)
158
 
159
  reference_img = (self.ref_img/127.5 - 1)
160
  reference_img = torch.from_numpy(reference_img.astype(np.float32)).cuda().unsqueeze(0).permute(0,3,1,2)
161
 
162
+ reference_mask, input_pmask, input_segmask, region_mask, ma = self._edit_app(input_mask, ref_mask, whole_ref)
163
 
164
  input_pmask = input_pmask.float().cuda().unsqueeze(0).unsqueeze(1)
165
  _, mean_feat_inpt, one_hot_inpt, empty_mask_flag_inpt = model.get_appearance(input_img, input_pmask, return_all=True)
 
182
  def process(self, input_mask, ref_mask, prompt, a_prompt, n_prompt,
183
  num_samples, ddim_steps, guess_mode, strength,
184
  scale_s, scale_f, scale_t, seed, eta, masking=True,whole_ref=False,inter=1):
185
+ structure, appearance, mask, img = self._edit(input_mask, ref_mask,
186
  whole_ref=whole_ref, inter=inter)
187
 
188
  null_structure = torch.zeros(structure.shape).cuda() - 1
 
242
  return [] + results
243
 
244
 
245
+ def init_input_canvas_wrapper(obj, *args):
246
+ return obj.init_input_canvas(*args)
247
+
248
+ def init_ref_canvas_wrapper(obj, *args):
249
+ return obj.init_ref_canvas(*args)
250
+
251
+ def process_wrapper(obj, *args):
252
+ return obj.process(*args)
253
+
254
+
255
+
256
  css = """
257
  h1 {
258
  text-align: center;
 
304
  """)
305
  with gr.Column():
306
  with gr.Row():
307
+ img_edit = gr.State(ImageComp('edit_app'))
308
  with gr.Column():
309
  btn1 = gr.Button("Input Image")
310
  input_image = gr.Image(source='upload', label='Input Image', type="numpy",)
311
  with gr.Column():
312
  btn2 = gr.Button("Select Object to Edit")
313
  input_mask = gr.Image(source="upload", label='Select Object in Input Image', type="numpy", tool="sketch")
314
+ input_image.change(fn=init_input_canvas_wrapper, inputs=[img_edit, input_image], outputs=[input_mask], queue=False)
315
 
316
  # with gr.Row():
317
  with gr.Column():
 
321
  btn4 = gr.Button("Select Reference Object")
322
  reference_mask = gr.Image(source="upload", label='Select Object in Refernce Image', type="numpy", tool="sketch")
323
 
324
+ ref_img.change(fn=init_ref_canvas_wrapper, inputs=[img_edit, ref_img], outputs=[reference_mask], queue=False)
325
 
326
  with gr.Row():
327
  prompt = gr.Textbox(label="Prompt", value='A picture of truck')
 
336
 
337
  with gr.Accordion("Advanced options", open=False):
338
  num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
 
339
  strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
340
  guess_mode = gr.Checkbox(label='Guess Mode', value=False)
341
  ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
 
361
  )
362
  ips = [input_mask, reference_mask, prompt, a_prompt, n_prompt, num_samples, ddim_steps, guess_mode, strength,
363
  scale_s, scale_f, scale_t, seed, eta, masking, whole_ref, interpolation]
364
+ run_button.click(fn=process_wrapper, inputs=[img_edit, *ips], outputs=[result_gallery])
365
 
366
 
367
 
requirements.txt CHANGED
@@ -1,7 +1,7 @@
1
  addict==2.4.0
2
  albumentations==1.3.0
3
  einops==0.3.0
4
- gradio==3.17.1
5
  imageio==2.9.0
6
  imageio-ffmpeg==0.4.2
7
  kornia==0.6.0
 
1
  addict==2.4.0
2
  albumentations==1.3.0
3
  einops==0.3.0
4
+ gradio==3.25.0
5
  imageio==2.9.0
6
  imageio-ffmpeg==0.4.2
7
  kornia==0.6.0