pan-yl commited on
Commit
fd61736
1 Parent(s): 6e11c9d
Files changed (2) hide show
  1. app.py +4 -4
  2. infer.py +6 -20
app.py CHANGED
@@ -427,7 +427,7 @@ class ChatBotUI(object):
427
  def set_callbacks(self, *args, **kwargs):
428
 
429
  ########################################
430
- @spaces.GPU(duration=600)
431
  def change_model(model_name):
432
  if model_name not in self.model_choices:
433
  gr.Info('The provided model name is not a valid choice!')
@@ -577,7 +577,7 @@ class ChatBotUI(object):
577
  outputs=[self.history, self.chatbot, self.text, self.gallery])
578
 
579
  ########################################
580
- @spaces.GPU(duration=600)
581
  def run_chat(message,
582
  extend_prompt,
583
  history,
@@ -796,7 +796,7 @@ class ChatBotUI(object):
796
  outputs=chat_outputs)
797
 
798
  ########################################
799
- @spaces.GPU(duration=120)
800
  def retry_chat(*args):
801
  return run_chat(self.retry_msg, *args)
802
 
@@ -805,7 +805,7 @@ class ChatBotUI(object):
805
  outputs=chat_outputs)
806
 
807
  ########################################
808
- @spaces.GPU(duration=600)
809
  def run_example(task, img, img_mask, ref1, prompt, seed):
810
  edit_image, edit_image_mask, edit_task = [], [], []
811
  if img is not None:
 
427
  def set_callbacks(self, *args, **kwargs):
428
 
429
  ########################################
430
+ @spaces.GPU(duration=60)
431
  def change_model(model_name):
432
  if model_name not in self.model_choices:
433
  gr.Info('The provided model name is not a valid choice!')
 
577
  outputs=[self.history, self.chatbot, self.text, self.gallery])
578
 
579
  ########################################
580
+ @spaces.GPU(duration=60)
581
  def run_chat(message,
582
  extend_prompt,
583
  history,
 
796
  outputs=chat_outputs)
797
 
798
  ########################################
799
+ @spaces.GPU(duration=60)
800
  def retry_chat(*args):
801
  return run_chat(self.retry_msg, *args)
802
 
 
805
  outputs=chat_outputs)
806
 
807
  ########################################
808
+ @spaces.GPU(duration=60)
809
  def run_example(task, img, img_mask, ref1, prompt, seed):
810
  edit_image, edit_image_mask, edit_task = [], [], []
811
  if img is not None:
infer.py CHANGED
@@ -139,6 +139,10 @@ class ACEInference(DiffusionInference):
139
  self.decoder_bias = cfg.get('DECODER_BIAS', 0)
140
  self.default_n_prompt = cfg.get('DEFAULT_N_PROMPT', '')
141
 
 
 
 
 
142
  @torch.no_grad()
143
  def encode_first_stage(self, x, **kwargs):
144
  _, dtype = self.get_function_info(self.first_stage_model, 'encode')
@@ -242,12 +246,8 @@ class ACEInference(DiffusionInference):
242
  ctx, null_ctx = {}, {}
243
 
244
  # Get Noise Shape
245
- self.dynamic_load(self.first_stage_model, 'first_stage_model')
246
  image = to_device(image)
247
  x = self.encode_first_stage(image)
248
- self.dynamic_unload(self.first_stage_model,
249
- 'first_stage_model',
250
- skip_loaded=True)
251
  noise = [
252
  torch.empty(*i.shape, device=we.device_id).normal_(generator=g)
253
  for i in x
@@ -261,7 +261,7 @@ class ACEInference(DiffusionInference):
261
  ctx['x_mask'] = null_ctx['x_mask'] = cond_mask
262
 
263
  # Encode Prompt
264
- self.dynamic_load(self.cond_stage_model, 'cond_stage_model')
265
  function_name, dtype = self.get_function_info(self.cond_stage_model)
266
  cont, cont_mask = getattr(get_model(self.cond_stage_model),
267
  function_name)(prompt)
@@ -271,14 +271,10 @@ class ACEInference(DiffusionInference):
271
  function_name)(n_prompt)
272
  null_cont, null_cont_mask = self.cond_stage_embeddings(
273
  prompt, edit_image, null_cont, null_cont_mask)
274
- self.dynamic_unload(self.cond_stage_model,
275
- 'cond_stage_model',
276
- skip_loaded=False)
277
  ctx['crossattn'] = cont
278
  null_ctx['crossattn'] = null_cont
279
 
280
  # Encode Edit Images
281
- self.dynamic_load(self.first_stage_model, 'first_stage_model')
282
  edit_image = [to_device(i, strict=False) for i in edit_image]
283
  edit_image_mask = [to_device(i, strict=False) for i in edit_image_mask]
284
  e_img, e_mask = [], []
@@ -289,14 +285,11 @@ class ACEInference(DiffusionInference):
289
  m = [None] * len(u)
290
  e_img.append(self.encode_first_stage(u, **kwargs))
291
  e_mask.append([self.interpolate_func(i) for i in m])
292
- self.dynamic_unload(self.first_stage_model,
293
- 'first_stage_model',
294
- skip_loaded=True)
295
  null_ctx['edit'] = ctx['edit'] = e_img
296
  null_ctx['edit_mask'] = ctx['edit_mask'] = e_mask
297
 
298
  # Diffusion Process
299
- self.dynamic_load(self.diffusion_model, 'diffusion_model')
300
  function_name, dtype = self.get_function_info(self.diffusion_model)
301
  with torch.autocast('cuda',
302
  enabled=dtype in ('float16', 'bfloat16'),
@@ -337,17 +330,10 @@ class ACEInference(DiffusionInference):
337
  guide_rescale=guide_rescale,
338
  return_intermediate=None,
339
  **kwargs)
340
- self.dynamic_unload(self.diffusion_model,
341
- 'diffusion_model',
342
- skip_loaded=False)
343
 
344
  # Decode to Pixel Space
345
- self.dynamic_load(self.first_stage_model, 'first_stage_model')
346
  samples = unpack_tensor_into_imagelist(latent, x_shapes)
347
  x_samples = self.decode_first_stage(samples)
348
- self.dynamic_unload(self.first_stage_model,
349
- 'first_stage_model',
350
- skip_loaded=False)
351
 
352
  imgs = [
353
  torch.clamp((x_i + 1.0) / 2.0 + self.decoder_bias / 255,
 
139
  self.decoder_bias = cfg.get('DECODER_BIAS', 0)
140
  self.default_n_prompt = cfg.get('DEFAULT_N_PROMPT', '')
141
 
142
+ self.dynamic_load(self.first_stage_model, 'first_stage_model')
143
+ self.dynamic_load(self.cond_stage_model, 'cond_stage_model')
144
+ self.dynamic_load(self.diffusion_model, 'diffusion_model')
145
+
146
  @torch.no_grad()
147
  def encode_first_stage(self, x, **kwargs):
148
  _, dtype = self.get_function_info(self.first_stage_model, 'encode')
 
246
  ctx, null_ctx = {}, {}
247
 
248
  # Get Noise Shape
 
249
  image = to_device(image)
250
  x = self.encode_first_stage(image)
 
 
 
251
  noise = [
252
  torch.empty(*i.shape, device=we.device_id).normal_(generator=g)
253
  for i in x
 
261
  ctx['x_mask'] = null_ctx['x_mask'] = cond_mask
262
 
263
  # Encode Prompt
264
+
265
  function_name, dtype = self.get_function_info(self.cond_stage_model)
266
  cont, cont_mask = getattr(get_model(self.cond_stage_model),
267
  function_name)(prompt)
 
271
  function_name)(n_prompt)
272
  null_cont, null_cont_mask = self.cond_stage_embeddings(
273
  prompt, edit_image, null_cont, null_cont_mask)
 
 
 
274
  ctx['crossattn'] = cont
275
  null_ctx['crossattn'] = null_cont
276
 
277
  # Encode Edit Images
 
278
  edit_image = [to_device(i, strict=False) for i in edit_image]
279
  edit_image_mask = [to_device(i, strict=False) for i in edit_image_mask]
280
  e_img, e_mask = [], []
 
285
  m = [None] * len(u)
286
  e_img.append(self.encode_first_stage(u, **kwargs))
287
  e_mask.append([self.interpolate_func(i) for i in m])
288
+
 
 
289
  null_ctx['edit'] = ctx['edit'] = e_img
290
  null_ctx['edit_mask'] = ctx['edit_mask'] = e_mask
291
 
292
  # Diffusion Process
 
293
  function_name, dtype = self.get_function_info(self.diffusion_model)
294
  with torch.autocast('cuda',
295
  enabled=dtype in ('float16', 'bfloat16'),
 
330
  guide_rescale=guide_rescale,
331
  return_intermediate=None,
332
  **kwargs)
 
 
 
333
 
334
  # Decode to Pixel Space
 
335
  samples = unpack_tensor_into_imagelist(latent, x_shapes)
336
  x_samples = self.decode_first_stage(samples)
 
 
 
337
 
338
  imgs = [
339
  torch.clamp((x_i + 1.0) / 2.0 + self.decoder_bias / 255,