Erasing-Concepts-In-Diffusion / memory_efficiency.py
Damian Stewart
add train seed
94be4c7
raw
history blame
3.44 kB
# adapted from EveryDream2Trainer
import contextlib
import traceback
import torch
from torch.cuda.amp import GradScaler
from StableDiffuser import StableDiffuser
class MemoryEfficiencyWrapper:
def __init__(self,
diffuser: StableDiffuser,
use_amp: bool,
use_xformers: bool,
use_gradient_checkpointing: bool):
self.diffuser = diffuser
self.is_sd1attn = diffuser.unet.config["attention_head_dim"] == [8, 8, 8, 8]
self.is_sd1attn = diffuser.unet.config["attention_head_dim"] == 8 or self.is_sd1attn
self.use_amp = use_amp
self.use_xformers = use_xformers
self.use_gradient_checkpointing = use_gradient_checkpointing
def __enter__(self):
if self.use_gradient_checkpointing:
self.diffuser.unet.enable_gradient_checkpointing()
self.diffuser.text_encoder.gradient_checkpointing_enable()
if self.use_xformers:
if (self.use_amp and self.is_sd1attn) or (not self.is_sd1attn):
try:
self.diffuser.unet.enable_xformers_memory_efficient_attention()
print("Enabled xformers")
except Exception as ex:
print("failed to load xformers, using attention slicing instead")
self.diffuser.unet.set_attention_slice("auto")
pass
elif (not self.use_amp and self.is_sd1attn):
print("AMP is disabled but model is SD1.X, using attention slicing instead of xformers")
self.diffuser.unet.set_attention_slice("auto")
else:
print("xformers disabled via arg, using attention slicing instead")
self.diffuser.unet.set_attention_slice("auto")
self.diffuser.vae = self.diffuser.vae.to(self.diffuser.vae.device, dtype=torch.float16 if self.use_amp else torch.float32)
self.diffuser.unet = self.diffuser.unet.to(self.diffuser.unet.device, dtype=torch.float32)
try:
# unet = torch.compile(unet)
# text_encoder = torch.compile(text_encoder)
# vae = torch.compile(vae)
torch.set_float32_matmul_precision('high')
torch.backends.cudnn.allow_tf32 = True
# logging.info("Successfully compiled models")
except Exception as ex:
print(f"Failed to compile model, continuing anyway, ex: {ex}")
pass
self.grad_scaler = GradScaler(
enabled=self.use_amp,
init_scale=2 ** 17.5,
growth_factor=2,
backoff_factor=1.0 / 2,
growth_interval=25,
)
def step(self, optimizer, loss):
self.grad_scaler.scale(loss).backward()
self.grad_scaler.step(optimizer)
self.grad_scaler.update()
def __exit__(self, exc_type, exc_value, tb):
if exc_type is not None:
traceback.print_exception(exc_type, exc_value, tb)
# return False # uncomment to pass exception through):
self.diffuser.unet.disable_gradient_checkpointing()
try:
self.diffuser.text_encoder.gradient_checkpointing_disable()
except AttributeError:
# self.diffuser.text_encoder is likely `del`eted
pass
self.diffuser.unet.disable_xformers_memory_efficient_attention()
self.diffuser.unet.set_attention_slice("auto")