Spaces:
Running
Running
# -*- coding: UTF-8 -*- | |
'''================================================= | |
@Project -> File pram -> trainer | |
@IDE PyCharm | |
@Author fx221@cam.ac.uk | |
@Date 29/01/2024 15:04 | |
==================================================''' | |
import datetime | |
import os | |
import os.path as osp | |
import numpy as np | |
from pathlib import Path | |
from tensorboardX import SummaryWriter | |
from tqdm import tqdm | |
import torch.optim as optim | |
import torch.nn.functional as F | |
import shutil | |
import torch | |
from torch.autograd import Variable | |
from tools.common import save_args_yaml, merge_tags | |
from tools.metrics import compute_iou, compute_precision, SeqIOU, compute_corr_incorr, compute_seg_loss_weight | |
from tools.metrics import compute_cls_loss_ce, compute_cls_corr | |
class Trainer: | |
def __init__(self, model, train_loader, feat_model=None, eval_loader=None, config=None, img_transforms=None): | |
self.model = model | |
self.train_loader = train_loader | |
self.eval_loader = eval_loader | |
self.config = config | |
self.with_aug = self.config['with_aug'] | |
self.with_cls = False # self.config['with_cls'] | |
self.with_sc = False # self.config['with_sc'] | |
self.img_transforms = img_transforms | |
self.feat_model = feat_model.cuda().eval() if feat_model is not None else None | |
self.init_lr = self.config['lr'] | |
self.min_lr = self.config['min_lr'] | |
params = [p for p in self.model.parameters() if p.requires_grad] | |
self.optimizer = optim.AdamW(params=params, lr=self.init_lr) | |
self.num_epochs = self.config['epochs'] | |
if config['resume_path'] is not None: | |
log_dir = config['resume_path'].split('/')[-2] | |
resume_log = torch.load(osp.join(osp.join(config['save_path'], config['resume_path'])), map_location='cpu') | |
self.epoch = resume_log['epoch'] + 1 | |
if 'iteration' in resume_log.keys(): | |
self.iteration = resume_log['iteration'] | |
else: | |
self.iteration = len(self.train_loader) * self.epoch | |
self.min_loss = resume_log['min_loss'] | |
else: | |
self.iteration = 0 | |
self.epoch = 0 | |
self.min_loss = 1e10 | |
now = datetime.datetime.now() | |
all_tags = [now.strftime("%Y%m%d_%H%M%S")] | |
dataset_name = merge_tags(self.config['dataset'], '') | |
all_tags = all_tags + [self.config['network'], 'L' + str(self.config['layers']), | |
dataset_name, | |
str(self.config['feature']), 'B' + str(self.config['batch_size']), | |
'K' + str(self.config['max_keypoints']), 'od' + str(self.config['output_dim']), | |
'nc' + str(self.config['n_class'])] | |
if self.config['use_mid_feature']: | |
all_tags.append('md') | |
# if self.with_cls: | |
# all_tags.append(self.config['cls_loss']) | |
# if self.with_sc: | |
# all_tags.append(self.config['sc_loss']) | |
if self.with_aug: | |
all_tags.append('A') | |
all_tags.append(self.config['cluster_method']) | |
log_dir = merge_tags(tags=all_tags, connection='_') | |
if config['local_rank'] == 0: | |
self.save_dir = osp.join(self.config['save_path'], log_dir) | |
os.makedirs(self.save_dir, exist_ok=True) | |
print("save_dir: ", self.save_dir) | |
self.log_file = open(osp.join(self.save_dir, "log.txt"), "a+") | |
save_args_yaml(args=config, save_path=Path(self.save_dir, "args.yaml")) | |
self.writer = SummaryWriter(self.save_dir) | |
self.tag = log_dir | |
self.do_eval = self.config['do_eval'] | |
if self.do_eval: | |
self.eval_fun = None | |
self.seq_metric = SeqIOU(n_class=self.config['n_class'], ignored_sids=[0]) | |
def preprocess_input(self, pred): | |
for k in pred.keys(): | |
if k.find('name') >= 0: | |
continue | |
if k != 'image' and k != 'depth': | |
if type(pred[k]) == torch.Tensor: | |
pred[k] = Variable(pred[k].float().cuda()) | |
else: | |
pred[k] = Variable(torch.stack(pred[k]).float().cuda()) | |
if self.with_aug: | |
new_scores = [] | |
new_descs = [] | |
global_descs = [] | |
with torch.no_grad(): | |
for i, im in enumerate(pred['image']): | |
img = torch.from_numpy(im[0]).cuda().float().permute(2, 0, 1) | |
# img = self.img_transforms(img)[None] | |
if self.img_transforms is not None: | |
img = self.img_transforms(img)[None] | |
else: | |
img = img[None] | |
out = self.feat_model.extract_local_global(data={'image': img}) | |
global_descs.append(out['global_descriptors']) | |
seg_scores, seg_descs = self.feat_model.sample(score_map=out['score_map'], | |
semi_descs=out['mid_features'] if self.config[ | |
'use_mid_feature'] else out['desc_map'], | |
kpts=pred['keypoints'][i], | |
norm_desc=self.config['norm_desc']) # [D, N] | |
new_scores.append(seg_scores[None]) | |
new_descs.append(seg_descs[None]) | |
pred['global_descriptors'] = global_descs | |
pred['scores'] = torch.cat(new_scores, dim=0) | |
pred['seg_descriptors'] = torch.cat(new_descs, dim=0).permute(0, 2, 1) # -> [B, N, D] | |
def process_epoch(self): | |
self.model.train() | |
epoch_cls_losses = [] | |
epoch_seg_losses = [] | |
epoch_losses = [] | |
epoch_acc_corr = [] | |
epoch_acc_incorr = [] | |
epoch_cls_acc = [] | |
epoch_sc_losses = [] | |
for bidx, pred in tqdm(enumerate(self.train_loader), total=len(self.train_loader)): | |
self.preprocess_input(pred) | |
if 0 <= self.config['its_per_epoch'] <= bidx: | |
break | |
data = self.model(pred) | |
for k, v in pred.items(): | |
pred[k] = v | |
pred = {**pred, **data} | |
seg_loss = compute_seg_loss_weight(pred=pred['prediction'], | |
target=pred['gt_seg'], | |
background_id=0, | |
weight_background=0.1) | |
acc_corr, acc_incorr = compute_corr_incorr(pred=pred['prediction'], | |
target=pred['gt_seg'], | |
ignored_ids=[0]) | |
if self.with_cls: | |
pred_cls_dist = pred['classification'] | |
gt_cls_dist = pred['gt_cls_dist'] | |
if len(pred_cls_dist.shape) > 2: | |
gt_cls_dist_full = gt_cls_dist.unsqueeze(-1).repeat(1, 1, pred_cls_dist.shape[-1]) | |
else: | |
gt_cls_dist_full = gt_cls_dist.unsqueeze(-1) | |
cls_loss = compute_cls_loss_ce(pred=pred_cls_dist, target=gt_cls_dist_full) | |
loss = seg_loss + cls_loss | |
# gt_n_seg = pred['gt_n_seg'] | |
cls_acc = compute_cls_corr(pred=pred_cls_dist.squeeze(-1), target=gt_cls_dist) | |
else: | |
loss = seg_loss | |
cls_loss = torch.zeros_like(seg_loss) | |
cls_acc = torch.zeros_like(seg_loss) | |
if self.with_sc: | |
pass | |
else: | |
sc_loss = torch.zeros_like(seg_loss) | |
epoch_losses.append(loss.item()) | |
epoch_seg_losses.append(seg_loss.item()) | |
epoch_cls_losses.append(cls_loss.item()) | |
epoch_sc_losses.append(sc_loss.item()) | |
epoch_acc_corr.append(acc_corr.item()) | |
epoch_acc_incorr.append(acc_incorr.item()) | |
epoch_cls_acc.append(cls_acc.item()) | |
self.optimizer.zero_grad() | |
loss.backward() | |
self.optimizer.step() | |
self.iteration += 1 | |
lr = min(self.config['lr'] * self.config['decay_rate'] ** (self.iteration - self.config['decay_iter']), | |
self.config['lr']) | |
if lr < self.min_lr: | |
lr = self.min_lr | |
for param_group in self.optimizer.param_groups: | |
param_group['lr'] = lr | |
if self.config['local_rank'] == 0 and bidx % self.config['log_intervals'] == 0: | |
print_text = 'Epoch [{:d}/{:d}], Step [{:d}/{:d}/{:d}], Loss [s{:.2f}/c{:.2f}/sc{:.2f}/t{:.2f}], Acc [c{:.2f}/{:.2f}/{:.2f}]'.format( | |
self.epoch, | |
self.num_epochs, bidx, | |
len(self.train_loader), | |
self.iteration, | |
seg_loss.item(), | |
cls_loss.item(), | |
sc_loss.item(), | |
loss.item(), | |
np.mean(epoch_acc_corr), | |
np.mean(epoch_acc_incorr), | |
np.mean(epoch_cls_acc) | |
) | |
print(print_text) | |
self.log_file.write(print_text + '\n') | |
info = { | |
'lr': lr, | |
'loss': loss.item(), | |
'cls_loss': cls_loss.item(), | |
'sc_loss': sc_loss.item(), | |
'acc_corr': acc_corr.item(), | |
'acc_incorr': acc_incorr.item(), | |
'acc_cls': cls_acc.item(), | |
} | |
for k, v in info.items(): | |
self.writer.add_scalar(tag=k, scalar_value=v, global_step=self.iteration) | |
if self.config['local_rank'] == 0: | |
print_text = 'Epoch [{:d}/{:d}], AVG Loss [s{:.2f}/c{:.2f}/sc{:.2f}/t{:.2f}], Acc [c{:.2f}/{:.2f}/{:.2f}]\n'.format( | |
self.epoch, | |
self.num_epochs, | |
np.mean(epoch_seg_losses), | |
np.mean(epoch_cls_losses), | |
np.mean(epoch_sc_losses), | |
np.mean(epoch_losses), | |
np.mean(epoch_acc_corr), | |
np.mean(epoch_acc_incorr), | |
np.mean(epoch_cls_acc), | |
) | |
print(print_text) | |
self.log_file.write(print_text + '\n') | |
self.log_file.flush() | |
return np.mean(epoch_losses) | |
def eval_seg(self, loader): | |
print('Start to do evaluation...') | |
self.model.eval() | |
self.seq_metric.clear() | |
mean_iou_day = [] | |
mean_iou_night = [] | |
mean_prec_day = [] | |
mean_prec_night = [] | |
mean_cls_day = [] | |
mean_cls_night = [] | |
for bid, pred in tqdm(enumerate(loader), total=len(loader)): | |
for k in pred.keys(): | |
if k.find('name') >= 0: | |
continue | |
if k != 'image' and k != 'depth': | |
if type(pred[k]) == torch.Tensor: | |
pred[k] = Variable(pred[k].float().cuda()) | |
elif type(pred[k]) == np.ndarray: | |
pred[k] = Variable(torch.from_numpy(pred[k]).float()[None].cuda()) | |
else: | |
pred[k] = Variable(torch.stack(pred[k]).float().cuda()) | |
if self.with_aug: | |
with torch.no_grad(): | |
if isinstance(pred['image'][0], list): | |
img = pred['image'][0][0] | |
else: | |
img = pred['image'][0] | |
img = torch.from_numpy(img).cuda().float().permute(2, 0, 1) | |
if self.img_transforms is not None: | |
img = self.img_transforms(img)[None] | |
else: | |
img = img[None] | |
encoder_out = self.feat_model.extract_local_global(data={'image': img}) | |
global_descriptors = [encoder_out['global_descriptors']] | |
pred['global_descriptors'] = global_descriptors | |
if self.config['use_mid_feature']: | |
scores, descs = self.feat_model.sample(score_map=encoder_out['score_map'], | |
semi_descs=encoder_out['mid_features'], | |
kpts=pred['keypoints'][0], | |
norm_desc=self.config['norm_desc']) | |
# print('eval: ', scores.shape, descs.shape) | |
pred['scores'] = scores[None] | |
pred['seg_descriptors'] = descs[None].permute(0, 2, 1) # -> [B, N, D] | |
else: | |
pred['seg_descriptors'] = pred['descriptors'] | |
image_name = pred['file_name'][0] | |
with torch.no_grad(): | |
out = self.model(pred) | |
pred = {**pred, **out} | |
pred_seg = torch.max(pred['prediction'], dim=-1)[1] # [B, N, C] | |
pred_seg = pred_seg[0].cpu().numpy() | |
gt_seg = pred['gt_seg'][0].cpu().numpy() | |
iou = compute_iou(pred=pred_seg, target=gt_seg, n_class=self.config['n_class'], ignored_ids=[0]) | |
prec = compute_precision(pred=pred_seg, target=gt_seg, ignored_ids=[0]) | |
if self.with_cls: | |
pred_cls_dist = pred['classification'] | |
gt_cls_dist = pred['gt_cls_dist'] | |
cls_acc = compute_cls_corr(pred=pred_cls_dist.squeeze(-1), target=gt_cls_dist).item() | |
else: | |
cls_acc = 0. | |
if image_name.find('night') >= 0: | |
mean_iou_night.append(iou) | |
mean_prec_night.append(prec) | |
mean_cls_night.append(cls_acc) | |
else: | |
mean_iou_day.append(iou) | |
mean_prec_day.append(prec) | |
mean_cls_day.append(cls_acc) | |
print_txt = 'Eval Epoch {:d}, iou day/night {:.3f}/{:.3f}, prec day/night {:.3f}/{:.3f}, cls day/night {:.3f}/{:.3f}'.format( | |
self.epoch, np.mean(mean_iou_day), np.mean(mean_iou_night), | |
np.mean(mean_prec_day), np.mean(mean_prec_night), | |
np.mean(mean_cls_day), np.mean(mean_cls_night)) | |
self.log_file.write(print_txt + '\n') | |
print(print_txt) | |
info = { | |
'mean_iou_day': np.mean(mean_iou_day), | |
'mean_iou_night': np.mean(mean_iou_night), | |
'mean_prec_day': np.mean(mean_prec_day), | |
'mean_prec_night': np.mean(mean_prec_night), | |
} | |
for k, v in info.items(): | |
self.writer.add_scalar(tag=k, scalar_value=v, global_step=self.epoch) | |
return np.mean(mean_prec_night) | |
def train(self): | |
if self.config['local_rank'] == 0: | |
print('Start to train the model from epoch: {:d}'.format(self.epoch)) | |
hist_values = [] | |
min_value = self.min_loss | |
epoch = self.epoch | |
while epoch < self.num_epochs: | |
if self.config['with_dist']: | |
self.train_loader.sampler.set_epoch(epoch=epoch) | |
self.epoch = epoch | |
train_loss = self.process_epoch() | |
# return with loss INF/NAN | |
if train_loss is None: | |
continue | |
if self.config['local_rank'] == 0: | |
if self.do_eval and self.epoch % self.config['eval_n_epoch'] == 0: # and self.epoch >= 50: | |
eval_ratio = self.eval_seg(loader=self.eval_loader) | |
hist_values.append(eval_ratio) # higher better | |
else: | |
hist_values.append(-train_loss) # lower better | |
checkpoint_path = os.path.join(self.save_dir, | |
'%s.%02d.pth' % (self.config['network'], self.epoch)) | |
checkpoint = { | |
'epoch': self.epoch, | |
'iteration': self.iteration, | |
'model': self.model.state_dict(), | |
'min_loss': min_value, | |
} | |
# for multi-gpu training | |
if len(self.config['gpu']) > 1: | |
checkpoint['model'] = self.model.module.state_dict() | |
torch.save(checkpoint, checkpoint_path) | |
if hist_values[-1] < min_value: | |
min_value = hist_values[-1] | |
best_checkpoint_path = os.path.join( | |
self.save_dir, | |
'%s.best.pth' % (self.tag) | |
) | |
shutil.copy(checkpoint_path, best_checkpoint_path) | |
# important!!! | |
epoch += 1 | |
if self.config['local_rank'] == 0: | |
self.log_file.close() | |