Spaces:
Running
Running
import copy | |
import datetime | |
import os | |
import random | |
import time | |
import numpy as np | |
import torch | |
from tqdm import tqdm | |
from openrec.losses import build_loss | |
from openrec.metrics import build_metric | |
from openrec.modeling import build_model | |
from openrec.optimizer import build_optimizer | |
from openrec.postprocess import build_post_process | |
from tools.data import build_dataloader | |
from tools.utils.ckpt import load_ckpt, save_ckpt | |
from tools.utils.logging import get_logger | |
from tools.utils.stats import TrainingStats | |
from tools.utils.utility import AverageMeter | |
__all__ = ['Trainer'] | |
def get_parameter_number(model): | |
total_num = sum(p.numel() for p in model.parameters()) | |
trainable_num = sum(p.numel() for p in model.parameters() | |
if p.requires_grad) | |
return {'Total': total_num, 'Trainable': trainable_num} | |
class Trainer(object): | |
def __init__(self, cfg, mode='train'): | |
self.cfg = cfg.cfg | |
self.local_rank = (int(os.environ['LOCAL_RANK']) | |
if 'LOCAL_RANK' in os.environ else 0) | |
self.set_device(self.cfg['Global']['device']) | |
mode = mode.lower() | |
assert mode in [ | |
'train_eval', | |
'train', | |
'eval', | |
'test', | |
], 'mode should be train, eval and test' | |
if torch.cuda.device_count() > 1 and 'train' in mode: | |
torch.distributed.init_process_group(backend='nccl') | |
torch.cuda.set_device(self.device) | |
self.cfg['Global']['distributed'] = True | |
else: | |
self.cfg['Global']['distributed'] = False | |
self.local_rank = 0 | |
self.cfg['Global']['output_dir'] = self.cfg['Global'].get( | |
'output_dir', 'output') | |
os.makedirs(self.cfg['Global']['output_dir'], exist_ok=True) | |
self.writer = None | |
if self.local_rank == 0 and self.cfg['Global'][ | |
'use_tensorboard'] and 'train' in mode: | |
from torch.utils.tensorboard import SummaryWriter | |
self.writer = SummaryWriter(self.cfg['Global']['output_dir']) | |
self.logger = get_logger( | |
'openrec', | |
os.path.join(self.cfg['Global']['output_dir'], 'train.log') | |
if 'train' in mode else None, | |
) | |
cfg.print_cfg(self.logger.info) | |
if self.cfg['Global']['device'] == 'gpu' and self.device.type == 'cpu': | |
self.logger.info('cuda is not available, auto switch to cpu') | |
self.grad_clip_val = self.cfg['Global'].get('grad_clip_val', 0) | |
self.all_ema = self.cfg['Global'].get('all_ema', True) | |
self.use_ema = self.cfg['Global'].get('use_ema', True) | |
self.set_random_seed(self.cfg['Global'].get('seed', 48)) | |
# build data loader | |
self.train_dataloader = None | |
if 'train' in mode: | |
cfg.save( | |
os.path.join(self.cfg['Global']['output_dir'], 'config.yml'), | |
self.cfg) | |
self.train_dataloader = build_dataloader(self.cfg, 'Train', | |
self.logger) | |
self.logger.info( | |
f'train dataloader has {len(self.train_dataloader)} iters') | |
self.valid_dataloader = None | |
if 'eval' in mode and self.cfg['Eval']: | |
self.valid_dataloader = build_dataloader(self.cfg, 'Eval', | |
self.logger) | |
self.logger.info( | |
f'valid dataloader has {len(self.valid_dataloader)} iters') | |
# build post process | |
self.post_process_class = build_post_process(self.cfg['PostProcess'], | |
self.cfg['Global']) | |
# build model | |
# for rec algorithm | |
char_num = self.post_process_class.get_character_num() | |
self.cfg['Architecture']['Decoder']['out_channels'] = char_num | |
self.model = build_model(self.cfg['Architecture']) | |
self.logger.info(get_parameter_number(model=self.model)) | |
self.model = self.model.to(self.device) | |
if self.local_rank == 0: | |
ema_model = build_model(self.cfg['Architecture']) | |
self.ema_model = ema_model.to(self.device) | |
self.ema_model.eval() | |
use_sync_bn = self.cfg['Global'].get('use_sync_bn', False) | |
if use_sync_bn: | |
self.model = torch.nn.SyncBatchNorm.convert_sync_batchnorm( | |
self.model) | |
self.logger.info('convert_sync_batchnorm') | |
# build loss | |
self.loss_class = build_loss(self.cfg['Loss']) | |
self.optimizer, self.lr_scheduler = None, None | |
if self.train_dataloader is not None: | |
# build optim | |
self.optimizer, self.lr_scheduler = build_optimizer( | |
self.cfg['Optimizer'], | |
self.cfg['LRScheduler'], | |
epochs=self.cfg['Global']['epoch_num'], | |
step_each_epoch=len(self.train_dataloader), | |
model=self.model, | |
) | |
self.eval_class = build_metric(self.cfg['Metric']) | |
self.status = load_ckpt(self.model, self.cfg, self.optimizer, | |
self.lr_scheduler) | |
if self.cfg['Global']['distributed']: | |
self.model = torch.nn.parallel.DistributedDataParallel( | |
self.model, [self.local_rank], find_unused_parameters=False) | |
# amp | |
self.scaler = (torch.cuda.amp.GradScaler() if self.cfg['Global'].get( | |
'use_amp', False) else None) | |
self.logger.info( | |
f'run with torch {torch.__version__} and device {self.device}') | |
def load_params(self, params): | |
self.model.load_state_dict(params) | |
def set_random_seed(self, seed): | |
torch.manual_seed(seed) # 为CPU设置随机种子 | |
if self.device.type == 'cuda': | |
torch.backends.cudnn.benchmark = True | |
torch.cuda.manual_seed(seed) # 为当前GPU设置随机种子 | |
torch.cuda.manual_seed_all(seed) # 为所有GPU设置随机种子 | |
random.seed(seed) | |
np.random.seed(seed) | |
def set_device(self, device): | |
if device == 'gpu' and torch.cuda.is_available(): | |
device = torch.device(f'cuda:{self.local_rank}') | |
else: | |
device = torch.device('cpu') | |
self.device = device | |
def train(self): | |
cal_metric_during_train = self.cfg['Global'].get( | |
'cal_metric_during_train', False) | |
log_smooth_window = self.cfg['Global']['log_smooth_window'] | |
epoch_num = self.cfg['Global']['epoch_num'] | |
print_batch_step = self.cfg['Global']['print_batch_step'] | |
eval_epoch_step = self.cfg['Global'].get('eval_epoch_step', 1) | |
start_eval_epoch = 0 | |
if self.valid_dataloader is not None: | |
if type(eval_epoch_step) == list and len(eval_epoch_step) >= 2: | |
start_eval_epoch = eval_epoch_step[0] | |
eval_epoch_step = eval_epoch_step[1] | |
if len(self.valid_dataloader) == 0: | |
start_eval_epoch = 1e111 | |
self.logger.info( | |
'No Images in eval dataset, evaluation during training will be disabled' | |
) | |
self.logger.info( | |
f'During the training process, after the {start_eval_epoch}th epoch, ' | |
f'an evaluation is run every {eval_epoch_step} epoch') | |
else: | |
start_eval_epoch = 1e111 | |
eval_batch_step = self.cfg['Global']['eval_batch_step'] | |
global_step = self.status.get('global_step', 0) | |
start_eval_step = 0 | |
if type(eval_batch_step) == list and len(eval_batch_step) >= 2: | |
start_eval_step = eval_batch_step[0] | |
eval_batch_step = eval_batch_step[1] | |
if len(self.valid_dataloader) == 0: | |
self.logger.info( | |
'No Images in eval dataset, evaluation during training ' | |
'will be disabled') | |
start_eval_step = 1e111 | |
self.logger.info( | |
'During the training process, after the {}th iteration, ' | |
'an evaluation is run every {} iterations'.format( | |
start_eval_step, eval_batch_step)) | |
start_epoch = self.status.get('epoch', 1) | |
best_metric = self.status.get('metrics', {}) | |
if self.eval_class.main_indicator not in best_metric: | |
best_metric[self.eval_class.main_indicator] = 0 | |
ema_best_metric = self.status.get('metrics', {}) | |
ema_best_metric[self.eval_class.main_indicator] = 0 | |
train_stats = TrainingStats(log_smooth_window, ['lr']) | |
self.model.train() | |
total_samples = 0 | |
train_reader_cost = 0.0 | |
train_batch_cost = 0.0 | |
best_iter = 0 | |
ema_stpe = 1 | |
ema_eval_iter = 0 | |
loss_avg = 0. | |
reader_start = time.time() | |
eta_meter = AverageMeter() | |
for epoch in range(start_epoch, epoch_num + 1): | |
if self.train_dataloader.dataset.need_reset: | |
self.train_dataloader = build_dataloader( | |
self.cfg, | |
'Train', | |
self.logger, | |
epoch=epoch % 20 if epoch % 20 != 0 else 20, | |
) | |
for idx, batch in enumerate(self.train_dataloader): | |
batch = [t.to(self.device) for t in batch] | |
self.optimizer.zero_grad() | |
train_reader_cost += time.time() - reader_start | |
# use amp | |
if self.scaler: | |
with torch.cuda.amp.autocast(): | |
preds = self.model(batch[0], data=batch[1:]) | |
loss = self.loss_class(preds, batch) | |
self.scaler.scale(loss['loss']).backward() | |
if self.grad_clip_val > 0: | |
torch.nn.utils.clip_grad_norm_( | |
self.model.parameters(), | |
max_norm=self.grad_clip_val) | |
self.scaler.step(self.optimizer) | |
self.scaler.update() | |
else: | |
preds = self.model(batch[0], data=batch[1:]) | |
loss = self.loss_class(preds, batch) | |
avg_loss = loss['loss'] | |
avg_loss.backward() | |
if self.grad_clip_val > 0: | |
torch.nn.utils.clip_grad_norm_( | |
self.model.parameters(), | |
max_norm=self.grad_clip_val) | |
self.optimizer.step() | |
if cal_metric_during_train: # only rec and cls need | |
post_result = self.post_process_class(preds, | |
batch, | |
training=True) | |
self.eval_class(post_result, batch, training=True) | |
metric = self.eval_class.get_metric() | |
train_stats.update(metric) | |
train_batch_time = time.time() - reader_start | |
train_batch_cost += train_batch_time | |
eta_meter.update(train_batch_time) | |
global_step += 1 | |
total_samples += len(batch[0]) | |
self.lr_scheduler.step() | |
if self.local_rank == 0 and self.use_ema and epoch > ( | |
epoch_num - epoch_num // 10): | |
with torch.no_grad(): | |
loss_currn = loss['loss'].detach().cpu().numpy().mean() | |
loss_avg = ((loss_avg * | |
(ema_stpe - 1)) + loss_currn) / (ema_stpe) | |
if ema_stpe == 1: | |
# current_weight = copy.deepcopy(self.model.module.state_dict()) | |
ema_state_dict = copy.deepcopy( | |
self.model.module.state_dict() if self. | |
cfg['Global']['distributed'] else self.model. | |
state_dict()) | |
self.ema_model.load_state_dict(ema_state_dict) | |
# if global_step > (epoch_num - epoch_num//10)*max_iter: | |
elif loss_currn <= loss_avg or self.all_ema: | |
# eval_batch_step = 500 | |
current_weight = copy.deepcopy( | |
self.model.module.state_dict() if self. | |
cfg['Global']['distributed'] else self.model. | |
state_dict()) | |
k1 = 1 / (ema_stpe + 1) | |
k2 = 1 - k1 | |
for k, v in ema_state_dict.items(): | |
# v = (v * (ema_stpe - 1) + current_weight[k])/ema_stpe | |
v = v * k2 + current_weight[k] * k1 | |
# v.req = True | |
ema_state_dict[k] = v | |
# ema_stpe += 1 | |
self.ema_model.load_state_dict(ema_state_dict) | |
ema_stpe += 1 | |
if global_step > start_eval_step and ( | |
global_step - | |
start_eval_step) % eval_batch_step == 0: | |
ema_cur_metric = self.eval_ema() | |
ema_cur_metric_str = f"cur ema metric, {', '.join(['{}: {}'.format(k, v) for k, v in ema_cur_metric.items()])}" | |
self.logger.info(ema_cur_metric_str) | |
state = { | |
'epoch': epoch, | |
'global_step': global_step, | |
'state_dict': self.ema_model.state_dict(), | |
'optimizer': None, | |
'scheduler': None, | |
'config': self.cfg, | |
'metrics': ema_cur_metric, | |
} | |
save_path = os.path.join( | |
self.cfg['Global']['output_dir'], | |
'ema_' + str(ema_eval_iter) + '.pth') | |
torch.save(state, save_path) | |
self.logger.info(f'save ema ckpt to {save_path}') | |
ema_eval_iter += 1 | |
if ema_cur_metric[self.eval_class. | |
main_indicator] >= ema_best_metric[ | |
self.eval_class.main_indicator]: | |
ema_best_metric.update(ema_cur_metric) | |
ema_best_metric['best_epoch'] = epoch | |
best_ema_str = f"best metric, {', '.join(['{}: {}'.format(k, v) for k, v in ema_best_metric.items()])}" | |
self.logger.info(best_ema_str) | |
# logger | |
stats = { | |
k: float(v) | |
if v.shape == [] else v.detach().cpu().numpy().mean() | |
for k, v in loss.items() | |
} | |
stats['lr'] = self.lr_scheduler.get_last_lr()[0] | |
train_stats.update(stats) | |
if self.writer is not None: | |
for k, v in train_stats.get().items(): | |
self.writer.add_scalar(f'TRAIN/{k}', v, global_step) | |
if self.local_rank == 0 and ( | |
(global_step > 0 and global_step % print_batch_step == 0) | |
or (idx >= len(self.train_dataloader) - 1)): | |
logs = train_stats.log() | |
eta_sec = ( | |
(epoch_num + 1 - epoch) * len(self.train_dataloader) - | |
idx - 1) * eta_meter.avg | |
eta_sec_format = str( | |
datetime.timedelta(seconds=int(eta_sec))) | |
strs = ( | |
f'epoch: [{epoch}/{epoch_num}], global_step: {global_step}, {logs}, ' | |
f'avg_reader_cost: {train_reader_cost / print_batch_step:.5f} s, ' | |
f'avg_batch_cost: {train_batch_cost / print_batch_step:.5f} s, ' | |
f'avg_samples: {total_samples / print_batch_step}, ' | |
f'ips: {total_samples / train_batch_cost:.5f} samples/s, ' | |
f'eta: {eta_sec_format}') | |
self.logger.info(strs) | |
total_samples = 0 | |
train_reader_cost = 0.0 | |
train_batch_cost = 0.0 | |
reader_start = time.time() | |
# eval | |
if (global_step > start_eval_step and | |
(global_step - start_eval_step) % eval_batch_step | |
== 0) and self.local_rank == 0: | |
cur_metric = self.eval() | |
cur_metric_str = f"cur metric, {', '.join(['{}: {}'.format(k, v) for k, v in cur_metric.items()])}" | |
self.logger.info(cur_metric_str) | |
# logger metric | |
if self.writer is not None: | |
for k, v in cur_metric.items(): | |
if isinstance(v, (float, int)): | |
self.writer.add_scalar(f'EVAL/{k}', | |
cur_metric[k], | |
global_step) | |
if (cur_metric[self.eval_class.main_indicator] >= | |
best_metric[self.eval_class.main_indicator]): | |
best_metric.update(cur_metric) | |
best_metric['best_epoch'] = epoch | |
if self.writer is not None: | |
self.writer.add_scalar( | |
f'EVAL/best_{self.eval_class.main_indicator}', | |
best_metric[self.eval_class.main_indicator], | |
global_step, | |
) | |
if epoch > (epoch_num - epoch_num // 10 - 2): | |
save_ckpt(self.model, | |
self.cfg, | |
self.optimizer, | |
self.lr_scheduler, | |
epoch, | |
global_step, | |
best_metric, | |
is_best=True, | |
prefix='best_' + str(best_iter)) | |
best_iter += 1 | |
# else: | |
save_ckpt(self.model, | |
self.cfg, | |
self.optimizer, | |
self.lr_scheduler, | |
epoch, | |
global_step, | |
best_metric, | |
is_best=True, | |
prefix=None) | |
best_str = f"best metric, {', '.join(['{}: {}'.format(k, v) for k, v in best_metric.items()])}" | |
self.logger.info(best_str) | |
if self.local_rank == 0 and epoch > start_eval_epoch and ( | |
epoch - start_eval_epoch) % eval_epoch_step == 0: | |
cur_metric = self.eval() | |
cur_metric_str = f"cur metric, {', '.join(['{}: {}'.format(k, v) for k, v in cur_metric.items()])}" | |
self.logger.info(cur_metric_str) | |
# logger metric | |
if self.writer is not None: | |
for k, v in cur_metric.items(): | |
if isinstance(v, (float, int)): | |
self.writer.add_scalar(f'EVAL/{k}', cur_metric[k], | |
global_step) | |
if (cur_metric[self.eval_class.main_indicator] >= | |
best_metric[self.eval_class.main_indicator]): | |
best_metric.update(cur_metric) | |
best_metric['best_epoch'] = epoch | |
if self.writer is not None: | |
self.writer.add_scalar( | |
f'EVAL/best_{self.eval_class.main_indicator}', | |
best_metric[self.eval_class.main_indicator], | |
global_step, | |
) | |
if epoch > (epoch_num - epoch_num // 10 - 2): | |
save_ckpt(self.model, | |
self.cfg, | |
self.optimizer, | |
self.lr_scheduler, | |
epoch, | |
global_step, | |
best_metric, | |
is_best=True, | |
prefix='best_' + str(best_iter)) | |
best_iter += 1 | |
# else: | |
save_ckpt(self.model, | |
self.cfg, | |
self.optimizer, | |
self.lr_scheduler, | |
epoch, | |
global_step, | |
best_metric, | |
is_best=True, | |
prefix=None) | |
best_str = f"best metric, {', '.join(['{}: {}'.format(k, v) for k, v in best_metric.items()])}" | |
self.logger.info(best_str) | |
if self.local_rank == 0: | |
save_ckpt(self.model, | |
self.cfg, | |
self.optimizer, | |
self.lr_scheduler, | |
epoch, | |
global_step, | |
best_metric, | |
is_best=False, | |
prefix=None) | |
if epoch > (epoch_num - epoch_num // 10 - 2): | |
save_ckpt(self.model, | |
self.cfg, | |
self.optimizer, | |
self.lr_scheduler, | |
epoch, | |
global_step, | |
best_metric, | |
is_best=False, | |
prefix='epoch_' + str(epoch)) | |
if self.use_ema and epoch > (epoch_num - epoch_num // 10): | |
# if global_step > start_eval_step and (global_step - start_eval_step) % eval_batch_step == 0: | |
ema_cur_metric = self.eval_ema() | |
ema_cur_metric_str = f"cur ema metric, {', '.join(['{}: {}'.format(k, v) for k, v in ema_cur_metric.items()])}" | |
self.logger.info(ema_cur_metric_str) | |
state = { | |
'epoch': epoch, | |
'global_step': global_step, | |
'state_dict': self.ema_model.state_dict(), | |
'optimizer': None, | |
'scheduler': None, | |
'config': self.cfg, | |
'metrics': ema_cur_metric, | |
} | |
save_path = os.path.join( | |
self.cfg['Global']['output_dir'], | |
'ema_' + str(ema_eval_iter) + '.pth') | |
torch.save(state, save_path) | |
self.logger.info(f'save ema ckpt to {save_path}') | |
ema_eval_iter += 1 | |
if (ema_cur_metric[self.eval_class.main_indicator] >= | |
ema_best_metric[self.eval_class.main_indicator]): | |
ema_best_metric.update(ema_cur_metric) | |
ema_best_metric['best_epoch'] = epoch | |
# ema_cur_metric_str = f"best ema metric, {', '.join(['{}: {}'.format(k, v) for k, v in ema_best_metric.items()])}" | |
best_ema_str = f"best metric, {', '.join(['{}: {}'.format(k, v) for k, v in ema_best_metric.items()])}" | |
self.logger.info(best_ema_str) | |
best_str = f"best metric, {', '.join(['{}: {}'.format(k, v) for k, v in best_metric.items()])}" | |
self.logger.info(best_str) | |
if self.writer is not None: | |
self.writer.close() | |
if torch.cuda.device_count() > 1: | |
torch.distributed.destroy_process_group() | |
def eval(self): | |
self.model.eval() | |
with torch.no_grad(): | |
total_frame = 0.0 | |
total_time = 0.0 | |
pbar = tqdm( | |
total=len(self.valid_dataloader), | |
desc='eval model:', | |
position=0, | |
leave=True, | |
) | |
sum_images = 0 | |
for idx, batch in enumerate(self.valid_dataloader): | |
batch = [t.to(self.device) for t in batch] | |
start = time.time() | |
if self.scaler: | |
with torch.cuda.amp.autocast(): | |
preds = self.model(batch[0], data=batch[1:]) | |
else: | |
preds = self.model(batch[0], data=batch[1:]) | |
total_time += time.time() - start | |
# Obtain usable results from post-processing methods | |
# Evaluate the results of the current batch | |
post_result = self.post_process_class(preds, batch) | |
self.eval_class(post_result, batch) | |
pbar.update(1) | |
total_frame += len(batch[0]) | |
sum_images += 1 | |
# Get final metric,eg. acc or hmean | |
metric = self.eval_class.get_metric() | |
pbar.close() | |
self.model.train() | |
metric['fps'] = total_frame / total_time | |
return metric | |
def eval_ema(self): | |
# self.model.eval() | |
with torch.no_grad(): | |
total_frame = 0.0 | |
total_time = 0.0 | |
pbar = tqdm( | |
total=len(self.valid_dataloader), | |
desc='eval ema_model:', | |
position=0, | |
leave=True, | |
) | |
sum_images = 0 | |
for idx, batch in enumerate(self.valid_dataloader): | |
batch = [t.to(self.device) for t in batch] | |
start = time.time() | |
if self.scaler: | |
with torch.cuda.amp.autocast(): | |
preds = self.ema_model(batch[0], data=batch[1:]) | |
else: | |
preds = self.ema_model(batch[0], data=batch[1:]) | |
total_time += time.time() - start | |
# Obtain usable results from post-processing methods | |
# Evaluate the results of the current batch | |
post_result = self.post_process_class(preds, batch) | |
self.eval_class(post_result, batch) | |
pbar.update(1) | |
total_frame += len(batch[0]) | |
sum_images += 1 | |
# Get final metric,eg. acc or hmean | |
metric = self.eval_class.get_metric() | |
pbar.close() | |
# self.model.train() | |
metric['fps'] = total_frame / total_time | |
return metric | |
def test_dataloader(self): | |
starttime = time.time() | |
count = 0 | |
try: | |
for data in self.train_dataloader: | |
count += 1 | |
if count % 1 == 0: | |
batch_time = time.time() - starttime | |
starttime = time.time() | |
self.logger.info( | |
f'reader: {count}, {data[0].shape}, {batch_time}') | |
except: | |
import traceback | |
self.logger.info(traceback.format_exc()) | |
self.logger.info(f'finish reader: {count}, Success!') | |