chaojiemao commited on
Commit
1828f85
1 Parent(s): 5b0cd30

Update ace_inference.py

Browse files
Files changed (1) hide show
  1. ace_inference.py +11 -10
ace_inference.py CHANGED
@@ -330,6 +330,7 @@ class ACEInference(DiffusionInference):
330
  history_io=None,
331
  tar_index=0,
332
  **kwargs):
 
333
  input_image, input_mask = image, mask
334
  g = torch.Generator(device=we.device_id)
335
  seed = seed if seed >= 0 else random.randint(0, 2**32 - 1)
@@ -396,9 +397,9 @@ class ACEInference(DiffusionInference):
396
  if use_ace and (not is_txt_image or refiner_scale <= 0):
397
  ctx, null_ctx = {}, {}
398
  # Get Noise Shape
399
- if use_dynamic_model: self.dynamic_load(self.first_stage_model, 'first_stage_model')
400
  x = self.encode_first_stage(image)
401
- if use_dynamic_model: self.dynamic_unload(self.first_stage_model,
402
  'first_stage_model',
403
  skip_loaded=True)
404
  noise = [
@@ -414,7 +415,7 @@ class ACEInference(DiffusionInference):
414
  ctx['x_mask'] = null_ctx['x_mask'] = cond_mask
415
 
416
  # Encode Prompt
417
- if use_dynamic_model: self.dynamic_load(self.cond_stage_model, 'cond_stage_model')
418
  function_name, dtype = self.get_function_info(self.cond_stage_model)
419
  cont, cont_mask = getattr(get_model(self.cond_stage_model),
420
  function_name)(prompt)
@@ -424,14 +425,14 @@ class ACEInference(DiffusionInference):
424
  function_name)(n_prompt)
425
  null_cont, null_cont_mask = self.cond_stage_embeddings(
426
  prompt, edit_image, null_cont, null_cont_mask)
427
- if use_dynamic_model: self.dynamic_unload(self.cond_stage_model,
428
  'cond_stage_model',
429
  skip_loaded=False)
430
  ctx['crossattn'] = cont
431
  null_ctx['crossattn'] = null_cont
432
 
433
  # Encode Edit Images
434
- if use_dynamic_model: self.dynamic_load(self.first_stage_model, 'first_stage_model')
435
  edit_image = [to_device(i, strict=False) for i in edit_image]
436
  edit_image_mask = [to_device(i, strict=False) for i in edit_image_mask]
437
  e_img, e_mask = [], []
@@ -442,14 +443,14 @@ class ACEInference(DiffusionInference):
442
  m = [None] * len(u)
443
  e_img.append(self.encode_first_stage(u, **kwargs))
444
  e_mask.append([self.interpolate_func(i) for i in m])
445
- if use_dynamic_model: self.dynamic_unload(self.first_stage_model,
446
  'first_stage_model',
447
  skip_loaded=True)
448
  null_ctx['edit'] = ctx['edit'] = e_img
449
  null_ctx['edit_mask'] = ctx['edit_mask'] = e_mask
450
 
451
  # Diffusion Process
452
- if use_dynamic_model: self.dynamic_load(self.diffusion_model, 'diffusion_model')
453
  function_name, dtype = self.get_function_info(self.diffusion_model)
454
  with torch.autocast('cuda',
455
  enabled=dtype in ('float16', 'bfloat16'),
@@ -490,15 +491,15 @@ class ACEInference(DiffusionInference):
490
  guide_rescale=guide_rescale,
491
  return_intermediate=None,
492
  **kwargs)
493
- if use_dynamic_model: self.dynamic_unload(self.diffusion_model,
494
  'diffusion_model',
495
  skip_loaded=False)
496
 
497
  # Decode to Pixel Space
498
- if use_dynamic_model: self.dynamic_load(self.first_stage_model, 'first_stage_model')
499
  samples = unpack_tensor_into_imagelist(latent, x_shapes)
500
  x_samples = self.decode_first_stage(samples)
501
- if use_dynamic_model: self.dynamic_unload(self.first_stage_model,
502
  'first_stage_model',
503
  skip_loaded=False)
504
  x_samples = [x.squeeze(0) for x in x_samples]
 
330
  history_io=None,
331
  tar_index=0,
332
  **kwargs):
333
+ print(kwargs)
334
  input_image, input_mask = image, mask
335
  g = torch.Generator(device=we.device_id)
336
  seed = seed if seed >= 0 else random.randint(0, 2**32 - 1)
 
397
  if use_ace and (not is_txt_image or refiner_scale <= 0):
398
  ctx, null_ctx = {}, {}
399
  # Get Noise Shape
400
+ if self.use_dynamic_model: self.dynamic_load(self.first_stage_model, 'first_stage_model')
401
  x = self.encode_first_stage(image)
402
+ if self.use_dynamic_model: self.dynamic_unload(self.first_stage_model,
403
  'first_stage_model',
404
  skip_loaded=True)
405
  noise = [
 
415
  ctx['x_mask'] = null_ctx['x_mask'] = cond_mask
416
 
417
  # Encode Prompt
418
+ if self.use_dynamic_model: self.dynamic_load(self.cond_stage_model, 'cond_stage_model')
419
  function_name, dtype = self.get_function_info(self.cond_stage_model)
420
  cont, cont_mask = getattr(get_model(self.cond_stage_model),
421
  function_name)(prompt)
 
425
  function_name)(n_prompt)
426
  null_cont, null_cont_mask = self.cond_stage_embeddings(
427
  prompt, edit_image, null_cont, null_cont_mask)
428
+ if self.use_dynamic_model: self.dynamic_unload(self.cond_stage_model,
429
  'cond_stage_model',
430
  skip_loaded=False)
431
  ctx['crossattn'] = cont
432
  null_ctx['crossattn'] = null_cont
433
 
434
  # Encode Edit Images
435
+ if self.use_dynamic_model: self.dynamic_load(self.first_stage_model, 'first_stage_model')
436
  edit_image = [to_device(i, strict=False) for i in edit_image]
437
  edit_image_mask = [to_device(i, strict=False) for i in edit_image_mask]
438
  e_img, e_mask = [], []
 
443
  m = [None] * len(u)
444
  e_img.append(self.encode_first_stage(u, **kwargs))
445
  e_mask.append([self.interpolate_func(i) for i in m])
446
+ if self.use_dynamic_model: self.dynamic_unload(self.first_stage_model,
447
  'first_stage_model',
448
  skip_loaded=True)
449
  null_ctx['edit'] = ctx['edit'] = e_img
450
  null_ctx['edit_mask'] = ctx['edit_mask'] = e_mask
451
 
452
  # Diffusion Process
453
+ if self.use_dynamic_model: self.dynamic_load(self.diffusion_model, 'diffusion_model')
454
  function_name, dtype = self.get_function_info(self.diffusion_model)
455
  with torch.autocast('cuda',
456
  enabled=dtype in ('float16', 'bfloat16'),
 
491
  guide_rescale=guide_rescale,
492
  return_intermediate=None,
493
  **kwargs)
494
+ if self.use_dynamic_model: self.dynamic_unload(self.diffusion_model,
495
  'diffusion_model',
496
  skip_loaded=False)
497
 
498
  # Decode to Pixel Space
499
+ if self.use_dynamic_model: self.dynamic_load(self.first_stage_model, 'first_stage_model')
500
  samples = unpack_tensor_into_imagelist(latent, x_shapes)
501
  x_samples = self.decode_first_stage(samples)
502
+ if self.use_dynamic_model: self.dynamic_unload(self.first_stage_model,
503
  'first_stage_model',
504
  skip_loaded=False)
505
  x_samples = [x.squeeze(0) for x in x_samples]