# adapted from EveryDream2Trainer import contextlib import traceback import torch from torch.cuda.amp import GradScaler from StableDiffuser import StableDiffuser class MemoryEfficiencyWrapper: def __init__(self, diffuser: StableDiffuser, use_amp: bool, use_xformers: bool, use_gradient_checkpointing: bool): self.diffuser = diffuser self.is_sd1attn = diffuser.unet.config["attention_head_dim"] == [8, 8, 8, 8] self.is_sd1attn = diffuser.unet.config["attention_head_dim"] == 8 or self.is_sd1attn self.use_amp = use_amp self.use_xformers = use_xformers self.use_gradient_checkpointing = use_gradient_checkpointing def __enter__(self): if self.use_gradient_checkpointing: self.diffuser.unet.enable_gradient_checkpointing() self.diffuser.text_encoder.gradient_checkpointing_enable() if self.use_xformers: if (self.use_amp and self.is_sd1attn) or (not self.is_sd1attn): try: self.diffuser.unet.enable_xformers_memory_efficient_attention() print("Enabled xformers") except Exception as ex: print("failed to load xformers, using attention slicing instead") self.diffuser.unet.set_attention_slice("auto") pass elif (not self.use_amp and self.is_sd1attn): print("AMP is disabled but model is SD1.X, using attention slicing instead of xformers") self.diffuser.unet.set_attention_slice("auto") else: print("xformers disabled via arg, using attention slicing instead") self.diffuser.unet.set_attention_slice("auto") self.diffuser.vae = self.diffuser.vae.to(self.diffuser.vae.device, dtype=torch.float16 if self.use_amp else torch.float32) self.diffuser.unet = self.diffuser.unet.to(self.diffuser.unet.device, dtype=torch.float32) try: # unet = torch.compile(unet) # text_encoder = torch.compile(text_encoder) # vae = torch.compile(vae) torch.set_float32_matmul_precision('high') torch.backends.cudnn.allow_tf32 = True # logging.info("Successfully compiled models") except Exception as ex: print(f"Failed to compile model, continuing anyway, ex: {ex}") pass self.grad_scaler = GradScaler( enabled=self.use_amp, init_scale=2 ** 17.5, growth_factor=2, backoff_factor=1.0 / 2, growth_interval=25, ) def step(self, optimizer, loss): self.grad_scaler.scale(loss).backward() self.grad_scaler.step(optimizer) self.grad_scaler.update() def __exit__(self, exc_type, exc_value, tb): if exc_type is not None: traceback.print_exception(exc_type, exc_value, tb) # return False # uncomment to pass exception through): self.diffuser.unet.disable_gradient_checkpointing() try: self.diffuser.text_encoder.gradient_checkpointing_disable() except AttributeError: # self.diffuser.text_encoder is likely `del`eted pass self.diffuser.unet.disable_xformers_memory_efficient_attention() self.diffuser.unet.set_attention_slice("auto")