File size: 10,279 Bytes
5231633
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
from .. import WarpCore
from ..utils import EXPECTED, EXPECTED_TRAIN, update_weights_ema, create_folder_if_necessary
from abc import abstractmethod
from dataclasses import dataclass
import torch
from torch import nn
from torch.utils.data import DataLoader
from gdf import GDF
import numpy as np
from tqdm import tqdm
import wandb

import webdataset as wds
from webdataset.handlers import warn_and_continue
from torch.distributed import barrier
from enum import Enum

class TargetReparametrization(Enum):
    EPSILON = 'epsilon'
    X0 = 'x0'

class DiffusionCore(WarpCore):
    @dataclass(frozen=True)
    class Config(WarpCore.Config):
        # TRAINING PARAMS
        lr: float = EXPECTED_TRAIN
        grad_accum_steps: int = EXPECTED_TRAIN
        batch_size: int = EXPECTED_TRAIN
        updates: int = EXPECTED_TRAIN
        warmup_updates: int = EXPECTED_TRAIN
        save_every: int = 500
        backup_every: int = 20000
        use_fsdp: bool = True

        # EMA UPDATE
        ema_start_iters: int = None
        ema_iters: int = None
        ema_beta: float = None

        # GDF setting
        gdf_target_reparametrization: TargetReparametrization = None # epsilon or x0
    
    @dataclass() # not frozen, means that fields are mutable. Doesn't support EXPECTED
    class Info(WarpCore.Info):
        ema_loss: float = None

    @dataclass(frozen=True)
    class Models(WarpCore.Models):
        generator : nn.Module = EXPECTED
        generator_ema : nn.Module = None # optional

    @dataclass(frozen=True)
    class Optimizers(WarpCore.Optimizers):
        generator : any = EXPECTED

    @dataclass(frozen=True)
    class Schedulers(WarpCore.Schedulers):
        generator: any = None

    @dataclass(frozen=True)
    class Extras(WarpCore.Extras):
        gdf: GDF = EXPECTED
        sampling_configs: dict = EXPECTED

    # --------------------------------------------
    info: Info
    config: Config

    @abstractmethod
    def encode_latents(self, batch: dict, models: Models, extras: Extras) -> torch.Tensor:
        raise NotImplementedError("This method needs to be overriden")

    @abstractmethod
    def decode_latents(self, latents: torch.Tensor, batch: dict, models: Models, extras: Extras) -> torch.Tensor:
        raise NotImplementedError("This method needs to be overriden")

    @abstractmethod
    def get_conditions(self, batch: dict, models: Models, extras: Extras, is_eval=False, is_unconditional=False):
        raise NotImplementedError("This method needs to be overriden")

    @abstractmethod
    def webdataset_path(self, extras: Extras):
        raise NotImplementedError("This method needs to be overriden")

    @abstractmethod
    def webdataset_filters(self, extras: Extras):
        raise NotImplementedError("This method needs to be overriden")
    
    @abstractmethod
    def webdataset_preprocessors(self, extras: Extras):
        raise NotImplementedError("This method needs to be overriden")

    @abstractmethod
    def sample(self, models: Models, data: WarpCore.Data, extras: Extras):
        raise NotImplementedError("This method needs to be overriden")
    # -------------

    def setup_data(self, extras: Extras) -> WarpCore.Data:
        # SETUP DATASET
        dataset_path = self.webdataset_path(extras)
        preprocessors = self.webdataset_preprocessors(extras)
        filters = self.webdataset_filters(extras)

        handler = warn_and_continue # None
        # handler = None
        dataset = wds.WebDataset(
            dataset_path, resampled=True, handler=handler
        ).select(filters).shuffle(690, handler=handler).decode(
            "pilrgb", handler=handler
        ).to_tuple(
            *[p[0] for p in preprocessors], handler=handler
        ).map_tuple(
            *[p[1] for p in preprocessors], handler=handler
        ).map(lambda x: {p[2]:x[i] for i, p in enumerate(preprocessors)})

        # SETUP DATALOADER
        real_batch_size = self.config.batch_size//(self.world_size*self.config.grad_accum_steps)
        dataloader = DataLoader(
            dataset, batch_size=real_batch_size, num_workers=8, pin_memory=True
        )

        return self.Data(dataset=dataset, dataloader=dataloader, iterator=iter(dataloader))

    def forward_pass(self, data: WarpCore.Data, extras: Extras, models: Models):
        batch = next(data.iterator)

        with torch.no_grad():
            conditions = self.get_conditions(batch, models, extras)
            latents = self.encode_latents(batch, models, extras)
            noised, noise, target, logSNR, noise_cond, loss_weight = extras.gdf.diffuse(latents, shift=1, loss_shift=1)

        # FORWARD PASS
        with torch.cuda.amp.autocast(dtype=torch.bfloat16):
            pred = models.generator(noised, noise_cond, **conditions)
            if self.config.gdf_target_reparametrization == TargetReparametrization.EPSILON:
                pred = extras.gdf.undiffuse(noised, logSNR, pred)[1] # transform whatever prediction to epsilon to use in the loss
                target = noise
            elif self.config.gdf_target_reparametrization == TargetReparametrization.X0:
                pred = extras.gdf.undiffuse(noised, logSNR, pred)[0] # transform whatever prediction to x0 to use in the loss
                target = latents
            loss = nn.functional.mse_loss(pred, target, reduction='none').mean(dim=[1, 2, 3])
            loss_adjusted = (loss * loss_weight).mean() / self.config.grad_accum_steps

        return loss, loss_adjusted
    
    def train(self, data: WarpCore.Data, extras: Extras, models: Models, optimizers: Optimizers, schedulers: Schedulers):
        start_iter = self.info.iter+1
        max_iters = self.config.updates * self.config.grad_accum_steps
        if self.is_main_node:
            print(f"STARTING AT STEP: {start_iter}/{max_iters}")

        pbar = tqdm(range(start_iter, max_iters+1)) if self.is_main_node else range(start_iter, max_iters+1) # <--- DDP
        models.generator.train()
        for i in pbar:
            # FORWARD PASS
            loss, loss_adjusted = self.forward_pass(data, extras, models)

            # BACKWARD PASS
            if i % self.config.grad_accum_steps == 0 or i == max_iters:
                loss_adjusted.backward()
                grad_norm = nn.utils.clip_grad_norm_(models.generator.parameters(), 1.0)
                optimizers_dict = optimizers.to_dict()
                for k in optimizers_dict:
                    optimizers_dict[k].step()
                schedulers_dict = schedulers.to_dict()
                for k in schedulers_dict:
                    schedulers_dict[k].step()
                models.generator.zero_grad(set_to_none=True)
                self.info.total_steps += 1
            else:
                with models.generator.no_sync():
                    loss_adjusted.backward()
            self.info.iter = i

            # UPDATE EMA
            if models.generator_ema is not None and i % self.config.ema_iters == 0:
                update_weights_ema(
                    models.generator_ema, models.generator,
                    beta=(self.config.ema_beta if i > self.config.ema_start_iters else 0)
                )

            # UPDATE LOSS METRICS
            self.info.ema_loss = loss.mean().item() if self.info.ema_loss is None else self.info.ema_loss * 0.99 + loss.mean().item() * 0.01

            if self.is_main_node and self.config.wandb_project is not None and np.isnan(loss.mean().item()) or np.isnan(grad_norm.item()):
                wandb.alert(
                    title=f"NaN value encountered in training run {self.info.wandb_run_id}", 
                    text=f"Loss {loss.mean().item()} - Grad Norm {grad_norm.item()}. Run {self.info.wandb_run_id}",
                    wait_duration=60*30
                )

            if self.is_main_node:
                logs = {
                    'loss': self.info.ema_loss, 
                    'raw_loss': loss.mean().item(),
                    'grad_norm': grad_norm.item(),
                    'lr': optimizers.generator.param_groups[0]['lr'],
                    'total_steps': self.info.total_steps,
                }

                pbar.set_postfix(logs)
                if self.config.wandb_project is not None:
                    wandb.log(logs)

            if i == 1 or i % (self.config.save_every*self.config.grad_accum_steps) == 0 or i == max_iters:
                # SAVE AND CHECKPOINT STUFF
                if np.isnan(loss.mean().item()):
                    if self.is_main_node and self.config.wandb_project is not None:
                        tqdm.write("Skipping sampling & checkpoint because the loss is NaN")
                        wandb.alert(title=f"Skipping sampling & checkpoint for training run {self.config.run_id}", text=f"Skipping sampling & checkpoint at {self.info.total_steps} for training run {self.info.wandb_run_id} iters because loss is NaN")
                else:
                    self.save_checkpoints(models, optimizers)
                    if self.is_main_node:
                        create_folder_if_necessary(f'{self.config.output_path}/{self.config.experiment_id}/')
                    self.sample(models, data, extras)

    def models_to_save(self):
        return ['generator', 'generator_ema']

    def save_checkpoints(self, models: Models, optimizers: Optimizers, suffix=None):
        barrier()
        suffix = '' if suffix is None else suffix
        self.save_info(self.info, suffix=suffix)
        models_dict = models.to_dict()
        optimizers_dict = optimizers.to_dict()
        for key in self.models_to_save():
            model = models_dict[key]
            if model is not None:
                self.save_model(model, f"{key}{suffix}", is_fsdp=self.config.use_fsdp)
        for key in optimizers_dict:
            optimizer = optimizers_dict[key]
            if optimizer is not None:
                self.save_optimizer(optimizer, f'{key}_optim{suffix}', fsdp_model=models.generator if self.config.use_fsdp else None)
        if suffix == '' and self.info.total_steps > 1 and self.info.total_steps % self.config.backup_every == 0:
            self.save_checkpoints(models, optimizers, suffix=f"_{self.info.total_steps//1000}k")
        torch.cuda.empty_cache()