|
from tqdm import tqdm |
|
from roma.utils.utils import to_cuda |
|
import roma |
|
import torch |
|
import wandb |
|
|
|
|
|
def log_param_statistics(named_parameters, norm_type=2): |
|
named_parameters = list(named_parameters) |
|
grads = [p.grad for n, p in named_parameters if p.grad is not None] |
|
weight_norms = [ |
|
p.norm(p=norm_type) for n, p in named_parameters if p.grad is not None |
|
] |
|
names = [n for n, p in named_parameters if p.grad is not None] |
|
param_norm = torch.stack(weight_norms).norm(p=norm_type) |
|
device = grads[0].device |
|
grad_norms = torch.stack( |
|
[torch.norm(g.detach(), norm_type).to(device) for g in grads] |
|
) |
|
nans_or_infs = torch.isinf(grad_norms) | torch.isnan(grad_norms) |
|
nan_inf_names = [name for name, naninf in zip(names, nans_or_infs) if naninf] |
|
total_grad_norm = torch.norm(grad_norms, norm_type) |
|
if torch.any(nans_or_infs): |
|
print(f"These params have nan or inf grads: {nan_inf_names}") |
|
wandb.log({"grad_norm": total_grad_norm.item()}, step=roma.GLOBAL_STEP) |
|
wandb.log({"param_norm": param_norm.item()}, step=roma.GLOBAL_STEP) |
|
|
|
|
|
def train_step( |
|
train_batch, model, objective, optimizer, grad_scaler, grad_clip_norm=1.0, **kwargs |
|
): |
|
optimizer.zero_grad() |
|
out = model(train_batch) |
|
l = objective(out, train_batch) |
|
grad_scaler.scale(l).backward() |
|
grad_scaler.unscale_(optimizer) |
|
log_param_statistics(model.named_parameters()) |
|
torch.nn.utils.clip_grad_norm_( |
|
model.parameters(), grad_clip_norm |
|
) |
|
grad_scaler.step(optimizer) |
|
grad_scaler.update() |
|
wandb.log({"grad_scale": grad_scaler._scale.item()}, step=roma.GLOBAL_STEP) |
|
if grad_scaler._scale < 1.0: |
|
grad_scaler._scale = torch.tensor(1.0).to(grad_scaler._scale) |
|
roma.GLOBAL_STEP = roma.GLOBAL_STEP + roma.STEP_SIZE |
|
return {"train_out": out, "train_loss": l.item()} |
|
|
|
|
|
def train_k_steps( |
|
n_0, |
|
k, |
|
dataloader, |
|
model, |
|
objective, |
|
optimizer, |
|
lr_scheduler, |
|
grad_scaler, |
|
progress_bar=True, |
|
grad_clip_norm=1.0, |
|
warmup=None, |
|
ema_model=None, |
|
): |
|
for n in tqdm(range(n_0, n_0 + k), disable=(not progress_bar) or roma.RANK > 0): |
|
batch = next(dataloader) |
|
model.train(True) |
|
batch = to_cuda(batch) |
|
train_step( |
|
train_batch=batch, |
|
model=model, |
|
objective=objective, |
|
optimizer=optimizer, |
|
lr_scheduler=lr_scheduler, |
|
grad_scaler=grad_scaler, |
|
n=n, |
|
grad_clip_norm=grad_clip_norm, |
|
) |
|
if ema_model is not None: |
|
ema_model.update() |
|
if warmup is not None: |
|
with warmup.dampening(): |
|
lr_scheduler.step() |
|
else: |
|
lr_scheduler.step() |
|
[ |
|
wandb.log({f"lr_group_{grp}": lr}) |
|
for grp, lr in enumerate(lr_scheduler.get_last_lr()) |
|
] |
|
|
|
|
|
def train_epoch( |
|
dataloader=None, |
|
model=None, |
|
objective=None, |
|
optimizer=None, |
|
lr_scheduler=None, |
|
epoch=None, |
|
): |
|
model.train(True) |
|
print(f"At epoch {epoch}") |
|
for batch in tqdm(dataloader, mininterval=5.0): |
|
batch = to_cuda(batch) |
|
train_step( |
|
train_batch=batch, model=model, objective=objective, optimizer=optimizer |
|
) |
|
lr_scheduler.step() |
|
return { |
|
"model": model, |
|
"optimizer": optimizer, |
|
"lr_scheduler": lr_scheduler, |
|
"epoch": epoch, |
|
} |
|
|
|
|
|
def train_k_epochs( |
|
start_epoch, end_epoch, dataloader, model, objective, optimizer, lr_scheduler |
|
): |
|
for epoch in range(start_epoch, end_epoch + 1): |
|
train_epoch( |
|
dataloader=dataloader, |
|
model=model, |
|
objective=objective, |
|
optimizer=optimizer, |
|
lr_scheduler=lr_scheduler, |
|
epoch=epoch, |
|
) |
|
|