Spaces:
Running
Running
import torch | |
from tqdm import tqdm | |
from DeDoDe.utils import to_cuda | |
def train_step(train_batch, model, objective, optimizer, grad_scaler=None, **kwargs): | |
optimizer.zero_grad() | |
out = model(train_batch) | |
l = objective(out, train_batch) | |
if grad_scaler is not None: | |
grad_scaler.scale(l).backward() | |
grad_scaler.unscale_(optimizer) | |
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.01) | |
grad_scaler.step(optimizer) | |
grad_scaler.update() | |
else: | |
l.backward() | |
optimizer.step() | |
return {"train_out": out, "train_loss": l.item()} | |
def train_k_steps( | |
n_0, | |
k, | |
dataloader, | |
model, | |
objective, | |
optimizer, | |
lr_scheduler, | |
grad_scaler=None, | |
progress_bar=True, | |
): | |
for n in tqdm(range(n_0, n_0 + k), disable=not progress_bar, mininterval=10.0): | |
batch = next(dataloader) | |
model.train(True) | |
batch = to_cuda(batch) | |
train_step( | |
train_batch=batch, | |
model=model, | |
objective=objective, | |
optimizer=optimizer, | |
lr_scheduler=lr_scheduler, | |
n=n, | |
grad_scaler=grad_scaler, | |
) | |
lr_scheduler.step() | |
def train_epoch( | |
dataloader=None, | |
model=None, | |
objective=None, | |
optimizer=None, | |
lr_scheduler=None, | |
epoch=None, | |
): | |
model.train(True) | |
print(f"At epoch {epoch}") | |
for batch in tqdm(dataloader, mininterval=5.0): | |
batch = to_cuda(batch) | |
train_step( | |
train_batch=batch, model=model, objective=objective, optimizer=optimizer | |
) | |
lr_scheduler.step() | |
return { | |
"model": model, | |
"optimizer": optimizer, | |
"lr_scheduler": lr_scheduler, | |
"epoch": epoch, | |
} | |
def train_k_epochs( | |
start_epoch, end_epoch, dataloader, model, objective, optimizer, lr_scheduler | |
): | |
for epoch in range(start_epoch, end_epoch + 1): | |
train_epoch( | |
dataloader=dataloader, | |
model=model, | |
objective=objective, | |
optimizer=optimizer, | |
lr_scheduler=lr_scheduler, | |
epoch=epoch, | |
) | |