File size: 3,435 Bytes
fc73e59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94be4c7
fc73e59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
# adapted from EveryDream2Trainer
import contextlib
import traceback

import torch
from torch.cuda.amp import GradScaler

from StableDiffuser import StableDiffuser


class MemoryEfficiencyWrapper:

    def __init__(self,
                 diffuser: StableDiffuser,
                 use_amp: bool,
                 use_xformers: bool,
                 use_gradient_checkpointing: bool):
        self.diffuser = diffuser
        self.is_sd1attn = diffuser.unet.config["attention_head_dim"] == [8, 8, 8, 8]
        self.is_sd1attn = diffuser.unet.config["attention_head_dim"] == 8 or self.is_sd1attn

        self.use_amp = use_amp
        self.use_xformers = use_xformers
        self.use_gradient_checkpointing = use_gradient_checkpointing

    def __enter__(self):
        if self.use_gradient_checkpointing:
            self.diffuser.unet.enable_gradient_checkpointing()
            self.diffuser.text_encoder.gradient_checkpointing_enable()

        if self.use_xformers:
            if (self.use_amp and self.is_sd1attn) or (not self.is_sd1attn):
                try:
                    self.diffuser.unet.enable_xformers_memory_efficient_attention()
                    print("Enabled xformers")
                except Exception as ex:
                    print("failed to load xformers, using attention slicing instead")
                    self.diffuser.unet.set_attention_slice("auto")
                    pass
            elif (not self.use_amp and self.is_sd1attn):
                print("AMP is disabled but model is SD1.X, using attention slicing instead of xformers")
                self.diffuser.unet.set_attention_slice("auto")
        else:
            print("xformers disabled via arg, using attention slicing instead")
            self.diffuser.unet.set_attention_slice("auto")

        self.diffuser.vae = self.diffuser.vae.to(self.diffuser.vae.device, dtype=torch.float16 if self.use_amp else torch.float32)
        self.diffuser.unet = self.diffuser.unet.to(self.diffuser.unet.device, dtype=torch.float32)

        try:
            # unet = torch.compile(unet)
            # text_encoder = torch.compile(text_encoder)
            # vae = torch.compile(vae)
            torch.set_float32_matmul_precision('high')
            torch.backends.cudnn.allow_tf32 = True
            # logging.info("Successfully compiled models")
        except Exception as ex:
            print(f"Failed to compile model, continuing anyway, ex: {ex}")
            pass

        self.grad_scaler = GradScaler(
            enabled=self.use_amp,
            init_scale=2 ** 17.5,
            growth_factor=2,
            backoff_factor=1.0 / 2,
            growth_interval=25,
        )

    def step(self, optimizer, loss):
        self.grad_scaler.scale(loss).backward()
        self.grad_scaler.step(optimizer)
        self.grad_scaler.update()

    def __exit__(self, exc_type, exc_value, tb):
        if exc_type is not None:
            traceback.print_exception(exc_type, exc_value, tb)
        # return False # uncomment to pass exception through):
        self.diffuser.unet.disable_gradient_checkpointing()
        try:
            self.diffuser.text_encoder.gradient_checkpointing_disable()
        except AttributeError:
            # self.diffuser.text_encoder is likely `del`eted
            pass

        self.diffuser.unet.disable_xformers_memory_efficient_attention()
        self.diffuser.unet.set_attention_slice("auto")