Ming Li commited on
Commit
b4ce7df
1 Parent(s): a9864eb

set fp32 to avoid errors

Browse files
Files changed (1) hide show
  1. model.py +5 -3
model.py CHANGED
@@ -53,14 +53,16 @@ class Model:
53
  ):
54
  return self.pipe
55
  model_id = CONTROLNET_MODEL_IDS[task_name]
56
- controlnet = ControlNetModel.from_pretrained(model_id, torch_dtype=torch.float32)
57
  pipe = StableDiffusionControlNetPipeline.from_pretrained(
58
- base_model_id, safety_checker=None, controlnet=controlnet, torch_dtype=torch.float32
59
  )
60
  pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
61
  if self.device.type == "cuda":
62
  pipe.disable_xformers_memory_efficient_attention()
63
  pipe.to(self.device)
 
 
64
  torch.cuda.empty_cache()
65
  gc.collect()
66
  self.base_model_id = base_model_id
@@ -87,7 +89,7 @@ class Model:
87
  torch.cuda.empty_cache()
88
  gc.collect()
89
  model_id = CONTROLNET_MODEL_IDS[task_name]
90
- controlnet = ControlNetModel.from_pretrained(model_id, torch_dtype=torch.float32)
91
  controlnet.to(self.device)
92
  torch.cuda.empty_cache()
93
  gc.collect()
 
53
  ):
54
  return self.pipe
55
  model_id = CONTROLNET_MODEL_IDS[task_name]
56
+ controlnet = ControlNetModel.from_pretrained(model_id, torch_dtype=torch.float16)
57
  pipe = StableDiffusionControlNetPipeline.from_pretrained(
58
+ base_model_id, safety_checker=None, controlnet=controlnet, torch_dtype=torch.float16
59
  )
60
  pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
61
  if self.device.type == "cuda":
62
  pipe.disable_xformers_memory_efficient_attention()
63
  pipe.to(self.device)
64
+ pipe.enable_model_cpu_offload()
65
+ pipe.enable_vae_slicing()
66
  torch.cuda.empty_cache()
67
  gc.collect()
68
  self.base_model_id = base_model_id
 
89
  torch.cuda.empty_cache()
90
  gc.collect()
91
  model_id = CONTROLNET_MODEL_IDS[task_name]
92
+ controlnet = ControlNetModel.from_pretrained(model_id, torch_dtype=torch.float16)
93
  controlnet.to(self.device)
94
  torch.cuda.empty_cache()
95
  gc.collect()