|
""" |
|
@Date: 2021/07/17 |
|
@description: |
|
""" |
|
import os |
|
import torch |
|
import torch.nn as nn |
|
import datetime |
|
|
|
|
|
class BaseModule(nn.Module): |
|
def __init__(self, ckpt_dir=None): |
|
super().__init__() |
|
|
|
self.ckpt_dir = ckpt_dir |
|
|
|
if ckpt_dir: |
|
if not os.path.exists(ckpt_dir): |
|
os.makedirs(ckpt_dir) |
|
else: |
|
self.model_lst = [x for x in sorted(os.listdir(self.ckpt_dir)) if x.endswith('.pkl')] |
|
|
|
self.last_model_path = None |
|
self.best_model_path = None |
|
self.best_accuracy = -float('inf') |
|
self.acc_d = {} |
|
|
|
def show_parameter_number(self, logger): |
|
total = sum(p.numel() for p in self.parameters()) |
|
trainable = sum(p.numel() for p in self.parameters() if p.requires_grad) |
|
logger.info('{} parameter total:{:,}, trainable:{:,}'.format(self._get_name(), total, trainable)) |
|
|
|
def load(self, device, logger, optimizer=None, best=False): |
|
if len(self.model_lst) == 0: |
|
logger.info('*'*50) |
|
logger.info("Empty model folder! Using initial weights") |
|
logger.info('*'*50) |
|
return 0 |
|
|
|
last_model_lst = list(filter(lambda n: '_last_' in n, self.model_lst)) |
|
best_model_lst = list(filter(lambda n: '_best_' in n, self.model_lst)) |
|
|
|
if len(last_model_lst) == 0 and len(best_model_lst) == 0: |
|
logger.info('*'*50) |
|
ckpt_path = os.path.join(self.ckpt_dir, self.model_lst[0]) |
|
logger.info(f"Load: {ckpt_path}") |
|
checkpoint = torch.load(ckpt_path, map_location=torch.device(device)) |
|
self.load_state_dict(checkpoint, strict=False) |
|
logger.info('*'*50) |
|
return 0 |
|
|
|
checkpoint = None |
|
if len(last_model_lst) > 0: |
|
self.last_model_path = os.path.join(self.ckpt_dir, last_model_lst[-1]) |
|
checkpoint = torch.load(self.last_model_path, map_location=torch.device(device)) |
|
self.best_accuracy = checkpoint['accuracy'] |
|
self.acc_d = checkpoint['acc_d'] |
|
|
|
if len(best_model_lst) > 0: |
|
self.best_model_path = os.path.join(self.ckpt_dir, best_model_lst[-1]) |
|
best_checkpoint = torch.load(self.best_model_path, map_location=torch.device(device)) |
|
self.best_accuracy = best_checkpoint['accuracy'] |
|
self.acc_d = best_checkpoint['acc_d'] |
|
if best: |
|
checkpoint = best_checkpoint |
|
|
|
for k in self.acc_d: |
|
if isinstance(self.acc_d[k], float): |
|
self.acc_d[k] = { |
|
'acc': self.acc_d[k], |
|
'epoch': checkpoint['epoch'] |
|
} |
|
|
|
if checkpoint is None: |
|
logger.error("Invalid checkpoint") |
|
return |
|
|
|
self.load_state_dict(checkpoint['net'], strict=False) |
|
if optimizer and not best: |
|
logger.info('Load optimizer') |
|
optimizer.load_state_dict(checkpoint['optimizer']) |
|
for state in optimizer.state.values(): |
|
for k, v in state.items(): |
|
if torch.is_tensor(v): |
|
state[k] = v.to(device) |
|
|
|
logger.info('*'*50) |
|
if best: |
|
logger.info(f"Lode best: {self.best_model_path}") |
|
else: |
|
logger.info(f"Lode last: {self.last_model_path}") |
|
|
|
logger.info(f"Best accuracy: {self.best_accuracy}") |
|
logger.info(f"Last epoch: {checkpoint['epoch'] + 1}") |
|
logger.info('*'*50) |
|
return checkpoint['epoch'] + 1 |
|
|
|
def update_acc(self, acc_d, epoch, logger): |
|
logger.info("-" * 100) |
|
for k in acc_d: |
|
if k not in self.acc_d.keys() or acc_d[k] > self.acc_d[k]['acc']: |
|
self.acc_d[k] = { |
|
'acc': acc_d[k], |
|
'epoch': epoch |
|
} |
|
logger.info(f"Update ACC: {k} {self.acc_d[k]['acc']:.4f}({self.acc_d[k]['epoch']}-{epoch})") |
|
logger.info("-" * 100) |
|
|
|
def save(self, optim, epoch, accuracy, logger, replace=True, acc_d=None, config=None): |
|
""" |
|
|
|
:param config: |
|
:param optim: |
|
:param epoch: |
|
:param accuracy: |
|
:param logger: |
|
:param replace: |
|
:param acc_d: 其他评估数据,visible_2/3d, full_2/3d, rmse... |
|
:return: |
|
""" |
|
if acc_d: |
|
self.update_acc(acc_d, epoch, logger) |
|
name = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S_last_{:.4f}_{}'.format(accuracy, epoch)) |
|
name = f"model_{name}.pkl" |
|
checkpoint = { |
|
'net': self.state_dict(), |
|
'optimizer': optim.state_dict(), |
|
'epoch': epoch, |
|
'accuracy': accuracy, |
|
'acc_d': acc_d |
|
} |
|
|
|
if (True or config.MODEL.SAVE_LAST) and epoch % config.TRAIN.SAVE_FREQ == 0: |
|
if replace and self.last_model_path and os.path.exists(self.last_model_path): |
|
os.remove(self.last_model_path) |
|
self.last_model_path = os.path.join(self.ckpt_dir, name) |
|
torch.save(checkpoint, self.last_model_path) |
|
logger.info(f"Saved last model: {self.last_model_path}") |
|
|
|
if accuracy > self.best_accuracy: |
|
self.best_accuracy = accuracy |
|
|
|
if True or config.MODEL.SAVE_BEST: |
|
if replace and self.best_model_path and os.path.exists(self.best_model_path): |
|
os.remove(self.best_model_path) |
|
self.best_model_path = os.path.join(self.ckpt_dir, name.replace('last', 'best')) |
|
torch.save(checkpoint, self.best_model_path) |
|
logger.info("#" * 100) |
|
logger.info(f"Saved best model: {self.best_model_path}") |
|
logger.info("#" * 100) |