mColBERT / colbert /training /training.py
vjeronymo2's picture
Adding model and checkpoint
828992f
raw
history blame
4.72 kB
import os
import random
import time
import torch
import torch.nn as nn
import numpy as np
from transformers import AdamW
from colbert.utils.runs import Run
from colbert.utils.amp import MixedPrecisionManager
from colbert.training.lazy_batcher import LazyBatcher
from colbert.training.eager_batcher import EagerBatcher
from colbert.parameters import DEVICE
from colbert.modeling.colbert import ColBERT
from colbert.utils.utils import print_message
from colbert.training.utils import print_progress, manage_checkpoints
def train(args):
random.seed(12345)
np.random.seed(12345)
torch.manual_seed(12345)
if args.distributed:
torch.cuda.manual_seed_all(12345)
if args.distributed:
assert args.bsize % args.nranks == 0, (args.bsize, args.nranks)
assert args.accumsteps == 1
args.bsize = args.bsize // args.nranks
print("Using args.bsize =", args.bsize, "(per process) and args.accumsteps =", args.accumsteps)
if args.lazy:
reader = LazyBatcher(args, (0 if args.rank == -1 else args.rank), args.nranks)
else:
reader = EagerBatcher(args, (0 if args.rank == -1 else args.rank), args.nranks)
if args.rank not in [-1, 0]:
torch.distributed.barrier()
colbert = ColBERT.from_pretrained('bert-base-multilingual-uncased',
query_maxlen=args.query_maxlen,
doc_maxlen=args.doc_maxlen,
dim=args.dim,
similarity_metric=args.similarity,
mask_punctuation=args.mask_punctuation)
if args.checkpoint is not None:
assert args.resume_optimizer is False, "TODO: This would mean reload optimizer too."
print_message(f"#> Starting from checkpoint {args.checkpoint} -- but NOT the optimizer!")
checkpoint = torch.load(args.checkpoint, map_location='cpu')
try:
colbert.load_state_dict(checkpoint['model_state_dict'])
except:
print_message("[WARNING] Loading checkpoint with strict=False")
colbert.load_state_dict(checkpoint['model_state_dict'], strict=False)
if args.rank == 0:
torch.distributed.barrier()
colbert = colbert.to(DEVICE)
colbert.train()
if args.distributed:
colbert = torch.nn.parallel.DistributedDataParallel(colbert, device_ids=[args.rank],
output_device=args.rank,
find_unused_parameters=True)
optimizer = AdamW(filter(lambda p: p.requires_grad, colbert.parameters()), lr=args.lr, eps=1e-8)
optimizer.zero_grad()
amp = MixedPrecisionManager(args.amp)
criterion = nn.CrossEntropyLoss()
labels = torch.zeros(args.bsize, dtype=torch.long, device=DEVICE)
start_time = time.time()
train_loss = 0.0
start_batch_idx = 0
if args.resume:
assert args.checkpoint is not None
start_batch_idx = checkpoint['batch']
reader.skip_to_batch(start_batch_idx, checkpoint['arguments']['bsize'])
for batch_idx, BatchSteps in zip(range(start_batch_idx, args.maxsteps), reader):
this_batch_loss = 0.0
for queries, passages in BatchSteps:
with amp.context():
scores = colbert(queries, passages).view(2, -1).permute(1, 0)
loss = criterion(scores, labels[:scores.size(0)])
loss = loss / args.accumsteps
if args.rank < 1:
print_progress(scores)
amp.backward(loss)
train_loss += loss.item()
this_batch_loss += loss.item()
amp.step(colbert, optimizer)
if args.rank < 1:
avg_loss = train_loss / (batch_idx+1)
num_examples_seen = (batch_idx - start_batch_idx) * args.bsize * args.nranks
elapsed = float(time.time() - start_time)
log_to_mlflow = (batch_idx % 20 == 0)
Run.log_metric('train/avg_loss', avg_loss, step=batch_idx, log_to_mlflow=log_to_mlflow)
Run.log_metric('train/batch_loss', this_batch_loss, step=batch_idx, log_to_mlflow=log_to_mlflow)
Run.log_metric('train/examples', num_examples_seen, step=batch_idx, log_to_mlflow=log_to_mlflow)
Run.log_metric('train/throughput', num_examples_seen / elapsed, step=batch_idx, log_to_mlflow=log_to_mlflow)
print_message(batch_idx, avg_loss)
manage_checkpoints(args, colbert, optimizer, batch_idx+1)