Spaces:
Running
Running
File size: 1,742 Bytes
dbf8b7e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 |
from tqdm import tqdm
from dkm.utils.utils import to_cuda
def train_step(train_batch, model, objective, optimizer, **kwargs):
optimizer.zero_grad()
out = model(train_batch)
l = objective(out, train_batch)
l.backward()
optimizer.step()
return {"train_out": out, "train_loss": l.item()}
def train_k_steps(
n_0, k, dataloader, model, objective, optimizer, lr_scheduler, progress_bar=True
):
for n in tqdm(range(n_0, n_0 + k), disable=not progress_bar):
batch = next(dataloader)
model.train(True)
batch = to_cuda(batch)
train_step(
train_batch=batch,
model=model,
objective=objective,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
n=n,
)
lr_scheduler.step()
def train_epoch(
dataloader=None,
model=None,
objective=None,
optimizer=None,
lr_scheduler=None,
epoch=None,
):
model.train(True)
print(f"At epoch {epoch}")
for batch in tqdm(dataloader, mininterval=5.0):
batch = to_cuda(batch)
train_step(
train_batch=batch, model=model, objective=objective, optimizer=optimizer
)
lr_scheduler.step()
return {
"model": model,
"optimizer": optimizer,
"lr_scheduler": lr_scheduler,
"epoch": epoch,
}
def train_k_epochs(
start_epoch, end_epoch, dataloader, model, objective, optimizer, lr_scheduler
):
for epoch in range(start_epoch, end_epoch + 1):
train_epoch(
dataloader=dataloader,
model=model,
objective=objective,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
epoch=epoch,
)
|