Realcat
fix: eloftr
ac67bb3
raw
history blame
17.1 kB
# -*- coding: UTF-8 -*-
'''=================================================
@Project -> File pram -> trainer
@IDE PyCharm
@Author fx221@cam.ac.uk
@Date 29/01/2024 15:04
=================================================='''
import datetime
import os
import os.path as osp
import numpy as np
from pathlib import Path
from tensorboardX import SummaryWriter
from tqdm import tqdm
import torch.optim as optim
import torch.nn.functional as F
import shutil
import torch
from torch.autograd import Variable
from tools.common import save_args_yaml, merge_tags
from tools.metrics import compute_iou, compute_precision, SeqIOU, compute_corr_incorr, compute_seg_loss_weight
from tools.metrics import compute_cls_loss_ce, compute_cls_corr
class Trainer:
def __init__(self, model, train_loader, feat_model=None, eval_loader=None, config=None, img_transforms=None):
self.model = model
self.train_loader = train_loader
self.eval_loader = eval_loader
self.config = config
self.with_aug = self.config['with_aug']
self.with_cls = False # self.config['with_cls']
self.with_sc = False # self.config['with_sc']
self.img_transforms = img_transforms
self.feat_model = feat_model.cuda().eval() if feat_model is not None else None
self.init_lr = self.config['lr']
self.min_lr = self.config['min_lr']
params = [p for p in self.model.parameters() if p.requires_grad]
self.optimizer = optim.AdamW(params=params, lr=self.init_lr)
self.num_epochs = self.config['epochs']
if config['resume_path'] is not None:
log_dir = config['resume_path'].split('/')[-2]
resume_log = torch.load(osp.join(osp.join(config['save_path'], config['resume_path'])), map_location='cpu')
self.epoch = resume_log['epoch'] + 1
if 'iteration' in resume_log.keys():
self.iteration = resume_log['iteration']
else:
self.iteration = len(self.train_loader) * self.epoch
self.min_loss = resume_log['min_loss']
else:
self.iteration = 0
self.epoch = 0
self.min_loss = 1e10
now = datetime.datetime.now()
all_tags = [now.strftime("%Y%m%d_%H%M%S")]
dataset_name = merge_tags(self.config['dataset'], '')
all_tags = all_tags + [self.config['network'], 'L' + str(self.config['layers']),
dataset_name,
str(self.config['feature']), 'B' + str(self.config['batch_size']),
'K' + str(self.config['max_keypoints']), 'od' + str(self.config['output_dim']),
'nc' + str(self.config['n_class'])]
if self.config['use_mid_feature']:
all_tags.append('md')
# if self.with_cls:
# all_tags.append(self.config['cls_loss'])
# if self.with_sc:
# all_tags.append(self.config['sc_loss'])
if self.with_aug:
all_tags.append('A')
all_tags.append(self.config['cluster_method'])
log_dir = merge_tags(tags=all_tags, connection='_')
if config['local_rank'] == 0:
self.save_dir = osp.join(self.config['save_path'], log_dir)
os.makedirs(self.save_dir, exist_ok=True)
print("save_dir: ", self.save_dir)
self.log_file = open(osp.join(self.save_dir, "log.txt"), "a+")
save_args_yaml(args=config, save_path=Path(self.save_dir, "args.yaml"))
self.writer = SummaryWriter(self.save_dir)
self.tag = log_dir
self.do_eval = self.config['do_eval']
if self.do_eval:
self.eval_fun = None
self.seq_metric = SeqIOU(n_class=self.config['n_class'], ignored_sids=[0])
def preprocess_input(self, pred):
for k in pred.keys():
if k.find('name') >= 0:
continue
if k != 'image' and k != 'depth':
if type(pred[k]) == torch.Tensor:
pred[k] = Variable(pred[k].float().cuda())
else:
pred[k] = Variable(torch.stack(pred[k]).float().cuda())
if self.with_aug:
new_scores = []
new_descs = []
global_descs = []
with torch.no_grad():
for i, im in enumerate(pred['image']):
img = torch.from_numpy(im[0]).cuda().float().permute(2, 0, 1)
# img = self.img_transforms(img)[None]
if self.img_transforms is not None:
img = self.img_transforms(img)[None]
else:
img = img[None]
out = self.feat_model.extract_local_global(data={'image': img})
global_descs.append(out['global_descriptors'])
seg_scores, seg_descs = self.feat_model.sample(score_map=out['score_map'],
semi_descs=out['mid_features'] if self.config[
'use_mid_feature'] else out['desc_map'],
kpts=pred['keypoints'][i],
norm_desc=self.config['norm_desc']) # [D, N]
new_scores.append(seg_scores[None])
new_descs.append(seg_descs[None])
pred['global_descriptors'] = global_descs
pred['scores'] = torch.cat(new_scores, dim=0)
pred['seg_descriptors'] = torch.cat(new_descs, dim=0).permute(0, 2, 1) # -> [B, N, D]
def process_epoch(self):
self.model.train()
epoch_cls_losses = []
epoch_seg_losses = []
epoch_losses = []
epoch_acc_corr = []
epoch_acc_incorr = []
epoch_cls_acc = []
epoch_sc_losses = []
for bidx, pred in tqdm(enumerate(self.train_loader), total=len(self.train_loader)):
self.preprocess_input(pred)
if 0 <= self.config['its_per_epoch'] <= bidx:
break
data = self.model(pred)
for k, v in pred.items():
pred[k] = v
pred = {**pred, **data}
seg_loss = compute_seg_loss_weight(pred=pred['prediction'],
target=pred['gt_seg'],
background_id=0,
weight_background=0.1)
acc_corr, acc_incorr = compute_corr_incorr(pred=pred['prediction'],
target=pred['gt_seg'],
ignored_ids=[0])
if self.with_cls:
pred_cls_dist = pred['classification']
gt_cls_dist = pred['gt_cls_dist']
if len(pred_cls_dist.shape) > 2:
gt_cls_dist_full = gt_cls_dist.unsqueeze(-1).repeat(1, 1, pred_cls_dist.shape[-1])
else:
gt_cls_dist_full = gt_cls_dist.unsqueeze(-1)
cls_loss = compute_cls_loss_ce(pred=pred_cls_dist, target=gt_cls_dist_full)
loss = seg_loss + cls_loss
# gt_n_seg = pred['gt_n_seg']
cls_acc = compute_cls_corr(pred=pred_cls_dist.squeeze(-1), target=gt_cls_dist)
else:
loss = seg_loss
cls_loss = torch.zeros_like(seg_loss)
cls_acc = torch.zeros_like(seg_loss)
if self.with_sc:
pass
else:
sc_loss = torch.zeros_like(seg_loss)
epoch_losses.append(loss.item())
epoch_seg_losses.append(seg_loss.item())
epoch_cls_losses.append(cls_loss.item())
epoch_sc_losses.append(sc_loss.item())
epoch_acc_corr.append(acc_corr.item())
epoch_acc_incorr.append(acc_incorr.item())
epoch_cls_acc.append(cls_acc.item())
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
self.iteration += 1
lr = min(self.config['lr'] * self.config['decay_rate'] ** (self.iteration - self.config['decay_iter']),
self.config['lr'])
if lr < self.min_lr:
lr = self.min_lr
for param_group in self.optimizer.param_groups:
param_group['lr'] = lr
if self.config['local_rank'] == 0 and bidx % self.config['log_intervals'] == 0:
print_text = 'Epoch [{:d}/{:d}], Step [{:d}/{:d}/{:d}], Loss [s{:.2f}/c{:.2f}/sc{:.2f}/t{:.2f}], Acc [c{:.2f}/{:.2f}/{:.2f}]'.format(
self.epoch,
self.num_epochs, bidx,
len(self.train_loader),
self.iteration,
seg_loss.item(),
cls_loss.item(),
sc_loss.item(),
loss.item(),
np.mean(epoch_acc_corr),
np.mean(epoch_acc_incorr),
np.mean(epoch_cls_acc)
)
print(print_text)
self.log_file.write(print_text + '\n')
info = {
'lr': lr,
'loss': loss.item(),
'cls_loss': cls_loss.item(),
'sc_loss': sc_loss.item(),
'acc_corr': acc_corr.item(),
'acc_incorr': acc_incorr.item(),
'acc_cls': cls_acc.item(),
}
for k, v in info.items():
self.writer.add_scalar(tag=k, scalar_value=v, global_step=self.iteration)
if self.config['local_rank'] == 0:
print_text = 'Epoch [{:d}/{:d}], AVG Loss [s{:.2f}/c{:.2f}/sc{:.2f}/t{:.2f}], Acc [c{:.2f}/{:.2f}/{:.2f}]\n'.format(
self.epoch,
self.num_epochs,
np.mean(epoch_seg_losses),
np.mean(epoch_cls_losses),
np.mean(epoch_sc_losses),
np.mean(epoch_losses),
np.mean(epoch_acc_corr),
np.mean(epoch_acc_incorr),
np.mean(epoch_cls_acc),
)
print(print_text)
self.log_file.write(print_text + '\n')
self.log_file.flush()
return np.mean(epoch_losses)
def eval_seg(self, loader):
print('Start to do evaluation...')
self.model.eval()
self.seq_metric.clear()
mean_iou_day = []
mean_iou_night = []
mean_prec_day = []
mean_prec_night = []
mean_cls_day = []
mean_cls_night = []
for bid, pred in tqdm(enumerate(loader), total=len(loader)):
for k in pred.keys():
if k.find('name') >= 0:
continue
if k != 'image' and k != 'depth':
if type(pred[k]) == torch.Tensor:
pred[k] = Variable(pred[k].float().cuda())
elif type(pred[k]) == np.ndarray:
pred[k] = Variable(torch.from_numpy(pred[k]).float()[None].cuda())
else:
pred[k] = Variable(torch.stack(pred[k]).float().cuda())
if self.with_aug:
with torch.no_grad():
if isinstance(pred['image'][0], list):
img = pred['image'][0][0]
else:
img = pred['image'][0]
img = torch.from_numpy(img).cuda().float().permute(2, 0, 1)
if self.img_transforms is not None:
img = self.img_transforms(img)[None]
else:
img = img[None]
encoder_out = self.feat_model.extract_local_global(data={'image': img})
global_descriptors = [encoder_out['global_descriptors']]
pred['global_descriptors'] = global_descriptors
if self.config['use_mid_feature']:
scores, descs = self.feat_model.sample(score_map=encoder_out['score_map'],
semi_descs=encoder_out['mid_features'],
kpts=pred['keypoints'][0],
norm_desc=self.config['norm_desc'])
# print('eval: ', scores.shape, descs.shape)
pred['scores'] = scores[None]
pred['seg_descriptors'] = descs[None].permute(0, 2, 1) # -> [B, N, D]
else:
pred['seg_descriptors'] = pred['descriptors']
image_name = pred['file_name'][0]
with torch.no_grad():
out = self.model(pred)
pred = {**pred, **out}
pred_seg = torch.max(pred['prediction'], dim=-1)[1] # [B, N, C]
pred_seg = pred_seg[0].cpu().numpy()
gt_seg = pred['gt_seg'][0].cpu().numpy()
iou = compute_iou(pred=pred_seg, target=gt_seg, n_class=self.config['n_class'], ignored_ids=[0])
prec = compute_precision(pred=pred_seg, target=gt_seg, ignored_ids=[0])
if self.with_cls:
pred_cls_dist = pred['classification']
gt_cls_dist = pred['gt_cls_dist']
cls_acc = compute_cls_corr(pred=pred_cls_dist.squeeze(-1), target=gt_cls_dist).item()
else:
cls_acc = 0.
if image_name.find('night') >= 0:
mean_iou_night.append(iou)
mean_prec_night.append(prec)
mean_cls_night.append(cls_acc)
else:
mean_iou_day.append(iou)
mean_prec_day.append(prec)
mean_cls_day.append(cls_acc)
print_txt = 'Eval Epoch {:d}, iou day/night {:.3f}/{:.3f}, prec day/night {:.3f}/{:.3f}, cls day/night {:.3f}/{:.3f}'.format(
self.epoch, np.mean(mean_iou_day), np.mean(mean_iou_night),
np.mean(mean_prec_day), np.mean(mean_prec_night),
np.mean(mean_cls_day), np.mean(mean_cls_night))
self.log_file.write(print_txt + '\n')
print(print_txt)
info = {
'mean_iou_day': np.mean(mean_iou_day),
'mean_iou_night': np.mean(mean_iou_night),
'mean_prec_day': np.mean(mean_prec_day),
'mean_prec_night': np.mean(mean_prec_night),
}
for k, v in info.items():
self.writer.add_scalar(tag=k, scalar_value=v, global_step=self.epoch)
return np.mean(mean_prec_night)
def train(self):
if self.config['local_rank'] == 0:
print('Start to train the model from epoch: {:d}'.format(self.epoch))
hist_values = []
min_value = self.min_loss
epoch = self.epoch
while epoch < self.num_epochs:
if self.config['with_dist']:
self.train_loader.sampler.set_epoch(epoch=epoch)
self.epoch = epoch
train_loss = self.process_epoch()
# return with loss INF/NAN
if train_loss is None:
continue
if self.config['local_rank'] == 0:
if self.do_eval and self.epoch % self.config['eval_n_epoch'] == 0: # and self.epoch >= 50:
eval_ratio = self.eval_seg(loader=self.eval_loader)
hist_values.append(eval_ratio) # higher better
else:
hist_values.append(-train_loss) # lower better
checkpoint_path = os.path.join(self.save_dir,
'%s.%02d.pth' % (self.config['network'], self.epoch))
checkpoint = {
'epoch': self.epoch,
'iteration': self.iteration,
'model': self.model.state_dict(),
'min_loss': min_value,
}
# for multi-gpu training
if len(self.config['gpu']) > 1:
checkpoint['model'] = self.model.module.state_dict()
torch.save(checkpoint, checkpoint_path)
if hist_values[-1] < min_value:
min_value = hist_values[-1]
best_checkpoint_path = os.path.join(
self.save_dir,
'%s.best.pth' % (self.tag)
)
shutil.copy(checkpoint_path, best_checkpoint_path)
# important!!!
epoch += 1
if self.config['local_rank'] == 0:
self.log_file.close()