Pyramid-Flow / trainer_misc /fsdp_trainer.py
sironagasuyagi's picture
Upload folder using huggingface_hub
910e2ad verified
raw
history blame
5.9 kB
import math
import sys
from typing import Iterable
import torch
import torch.nn as nn
import accelerate
from .utils import MetricLogger, SmoothedValue
def update_ema_for_dit(model, model_ema, accelerator, decay):
"""Apply exponential moving average update.
The weights are updated in-place as follow:
w_ema = w_ema * decay + (1 - decay) * w
Args:
model: active model that is being optimized
model_ema: running average model
decay: exponential decay parameter
"""
with torch.no_grad():
msd = accelerator.get_state_dict(model)
for k, ema_v in model_ema.state_dict().items():
if k in msd:
model_v = msd[k].detach().to(ema_v.device, dtype=ema_v.dtype)
ema_v.copy_(ema_v * decay + (1.0 - decay) * model_v)
def get_decay(optimization_step: int, ema_decay: float) -> float:
"""
Compute the decay factor for the exponential moving average.
"""
step = max(0, optimization_step - 1)
if step <= 0:
return 0.0
cur_decay_value = (1 + step) / (10 + step)
cur_decay_value = min(cur_decay_value, ema_decay)
cur_decay_value = max(cur_decay_value, 0.0)
return cur_decay_value
def train_one_epoch_with_fsdp(
runner,
model_ema: torch.nn.Module,
accelerator: accelerate.Accelerator,
model_dtype: str,
data_loader: Iterable,
optimizer: torch.optim.Optimizer,
lr_schedule_values,
device: torch.device,
epoch: int,
clip_grad: float = 1.0,
start_steps=None,
args=None,
print_freq=20,
iters_per_epoch=2000,
ema_decay=0.9999,
use_temporal_pyramid=True,
):
runner.dit.train()
metric_logger = MetricLogger(delimiter=" ")
metric_logger.add_meter('lr', SmoothedValue(window_size=1, fmt='{value:.6f}'))
metric_logger.add_meter('min_lr', SmoothedValue(window_size=1, fmt='{value:.6f}'))
header = 'Epoch: [{}]'.format(epoch)
train_loss = 0.0
print("Start training epoch {}, {} iters per inner epoch. Training dtype {}".format(epoch, iters_per_epoch, model_dtype))
for step in metric_logger.log_every(range(iters_per_epoch), print_freq, header):
if step >= iters_per_epoch:
break
if lr_schedule_values is not None:
for i, param_group in enumerate(optimizer.param_groups):
param_group["lr"] = lr_schedule_values[start_steps] * param_group.get("lr_scale", 1.0)
for _ in range(args.gradient_accumulation_steps):
with accelerator.accumulate(runner.dit):
# To fetch the data sample and Move the input to device
samples = next(data_loader)
video = samples['video'].to(accelerator.device)
text = samples['text']
identifier = samples['identifier']
# Perform the forward using the accerlate
loss, log_loss = runner(video, text, identifier,
use_temporal_pyramid=use_temporal_pyramid, accelerator=accelerator)
# Check if the loss is nan
loss_value = loss.item()
if not math.isfinite(loss_value):
print("Loss is {}, stopping training".format(loss_value), force=True)
sys.exit(1)
avg_loss = accelerator.gather(loss.repeat(args.batch_size)).mean()
train_loss += avg_loss.item() / args.gradient_accumulation_steps
accelerator.backward(loss)
# clip the gradient
if accelerator.sync_gradients:
params_to_clip = runner.dit.parameters()
grad_norm = accelerator.clip_grad_norm_(params_to_clip, clip_grad)
# To deal with the abnormal data point
if train_loss >= 2.0:
print(f"The ERROR data sample, finding extreme high loss {train_loss}, skip updating the parameters", force=True)
# zero out the gradient, do not update
optimizer.zero_grad()
train_loss = 0.001 # fix the loss for logging
else:
optimizer.step()
optimizer.zero_grad()
if accelerator.sync_gradients:
# Update every 100 steps
if model_ema is not None and start_steps % 100 == 0:
# cur_ema_decay = get_decay(start_steps, ema_decay)
cur_ema_decay = ema_decay
update_ema_for_dit(runner.dit, model_ema, accelerator, decay=cur_ema_decay)
start_steps += 1
# Report to tensorboard
accelerator.log({"train_loss": train_loss}, step=start_steps)
metric_logger.update(loss=train_loss)
train_loss = 0.0
min_lr = 10.
max_lr = 0.
for group in optimizer.param_groups:
min_lr = min(min_lr, group["lr"])
max_lr = max(max_lr, group["lr"])
metric_logger.update(lr=max_lr)
metric_logger.update(min_lr=min_lr)
weight_decay_value = None
for group in optimizer.param_groups:
if group["weight_decay"] > 0:
weight_decay_value = group["weight_decay"]
metric_logger.update(weight_decay=weight_decay_value)
metric_logger.update(grad_norm=grad_norm)
# gather the stats from all processes
metric_logger.synchronize_between_processes()
print("Averaged stats:", metric_logger)
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}