topdu's picture
openocr demo
29f689c
import copy
import datetime
import os
import random
import time
import numpy as np
import torch
from tqdm import tqdm
from openrec.losses import build_loss
from openrec.metrics import build_metric
from openrec.modeling import build_model
from openrec.optimizer import build_optimizer
from openrec.postprocess import build_post_process
from tools.data import build_dataloader
from tools.utils.ckpt import load_ckpt, save_ckpt
from tools.utils.logging import get_logger
from tools.utils.stats import TrainingStats
from tools.utils.utility import AverageMeter
__all__ = ['Trainer']
def get_parameter_number(model):
total_num = sum(p.numel() for p in model.parameters())
trainable_num = sum(p.numel() for p in model.parameters()
if p.requires_grad)
return {'Total': total_num, 'Trainable': trainable_num}
class Trainer(object):
def __init__(self, cfg, mode='train'):
self.cfg = cfg.cfg
self.local_rank = (int(os.environ['LOCAL_RANK'])
if 'LOCAL_RANK' in os.environ else 0)
self.set_device(self.cfg['Global']['device'])
mode = mode.lower()
assert mode in [
'train_eval',
'train',
'eval',
'test',
], 'mode should be train, eval and test'
if torch.cuda.device_count() > 1 and 'train' in mode:
torch.distributed.init_process_group(backend='nccl')
torch.cuda.set_device(self.device)
self.cfg['Global']['distributed'] = True
else:
self.cfg['Global']['distributed'] = False
self.local_rank = 0
self.cfg['Global']['output_dir'] = self.cfg['Global'].get(
'output_dir', 'output')
os.makedirs(self.cfg['Global']['output_dir'], exist_ok=True)
self.writer = None
if self.local_rank == 0 and self.cfg['Global'][
'use_tensorboard'] and 'train' in mode:
from torch.utils.tensorboard import SummaryWriter
self.writer = SummaryWriter(self.cfg['Global']['output_dir'])
self.logger = get_logger(
'openrec',
os.path.join(self.cfg['Global']['output_dir'], 'train.log')
if 'train' in mode else None,
)
cfg.print_cfg(self.logger.info)
if self.cfg['Global']['device'] == 'gpu' and self.device.type == 'cpu':
self.logger.info('cuda is not available, auto switch to cpu')
self.grad_clip_val = self.cfg['Global'].get('grad_clip_val', 0)
self.all_ema = self.cfg['Global'].get('all_ema', True)
self.use_ema = self.cfg['Global'].get('use_ema', True)
self.set_random_seed(self.cfg['Global'].get('seed', 48))
# build data loader
self.train_dataloader = None
if 'train' in mode:
cfg.save(
os.path.join(self.cfg['Global']['output_dir'], 'config.yml'),
self.cfg)
self.train_dataloader = build_dataloader(self.cfg, 'Train',
self.logger)
self.logger.info(
f'train dataloader has {len(self.train_dataloader)} iters')
self.valid_dataloader = None
if 'eval' in mode and self.cfg['Eval']:
self.valid_dataloader = build_dataloader(self.cfg, 'Eval',
self.logger)
self.logger.info(
f'valid dataloader has {len(self.valid_dataloader)} iters')
# build post process
self.post_process_class = build_post_process(self.cfg['PostProcess'],
self.cfg['Global'])
# build model
# for rec algorithm
char_num = self.post_process_class.get_character_num()
self.cfg['Architecture']['Decoder']['out_channels'] = char_num
self.model = build_model(self.cfg['Architecture'])
self.logger.info(get_parameter_number(model=self.model))
self.model = self.model.to(self.device)
if self.local_rank == 0:
ema_model = build_model(self.cfg['Architecture'])
self.ema_model = ema_model.to(self.device)
self.ema_model.eval()
use_sync_bn = self.cfg['Global'].get('use_sync_bn', False)
if use_sync_bn:
self.model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(
self.model)
self.logger.info('convert_sync_batchnorm')
# build loss
self.loss_class = build_loss(self.cfg['Loss'])
self.optimizer, self.lr_scheduler = None, None
if self.train_dataloader is not None:
# build optim
self.optimizer, self.lr_scheduler = build_optimizer(
self.cfg['Optimizer'],
self.cfg['LRScheduler'],
epochs=self.cfg['Global']['epoch_num'],
step_each_epoch=len(self.train_dataloader),
model=self.model,
)
self.eval_class = build_metric(self.cfg['Metric'])
self.status = load_ckpt(self.model, self.cfg, self.optimizer,
self.lr_scheduler)
if self.cfg['Global']['distributed']:
self.model = torch.nn.parallel.DistributedDataParallel(
self.model, [self.local_rank], find_unused_parameters=False)
# amp
self.scaler = (torch.cuda.amp.GradScaler() if self.cfg['Global'].get(
'use_amp', False) else None)
self.logger.info(
f'run with torch {torch.__version__} and device {self.device}')
def load_params(self, params):
self.model.load_state_dict(params)
def set_random_seed(self, seed):
torch.manual_seed(seed) # 为CPU设置随机种子
if self.device.type == 'cuda':
torch.backends.cudnn.benchmark = True
torch.cuda.manual_seed(seed) # 为当前GPU设置随机种子
torch.cuda.manual_seed_all(seed) # 为所有GPU设置随机种子
random.seed(seed)
np.random.seed(seed)
def set_device(self, device):
if device == 'gpu' and torch.cuda.is_available():
device = torch.device(f'cuda:{self.local_rank}')
else:
device = torch.device('cpu')
self.device = device
def train(self):
cal_metric_during_train = self.cfg['Global'].get(
'cal_metric_during_train', False)
log_smooth_window = self.cfg['Global']['log_smooth_window']
epoch_num = self.cfg['Global']['epoch_num']
print_batch_step = self.cfg['Global']['print_batch_step']
eval_epoch_step = self.cfg['Global'].get('eval_epoch_step', 1)
start_eval_epoch = 0
if self.valid_dataloader is not None:
if type(eval_epoch_step) == list and len(eval_epoch_step) >= 2:
start_eval_epoch = eval_epoch_step[0]
eval_epoch_step = eval_epoch_step[1]
if len(self.valid_dataloader) == 0:
start_eval_epoch = 1e111
self.logger.info(
'No Images in eval dataset, evaluation during training will be disabled'
)
self.logger.info(
f'During the training process, after the {start_eval_epoch}th epoch, '
f'an evaluation is run every {eval_epoch_step} epoch')
else:
start_eval_epoch = 1e111
eval_batch_step = self.cfg['Global']['eval_batch_step']
global_step = self.status.get('global_step', 0)
start_eval_step = 0
if type(eval_batch_step) == list and len(eval_batch_step) >= 2:
start_eval_step = eval_batch_step[0]
eval_batch_step = eval_batch_step[1]
if len(self.valid_dataloader) == 0:
self.logger.info(
'No Images in eval dataset, evaluation during training '
'will be disabled')
start_eval_step = 1e111
self.logger.info(
'During the training process, after the {}th iteration, '
'an evaluation is run every {} iterations'.format(
start_eval_step, eval_batch_step))
start_epoch = self.status.get('epoch', 1)
best_metric = self.status.get('metrics', {})
if self.eval_class.main_indicator not in best_metric:
best_metric[self.eval_class.main_indicator] = 0
ema_best_metric = self.status.get('metrics', {})
ema_best_metric[self.eval_class.main_indicator] = 0
train_stats = TrainingStats(log_smooth_window, ['lr'])
self.model.train()
total_samples = 0
train_reader_cost = 0.0
train_batch_cost = 0.0
best_iter = 0
ema_stpe = 1
ema_eval_iter = 0
loss_avg = 0.
reader_start = time.time()
eta_meter = AverageMeter()
for epoch in range(start_epoch, epoch_num + 1):
if self.train_dataloader.dataset.need_reset:
self.train_dataloader = build_dataloader(
self.cfg,
'Train',
self.logger,
epoch=epoch % 20 if epoch % 20 != 0 else 20,
)
for idx, batch in enumerate(self.train_dataloader):
batch = [t.to(self.device) for t in batch]
self.optimizer.zero_grad()
train_reader_cost += time.time() - reader_start
# use amp
if self.scaler:
with torch.cuda.amp.autocast():
preds = self.model(batch[0], data=batch[1:])
loss = self.loss_class(preds, batch)
self.scaler.scale(loss['loss']).backward()
if self.grad_clip_val > 0:
torch.nn.utils.clip_grad_norm_(
self.model.parameters(),
max_norm=self.grad_clip_val)
self.scaler.step(self.optimizer)
self.scaler.update()
else:
preds = self.model(batch[0], data=batch[1:])
loss = self.loss_class(preds, batch)
avg_loss = loss['loss']
avg_loss.backward()
if self.grad_clip_val > 0:
torch.nn.utils.clip_grad_norm_(
self.model.parameters(),
max_norm=self.grad_clip_val)
self.optimizer.step()
if cal_metric_during_train: # only rec and cls need
post_result = self.post_process_class(preds,
batch,
training=True)
self.eval_class(post_result, batch, training=True)
metric = self.eval_class.get_metric()
train_stats.update(metric)
train_batch_time = time.time() - reader_start
train_batch_cost += train_batch_time
eta_meter.update(train_batch_time)
global_step += 1
total_samples += len(batch[0])
self.lr_scheduler.step()
if self.local_rank == 0 and self.use_ema and epoch > (
epoch_num - epoch_num // 10):
with torch.no_grad():
loss_currn = loss['loss'].detach().cpu().numpy().mean()
loss_avg = ((loss_avg *
(ema_stpe - 1)) + loss_currn) / (ema_stpe)
if ema_stpe == 1:
# current_weight = copy.deepcopy(self.model.module.state_dict())
ema_state_dict = copy.deepcopy(
self.model.module.state_dict() if self.
cfg['Global']['distributed'] else self.model.
state_dict())
self.ema_model.load_state_dict(ema_state_dict)
# if global_step > (epoch_num - epoch_num//10)*max_iter:
elif loss_currn <= loss_avg or self.all_ema:
# eval_batch_step = 500
current_weight = copy.deepcopy(
self.model.module.state_dict() if self.
cfg['Global']['distributed'] else self.model.
state_dict())
k1 = 1 / (ema_stpe + 1)
k2 = 1 - k1
for k, v in ema_state_dict.items():
# v = (v * (ema_stpe - 1) + current_weight[k])/ema_stpe
v = v * k2 + current_weight[k] * k1
# v.req = True
ema_state_dict[k] = v
# ema_stpe += 1
self.ema_model.load_state_dict(ema_state_dict)
ema_stpe += 1
if global_step > start_eval_step and (
global_step -
start_eval_step) % eval_batch_step == 0:
ema_cur_metric = self.eval_ema()
ema_cur_metric_str = f"cur ema metric, {', '.join(['{}: {}'.format(k, v) for k, v in ema_cur_metric.items()])}"
self.logger.info(ema_cur_metric_str)
state = {
'epoch': epoch,
'global_step': global_step,
'state_dict': self.ema_model.state_dict(),
'optimizer': None,
'scheduler': None,
'config': self.cfg,
'metrics': ema_cur_metric,
}
save_path = os.path.join(
self.cfg['Global']['output_dir'],
'ema_' + str(ema_eval_iter) + '.pth')
torch.save(state, save_path)
self.logger.info(f'save ema ckpt to {save_path}')
ema_eval_iter += 1
if ema_cur_metric[self.eval_class.
main_indicator] >= ema_best_metric[
self.eval_class.main_indicator]:
ema_best_metric.update(ema_cur_metric)
ema_best_metric['best_epoch'] = epoch
best_ema_str = f"best metric, {', '.join(['{}: {}'.format(k, v) for k, v in ema_best_metric.items()])}"
self.logger.info(best_ema_str)
# logger
stats = {
k: float(v)
if v.shape == [] else v.detach().cpu().numpy().mean()
for k, v in loss.items()
}
stats['lr'] = self.lr_scheduler.get_last_lr()[0]
train_stats.update(stats)
if self.writer is not None:
for k, v in train_stats.get().items():
self.writer.add_scalar(f'TRAIN/{k}', v, global_step)
if self.local_rank == 0 and (
(global_step > 0 and global_step % print_batch_step == 0)
or (idx >= len(self.train_dataloader) - 1)):
logs = train_stats.log()
eta_sec = (
(epoch_num + 1 - epoch) * len(self.train_dataloader) -
idx - 1) * eta_meter.avg
eta_sec_format = str(
datetime.timedelta(seconds=int(eta_sec)))
strs = (
f'epoch: [{epoch}/{epoch_num}], global_step: {global_step}, {logs}, '
f'avg_reader_cost: {train_reader_cost / print_batch_step:.5f} s, '
f'avg_batch_cost: {train_batch_cost / print_batch_step:.5f} s, '
f'avg_samples: {total_samples / print_batch_step}, '
f'ips: {total_samples / train_batch_cost:.5f} samples/s, '
f'eta: {eta_sec_format}')
self.logger.info(strs)
total_samples = 0
train_reader_cost = 0.0
train_batch_cost = 0.0
reader_start = time.time()
# eval
if (global_step > start_eval_step and
(global_step - start_eval_step) % eval_batch_step
== 0) and self.local_rank == 0:
cur_metric = self.eval()
cur_metric_str = f"cur metric, {', '.join(['{}: {}'.format(k, v) for k, v in cur_metric.items()])}"
self.logger.info(cur_metric_str)
# logger metric
if self.writer is not None:
for k, v in cur_metric.items():
if isinstance(v, (float, int)):
self.writer.add_scalar(f'EVAL/{k}',
cur_metric[k],
global_step)
if (cur_metric[self.eval_class.main_indicator] >=
best_metric[self.eval_class.main_indicator]):
best_metric.update(cur_metric)
best_metric['best_epoch'] = epoch
if self.writer is not None:
self.writer.add_scalar(
f'EVAL/best_{self.eval_class.main_indicator}',
best_metric[self.eval_class.main_indicator],
global_step,
)
if epoch > (epoch_num - epoch_num // 10 - 2):
save_ckpt(self.model,
self.cfg,
self.optimizer,
self.lr_scheduler,
epoch,
global_step,
best_metric,
is_best=True,
prefix='best_' + str(best_iter))
best_iter += 1
# else:
save_ckpt(self.model,
self.cfg,
self.optimizer,
self.lr_scheduler,
epoch,
global_step,
best_metric,
is_best=True,
prefix=None)
best_str = f"best metric, {', '.join(['{}: {}'.format(k, v) for k, v in best_metric.items()])}"
self.logger.info(best_str)
if self.local_rank == 0 and epoch > start_eval_epoch and (
epoch - start_eval_epoch) % eval_epoch_step == 0:
cur_metric = self.eval()
cur_metric_str = f"cur metric, {', '.join(['{}: {}'.format(k, v) for k, v in cur_metric.items()])}"
self.logger.info(cur_metric_str)
# logger metric
if self.writer is not None:
for k, v in cur_metric.items():
if isinstance(v, (float, int)):
self.writer.add_scalar(f'EVAL/{k}', cur_metric[k],
global_step)
if (cur_metric[self.eval_class.main_indicator] >=
best_metric[self.eval_class.main_indicator]):
best_metric.update(cur_metric)
best_metric['best_epoch'] = epoch
if self.writer is not None:
self.writer.add_scalar(
f'EVAL/best_{self.eval_class.main_indicator}',
best_metric[self.eval_class.main_indicator],
global_step,
)
if epoch > (epoch_num - epoch_num // 10 - 2):
save_ckpt(self.model,
self.cfg,
self.optimizer,
self.lr_scheduler,
epoch,
global_step,
best_metric,
is_best=True,
prefix='best_' + str(best_iter))
best_iter += 1
# else:
save_ckpt(self.model,
self.cfg,
self.optimizer,
self.lr_scheduler,
epoch,
global_step,
best_metric,
is_best=True,
prefix=None)
best_str = f"best metric, {', '.join(['{}: {}'.format(k, v) for k, v in best_metric.items()])}"
self.logger.info(best_str)
if self.local_rank == 0:
save_ckpt(self.model,
self.cfg,
self.optimizer,
self.lr_scheduler,
epoch,
global_step,
best_metric,
is_best=False,
prefix=None)
if epoch > (epoch_num - epoch_num // 10 - 2):
save_ckpt(self.model,
self.cfg,
self.optimizer,
self.lr_scheduler,
epoch,
global_step,
best_metric,
is_best=False,
prefix='epoch_' + str(epoch))
if self.use_ema and epoch > (epoch_num - epoch_num // 10):
# if global_step > start_eval_step and (global_step - start_eval_step) % eval_batch_step == 0:
ema_cur_metric = self.eval_ema()
ema_cur_metric_str = f"cur ema metric, {', '.join(['{}: {}'.format(k, v) for k, v in ema_cur_metric.items()])}"
self.logger.info(ema_cur_metric_str)
state = {
'epoch': epoch,
'global_step': global_step,
'state_dict': self.ema_model.state_dict(),
'optimizer': None,
'scheduler': None,
'config': self.cfg,
'metrics': ema_cur_metric,
}
save_path = os.path.join(
self.cfg['Global']['output_dir'],
'ema_' + str(ema_eval_iter) + '.pth')
torch.save(state, save_path)
self.logger.info(f'save ema ckpt to {save_path}')
ema_eval_iter += 1
if (ema_cur_metric[self.eval_class.main_indicator] >=
ema_best_metric[self.eval_class.main_indicator]):
ema_best_metric.update(ema_cur_metric)
ema_best_metric['best_epoch'] = epoch
# ema_cur_metric_str = f"best ema metric, {', '.join(['{}: {}'.format(k, v) for k, v in ema_best_metric.items()])}"
best_ema_str = f"best metric, {', '.join(['{}: {}'.format(k, v) for k, v in ema_best_metric.items()])}"
self.logger.info(best_ema_str)
best_str = f"best metric, {', '.join(['{}: {}'.format(k, v) for k, v in best_metric.items()])}"
self.logger.info(best_str)
if self.writer is not None:
self.writer.close()
if torch.cuda.device_count() > 1:
torch.distributed.destroy_process_group()
def eval(self):
self.model.eval()
with torch.no_grad():
total_frame = 0.0
total_time = 0.0
pbar = tqdm(
total=len(self.valid_dataloader),
desc='eval model:',
position=0,
leave=True,
)
sum_images = 0
for idx, batch in enumerate(self.valid_dataloader):
batch = [t.to(self.device) for t in batch]
start = time.time()
if self.scaler:
with torch.cuda.amp.autocast():
preds = self.model(batch[0], data=batch[1:])
else:
preds = self.model(batch[0], data=batch[1:])
total_time += time.time() - start
# Obtain usable results from post-processing methods
# Evaluate the results of the current batch
post_result = self.post_process_class(preds, batch)
self.eval_class(post_result, batch)
pbar.update(1)
total_frame += len(batch[0])
sum_images += 1
# Get final metric,eg. acc or hmean
metric = self.eval_class.get_metric()
pbar.close()
self.model.train()
metric['fps'] = total_frame / total_time
return metric
def eval_ema(self):
# self.model.eval()
with torch.no_grad():
total_frame = 0.0
total_time = 0.0
pbar = tqdm(
total=len(self.valid_dataloader),
desc='eval ema_model:',
position=0,
leave=True,
)
sum_images = 0
for idx, batch in enumerate(self.valid_dataloader):
batch = [t.to(self.device) for t in batch]
start = time.time()
if self.scaler:
with torch.cuda.amp.autocast():
preds = self.ema_model(batch[0], data=batch[1:])
else:
preds = self.ema_model(batch[0], data=batch[1:])
total_time += time.time() - start
# Obtain usable results from post-processing methods
# Evaluate the results of the current batch
post_result = self.post_process_class(preds, batch)
self.eval_class(post_result, batch)
pbar.update(1)
total_frame += len(batch[0])
sum_images += 1
# Get final metric,eg. acc or hmean
metric = self.eval_class.get_metric()
pbar.close()
# self.model.train()
metric['fps'] = total_frame / total_time
return metric
def test_dataloader(self):
starttime = time.time()
count = 0
try:
for data in self.train_dataloader:
count += 1
if count % 1 == 0:
batch_time = time.time() - starttime
starttime = time.time()
self.logger.info(
f'reader: {count}, {data[0].shape}, {batch_time}')
except:
import traceback
self.logger.info(traceback.format_exc())
self.logger.info(f'finish reader: {count}, Success!')