File size: 3,868 Bytes
62c7319
 
 
 
 
 
8b973ee
 
62c7319
 
8b973ee
 
 
 
62c7319
 
8b973ee
 
 
62c7319
 
 
 
 
8b973ee
 
 
62c7319
8b973ee
 
 
62c7319
 
 
 
 
 
8b973ee
 
 
62c7319
 
8b973ee
 
 
 
62c7319
 
 
 
8b973ee
 
 
 
 
 
 
 
 
 
 
 
62c7319
 
 
 
 
 
 
 
 
 
 
 
 
8b973ee
62c7319
 
 
 
 
 
 
 
8b973ee
 
 
 
62c7319
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
from tqdm import tqdm
from roma.utils.utils import to_cuda
import roma
import torch
import wandb


def log_param_statistics(named_parameters, norm_type=2):
    named_parameters = list(named_parameters)
    grads = [p.grad for n, p in named_parameters if p.grad is not None]
    weight_norms = [
        p.norm(p=norm_type) for n, p in named_parameters if p.grad is not None
    ]
    names = [n for n, p in named_parameters if p.grad is not None]
    param_norm = torch.stack(weight_norms).norm(p=norm_type)
    device = grads[0].device
    grad_norms = torch.stack(
        [torch.norm(g.detach(), norm_type).to(device) for g in grads]
    )
    nans_or_infs = torch.isinf(grad_norms) | torch.isnan(grad_norms)
    nan_inf_names = [name for name, naninf in zip(names, nans_or_infs) if naninf]
    total_grad_norm = torch.norm(grad_norms, norm_type)
    if torch.any(nans_or_infs):
        print(f"These params have nan or inf grads: {nan_inf_names}")
    wandb.log({"grad_norm": total_grad_norm.item()}, step=roma.GLOBAL_STEP)
    wandb.log({"param_norm": param_norm.item()}, step=roma.GLOBAL_STEP)


def train_step(
    train_batch, model, objective, optimizer, grad_scaler, grad_clip_norm=1.0, **kwargs
):
    optimizer.zero_grad()
    out = model(train_batch)
    l = objective(out, train_batch)
    grad_scaler.scale(l).backward()
    grad_scaler.unscale_(optimizer)
    log_param_statistics(model.named_parameters())
    torch.nn.utils.clip_grad_norm_(
        model.parameters(), grad_clip_norm
    )  # what should max norm be?
    grad_scaler.step(optimizer)
    grad_scaler.update()
    wandb.log({"grad_scale": grad_scaler._scale.item()}, step=roma.GLOBAL_STEP)
    if grad_scaler._scale < 1.0:
        grad_scaler._scale = torch.tensor(1.0).to(grad_scaler._scale)
    roma.GLOBAL_STEP = roma.GLOBAL_STEP + roma.STEP_SIZE  # increment global step
    return {"train_out": out, "train_loss": l.item()}


def train_k_steps(
    n_0,
    k,
    dataloader,
    model,
    objective,
    optimizer,
    lr_scheduler,
    grad_scaler,
    progress_bar=True,
    grad_clip_norm=1.0,
    warmup=None,
    ema_model=None,
):
    for n in tqdm(range(n_0, n_0 + k), disable=(not progress_bar) or roma.RANK > 0):
        batch = next(dataloader)
        model.train(True)
        batch = to_cuda(batch)
        train_step(
            train_batch=batch,
            model=model,
            objective=objective,
            optimizer=optimizer,
            lr_scheduler=lr_scheduler,
            grad_scaler=grad_scaler,
            n=n,
            grad_clip_norm=grad_clip_norm,
        )
        if ema_model is not None:
            ema_model.update()
        if warmup is not None:
            with warmup.dampening():
                lr_scheduler.step()
        else:
            lr_scheduler.step()
        [
            wandb.log({f"lr_group_{grp}": lr})
            for grp, lr in enumerate(lr_scheduler.get_last_lr())
        ]


def train_epoch(
    dataloader=None,
    model=None,
    objective=None,
    optimizer=None,
    lr_scheduler=None,
    epoch=None,
):
    model.train(True)
    print(f"At epoch {epoch}")
    for batch in tqdm(dataloader, mininterval=5.0):
        batch = to_cuda(batch)
        train_step(
            train_batch=batch, model=model, objective=objective, optimizer=optimizer
        )
    lr_scheduler.step()
    return {
        "model": model,
        "optimizer": optimizer,
        "lr_scheduler": lr_scheduler,
        "epoch": epoch,
    }


def train_k_epochs(
    start_epoch, end_epoch, dataloader, model, objective, optimizer, lr_scheduler
):
    for epoch in range(start_epoch, end_epoch + 1):
        train_epoch(
            dataloader=dataloader,
            model=model,
            objective=objective,
            optimizer=optimizer,
            lr_scheduler=lr_scheduler,
            epoch=epoch,
        )