Spaces:
Runtime error
Runtime error
# adapted from EveryDream2Trainer | |
import contextlib | |
import traceback | |
import torch | |
from torch.cuda.amp import GradScaler | |
from StableDiffuser import StableDiffuser | |
class MemoryEfficiencyWrapper: | |
def __init__(self, | |
diffuser: StableDiffuser, | |
use_amp: bool, | |
use_xformers: bool, | |
use_gradient_checkpointing: bool): | |
self.diffuser = diffuser | |
self.is_sd1attn = diffuser.unet.config["attention_head_dim"] == [8, 8, 8, 8] | |
self.is_sd1attn = diffuser.unet.config["attention_head_dim"] == 8 or self.is_sd1attn | |
self.use_amp = use_amp | |
self.use_xformers = use_xformers | |
self.use_gradient_checkpointing = use_gradient_checkpointing | |
def __enter__(self): | |
if self.use_gradient_checkpointing: | |
self.diffuser.unet.enable_gradient_checkpointing() | |
self.diffuser.text_encoder.gradient_checkpointing_enable() | |
if self.use_xformers: | |
if (self.use_amp and self.is_sd1attn) or (not self.is_sd1attn): | |
try: | |
self.diffuser.unet.enable_xformers_memory_efficient_attention() | |
print("Enabled xformers") | |
except Exception as ex: | |
print("failed to load xformers, using attention slicing instead") | |
self.diffuser.unet.set_attention_slice("auto") | |
pass | |
elif (not self.use_amp and self.is_sd1attn): | |
print("AMP is disabled but model is SD1.X, using attention slicing instead of xformers") | |
self.diffuser.unet.set_attention_slice("auto") | |
else: | |
print("xformers disabled via arg, using attention slicing instead") | |
self.diffuser.unet.set_attention_slice("auto") | |
self.diffuser.vae = self.diffuser.vae.to(self.diffuser.vae.device, dtype=torch.float16 if self.use_amp else torch.float32) | |
self.diffuser.unet = self.diffuser.unet.to(self.diffuser.unet.device, dtype=torch.float32) | |
try: | |
# unet = torch.compile(unet) | |
# text_encoder = torch.compile(text_encoder) | |
# vae = torch.compile(vae) | |
torch.set_float32_matmul_precision('high') | |
torch.backends.cudnn.allow_tf32 = True | |
# logging.info("Successfully compiled models") | |
except Exception as ex: | |
print(f"Failed to compile model, continuing anyway, ex: {ex}") | |
pass | |
self.grad_scaler = GradScaler( | |
enabled=self.use_amp, | |
init_scale=2 ** 17.5, | |
growth_factor=2, | |
backoff_factor=1.0 / 2, | |
growth_interval=25, | |
) | |
def step(self, optimizer, loss): | |
self.grad_scaler.scale(loss).backward() | |
self.grad_scaler.step(optimizer) | |
self.grad_scaler.update() | |
def __exit__(self, exc_type, exc_value, tb): | |
if exc_type is not None: | |
traceback.print_exception(exc_type, exc_value, tb) | |
# return False # uncomment to pass exception through): | |
self.diffuser.unet.disable_gradient_checkpointing() | |
try: | |
self.diffuser.text_encoder.gradient_checkpointing_disable() | |
except AttributeError: | |
# self.diffuser.text_encoder is likely `del`eted | |
pass | |
self.diffuser.unet.disable_xformers_memory_efficient_attention() | |
self.diffuser.unet.set_attention_slice("auto") | |