Dragunflie-420 commited on
Commit
5560b78
1 Parent(s): 9c63ad9

Upload train.py

Browse files
Files changed (1) hide show
  1. train.py +269 -0
train.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ A minimal training script for DiT using PyTorch DDP.
9
+ """
10
+ import torch
11
+ # the first flag below was False when we tested this script but True makes A100 training a lot faster:
12
+ torch.backends.cuda.matmul.allow_tf32 = True
13
+ torch.backends.cudnn.allow_tf32 = True
14
+ import torch.distributed as dist
15
+ from torch.nn.parallel import DistributedDataParallel as DDP
16
+ from torch.utils.data import DataLoader
17
+ from torch.utils.data.distributed import DistributedSampler
18
+ from torchvision.datasets import ImageFolder
19
+ from torchvision import transforms
20
+ import numpy as np
21
+ from collections import OrderedDict
22
+ from PIL import Image
23
+ from copy import deepcopy
24
+ from glob import glob
25
+ from time import time
26
+ import argparse
27
+ import logging
28
+ import os
29
+
30
+ from models import DiT_models
31
+ from diffusion import create_diffusion
32
+ from diffusers.models import AutoencoderKL
33
+
34
+
35
+ #################################################################################
36
+ # Training Helper Functions #
37
+ #################################################################################
38
+
39
+ @torch.no_grad()
40
+ def update_ema(ema_model, model, decay=0.9999):
41
+ """
42
+ Step the EMA model towards the current model.
43
+ """
44
+ ema_params = OrderedDict(ema_model.named_parameters())
45
+ model_params = OrderedDict(model.named_parameters())
46
+
47
+ for name, param in model_params.items():
48
+ # TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed
49
+ ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay)
50
+
51
+
52
+ def requires_grad(model, flag=True):
53
+ """
54
+ Set requires_grad flag for all parameters in a model.
55
+ """
56
+ for p in model.parameters():
57
+ p.requires_grad = flag
58
+
59
+
60
+ def cleanup():
61
+ """
62
+ End DDP training.
63
+ """
64
+ dist.destroy_process_group()
65
+
66
+
67
+ def create_logger(logging_dir):
68
+ """
69
+ Create a logger that writes to a log file and stdout.
70
+ """
71
+ if dist.get_rank() == 0: # real logger
72
+ logging.basicConfig(
73
+ level=logging.INFO,
74
+ format='[\033[34m%(asctime)s\033[0m] %(message)s',
75
+ datefmt='%Y-%m-%d %H:%M:%S',
76
+ handlers=[logging.StreamHandler(), logging.FileHandler(f"{logging_dir}/log.txt")]
77
+ )
78
+ logger = logging.getLogger(__name__)
79
+ else: # dummy logger (does nothing)
80
+ logger = logging.getLogger(__name__)
81
+ logger.addHandler(logging.NullHandler())
82
+ return logger
83
+
84
+
85
+ def center_crop_arr(pil_image, image_size):
86
+ """
87
+ Center cropping implementation from ADM.
88
+ https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
89
+ """
90
+ while min(*pil_image.size) >= 2 * image_size:
91
+ pil_image = pil_image.resize(
92
+ tuple(x // 2 for x in pil_image.size), resample=Image.BOX
93
+ )
94
+
95
+ scale = image_size / min(*pil_image.size)
96
+ pil_image = pil_image.resize(
97
+ tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
98
+ )
99
+
100
+ arr = np.array(pil_image)
101
+ crop_y = (arr.shape[0] - image_size) // 2
102
+ crop_x = (arr.shape[1] - image_size) // 2
103
+ return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size])
104
+
105
+
106
+ #################################################################################
107
+ # Training Loop #
108
+ #################################################################################
109
+
110
+ def main(args):
111
+ """
112
+ Trains a new DiT model.
113
+ """
114
+ assert torch.cuda.is_available(), "Training currently requires at least one GPU."
115
+
116
+ # Setup DDP:
117
+ dist.init_process_group("nccl")
118
+ assert args.global_batch_size % dist.get_world_size() == 0, f"Batch size must be divisible by world size."
119
+ rank = dist.get_rank()
120
+ device = rank % torch.cuda.device_count()
121
+ seed = args.global_seed * dist.get_world_size() + rank
122
+ torch.manual_seed(seed)
123
+ torch.cuda.set_device(device)
124
+ print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.")
125
+
126
+ # Setup an experiment folder:
127
+ if rank == 0:
128
+ os.makedirs(args.results_dir, exist_ok=True) # Make results folder (holds all experiment subfolders)
129
+ experiment_index = len(glob(f"{args.results_dir}/*"))
130
+ model_string_name = args.model.replace("/", "-") # e.g., DiT-XL/2 --> DiT-XL-2 (for naming folders)
131
+ experiment_dir = f"{args.results_dir}/{experiment_index:03d}-{model_string_name}" # Create an experiment folder
132
+ checkpoint_dir = f"{experiment_dir}/checkpoints" # Stores saved model checkpoints
133
+ os.makedirs(checkpoint_dir, exist_ok=True)
134
+ logger = create_logger(experiment_dir)
135
+ logger.info(f"Experiment directory created at {experiment_dir}")
136
+ else:
137
+ logger = create_logger(None)
138
+
139
+ # Create model:
140
+ assert args.image_size % 8 == 0, "Image size must be divisible by 8 (for the VAE encoder)."
141
+ latent_size = args.image_size // 8
142
+ model = DiT_models[args.model](
143
+ input_size=latent_size,
144
+ num_classes=args.num_classes
145
+ )
146
+ # Note that parameter initialization is done within the DiT constructor
147
+ ema = deepcopy(model).to(device) # Create an EMA of the model for use after training
148
+ requires_grad(ema, False)
149
+ model = DDP(model.to(device), device_ids=[rank])
150
+ diffusion = create_diffusion(timestep_respacing="") # default: 1000 steps, linear noise schedule
151
+ vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-{args.vae}").to(device)
152
+ logger.info(f"DiT Parameters: {sum(p.numel() for p in model.parameters()):,}")
153
+
154
+ # Setup optimizer (we used default Adam betas=(0.9, 0.999) and a constant learning rate of 1e-4 in our paper):
155
+ opt = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0)
156
+
157
+ # Setup data:
158
+ transform = transforms.Compose([
159
+ transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, args.image_size)),
160
+ transforms.RandomHorizontalFlip(),
161
+ transforms.ToTensor(),
162
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
163
+ ])
164
+ dataset = ImageFolder(args.data_path, transform=transform)
165
+ sampler = DistributedSampler(
166
+ dataset,
167
+ num_replicas=dist.get_world_size(),
168
+ rank=rank,
169
+ shuffle=True,
170
+ seed=args.global_seed
171
+ )
172
+ loader = DataLoader(
173
+ dataset,
174
+ batch_size=int(args.global_batch_size // dist.get_world_size()),
175
+ shuffle=False,
176
+ sampler=sampler,
177
+ num_workers=args.num_workers,
178
+ pin_memory=True,
179
+ drop_last=True
180
+ )
181
+ logger.info(f"Dataset contains {len(dataset):,} images ({args.data_path})")
182
+
183
+ # Prepare models for training:
184
+ update_ema(ema, model.module, decay=0) # Ensure EMA is initialized with synced weights
185
+ model.train() # important! This enables embedding dropout for classifier-free guidance
186
+ ema.eval() # EMA model should always be in eval mode
187
+
188
+ # Variables for monitoring/logging purposes:
189
+ train_steps = 0
190
+ log_steps = 0
191
+ running_loss = 0
192
+ start_time = time()
193
+
194
+ logger.info(f"Training for {args.epochs} epochs...")
195
+ for epoch in range(args.epochs):
196
+ sampler.set_epoch(epoch)
197
+ logger.info(f"Beginning epoch {epoch}...")
198
+ for x, y in loader:
199
+ x = x.to(device)
200
+ y = y.to(device)
201
+ with torch.no_grad():
202
+ # Map input images to latent space + normalize latents:
203
+ x = vae.encode(x).latent_dist.sample().mul_(0.18215)
204
+ t = torch.randint(0, diffusion.num_timesteps, (x.shape[0],), device=device)
205
+ model_kwargs = dict(y=y)
206
+ loss_dict = diffusion.training_losses(model, x, t, model_kwargs)
207
+ loss = loss_dict["loss"].mean()
208
+ opt.zero_grad()
209
+ loss.backward()
210
+ opt.step()
211
+ update_ema(ema, model.module)
212
+
213
+ # Log loss values:
214
+ running_loss += loss.item()
215
+ log_steps += 1
216
+ train_steps += 1
217
+ if train_steps % args.log_every == 0:
218
+ # Measure training speed:
219
+ torch.cuda.synchronize()
220
+ end_time = time()
221
+ steps_per_sec = log_steps / (end_time - start_time)
222
+ # Reduce loss history over all processes:
223
+ avg_loss = torch.tensor(running_loss / log_steps, device=device)
224
+ dist.all_reduce(avg_loss, op=dist.ReduceOp.SUM)
225
+ avg_loss = avg_loss.item() / dist.get_world_size()
226
+ logger.info(f"(step={train_steps:07d}) Train Loss: {avg_loss:.4f}, Train Steps/Sec: {steps_per_sec:.2f}")
227
+ # Reset monitoring variables:
228
+ running_loss = 0
229
+ log_steps = 0
230
+ start_time = time()
231
+
232
+ # Save DiT checkpoint:
233
+ if train_steps % args.ckpt_every == 0 and train_steps > 0:
234
+ if rank == 0:
235
+ checkpoint = {
236
+ "model": model.module.state_dict(),
237
+ "ema": ema.state_dict(),
238
+ "opt": opt.state_dict(),
239
+ "args": args
240
+ }
241
+ checkpoint_path = f"{checkpoint_dir}/{train_steps:07d}.pt"
242
+ torch.save(checkpoint, checkpoint_path)
243
+ logger.info(f"Saved checkpoint to {checkpoint_path}")
244
+ dist.barrier()
245
+
246
+ model.eval() # important! This disables randomized embedding dropout
247
+ # do any sampling/FID calculation/etc. with ema (or model) in eval mode ...
248
+
249
+ logger.info("Done!")
250
+ cleanup()
251
+
252
+
253
+ if __name__ == "__main__":
254
+ # Default args here will train DiT-XL/2 with the hyperparameters we used in our paper (except training iters).
255
+ parser = argparse.ArgumentParser()
256
+ parser.add_argument("--data-path", type=str, required=True)
257
+ parser.add_argument("--results-dir", type=str, default="results")
258
+ parser.add_argument("--model", type=str, choices=list(DiT_models.keys()), default="DiT-XL/2")
259
+ parser.add_argument("--image-size", type=int, choices=[256, 512], default=256)
260
+ parser.add_argument("--num-classes", type=int, default=1000)
261
+ parser.add_argument("--epochs", type=int, default=1400)
262
+ parser.add_argument("--global-batch-size", type=int, default=256)
263
+ parser.add_argument("--global-seed", type=int, default=0)
264
+ parser.add_argument("--vae", type=str, choices=["ema", "mse"], default="ema") # Choice doesn't affect training
265
+ parser.add_argument("--num-workers", type=int, default=4)
266
+ parser.add_argument("--log-every", type=int, default=100)
267
+ parser.add_argument("--ckpt-every", type=int, default=50_000)
268
+ args = parser.parse_args()
269
+ main(args)