Spaces:
Running
Running
import os | |
import platform | |
import time | |
import yaml | |
import torch | |
import datetime | |
from torch.utils.tensorboard import SummaryWriter | |
import torch.utils.data | |
import numpy as np | |
import glob | |
import shutil | |
from utils.net_util import to_cuda | |
def worker_init_fn(worker_id): # set numpy's random seed | |
seed = torch.initial_seed() | |
seed = seed % (2 ** 32) | |
np.random.seed(seed + worker_id) | |
class BaseTrainer: | |
def __init__(self, opt): | |
self.opt = opt | |
self.dataset = None | |
self.network = None | |
self.net_dict = {} | |
self.optm_dict = {} | |
self.update_keys = None | |
self.lr_schedule_dict = {} | |
self.iter_idx = 0 | |
self.epoch_idx = 0 | |
self.iter_num = 9999999999 | |
self.loss_weight = self.opt['train']['loss_weight'] | |
def load_pretrained(path, dict_): | |
data = torch.load(path) | |
for k in dict_: | |
if k in data: | |
print('# Loading %s...' % k) | |
dict_[k].load_state_dict(data[k]) | |
else: | |
print('# %s not found!' % k) | |
return data.get('epoch_idx', None) | |
def load_ckpt(self, path, load_optm = True): | |
epoch_idx = self.load_pretrained(path + '/net.pt', self.net_dict) | |
if load_optm: | |
if os.path.exists(path + '/optm.pt'): | |
self.load_pretrained(path + '/optm.pt', self.optm_dict) | |
else: | |
print('# Optimizer not found!') | |
return epoch_idx | |
# @staticmethod | |
def save_trained(self, path, dict_): | |
data = {} | |
for k in dict_: | |
data[k] = dict_[k].state_dict() | |
data.update({ | |
'epoch_idx': self.epoch_idx, | |
}) | |
torch.save(data, path) | |
def save_ckpt(self, path, save_optm = True): | |
self.save_trained(path + '/net.pt', self.net_dict) | |
if save_optm: | |
self.save_trained(path + '/optm.pt', self.optm_dict) | |
def zero_grad(self): | |
if self.update_keys is None: | |
update_keys = self.optm_dict.keys() | |
else: | |
update_keys = self.update_keys | |
for k in update_keys: | |
self.optm_dict[k].zero_grad() | |
def step(self): | |
if self.update_keys is None: | |
update_keys = self.optm_dict.keys() | |
else: | |
update_keys = self.update_keys | |
for k in update_keys: | |
self.optm_dict[k].step() | |
def update_lr(self, iter_idx): | |
lr_dict = {} | |
if self.update_keys is None: | |
update_keys = self.optm_dict.keys() | |
else: | |
update_keys = self.update_keys | |
for k in update_keys: | |
lr = self.lr_schedule_dict[k].get_learning_rate(iter_idx) | |
for param_group in self.optm_dict[k].param_groups: | |
param_group['lr'] = lr | |
lr_dict[k] = lr | |
return lr_dict | |
def set_dataset(self, dataset): | |
self.dataset = dataset | |
def set_network(self, network): | |
self.network = network | |
def set_net_dict(self, net_dict): | |
self.net_dict = net_dict | |
def set_optm_dict(self, optm_dict): | |
self.optm_dict = optm_dict | |
def set_update_keys(self, update_keys): | |
self.update_keys = update_keys | |
def set_lr_schedule_dict(self, lr_schedule_dict): | |
self.lr_schedule_dict = lr_schedule_dict | |
def set_train(self, flag = True): | |
if flag: | |
for k, net in self.net_dict.items(): | |
if k in self.update_keys: | |
net.train() | |
else: | |
net.eval() | |
else: | |
for k, net in self.net_dict.items(): | |
net.eval() | |
def train(self): | |
# log | |
os.makedirs(self.opt['train']['net_ckpt_dir'], exist_ok = True) | |
log_dir = self.opt['train']['net_ckpt_dir'] + '/' + datetime.datetime.now().strftime('%Y_%m_%d_%H_%M_%S') | |
os.makedirs(log_dir, exist_ok = True) | |
writer = SummaryWriter(log_dir) | |
yaml.dump(self.opt, open(log_dir + '/config_bk.yaml', 'w'), sort_keys = False) | |
self.set_train() | |
self.dataset.training = True | |
batch_size = self.opt['train'].get('batch_size', 1) | |
num_workers = self.opt['train'].get('num_workers', 0) | |
dataloader = torch.utils.data.DataLoader(self.dataset, | |
batch_size = batch_size, | |
shuffle = True, | |
num_workers = num_workers, | |
worker_init_fn = worker_init_fn, | |
drop_last = True) | |
self.batch_num = len(self.dataset) // batch_size | |
if self.opt['train'].get('save_init_ckpt', False) and self.opt['train'].get('start_epoch', 0) == 0: | |
init_folder = self.opt['train']['net_ckpt_dir'] + '/init_ckpt' | |
if not os.path.exists(init_folder) or self.opt['train']['start_epoch'] == 0: | |
os.makedirs(init_folder, exist_ok = True) | |
self.save_ckpt(init_folder, False) | |
else: | |
print('# Init checkpoint has been saved!') | |
if self.opt['train']['prev_ckpt'] is not None: | |
start_epoch = self.load_ckpt(self.opt['train']['prev_ckpt']) + 1 | |
else: | |
prev_ckpt_path = self.opt['train']['net_ckpt_dir'] + '/epoch_latest' | |
if os.path.exists(prev_ckpt_path): | |
start_epoch = self.load_ckpt(prev_ckpt_path) + 1 | |
else: | |
start_epoch = None | |
if start_epoch is None: | |
start_epoch = self.opt['train'].get('start_epoch', 0) | |
end_epoch = self.opt['train'].get('end_epoch', 999) | |
forward_one_pass = self.forward_one_pass | |
for epoch_idx in range(start_epoch, end_epoch): | |
self.epoch_idx = epoch_idx | |
self.update_config_before_epoch(epoch_idx) | |
epoch_losses = dict() | |
time0 = time.time() | |
for batch_idx, items in enumerate(dataloader): | |
iter_idx = batch_idx + self.batch_num * epoch_idx | |
self.iter_idx = iter_idx | |
lr_dict = self.update_lr(iter_idx) | |
items = to_cuda(items) | |
loss, batch_losses = forward_one_pass(items) | |
# self.zero_grad() | |
# loss.backward() | |
# self.step() | |
# record batch loss | |
log_info = 'epoch %d, batch %d, ' % (epoch_idx, batch_idx) | |
log_info += 'lr: ' | |
for k in lr_dict.keys(): | |
log_info += '%s %e, ' % (k, lr_dict[k]) | |
for key in batch_losses.keys(): | |
log_info = log_info + ('%s: %f, ' % (key, batch_losses[key])) | |
writer.add_scalar('%s/Batch' % key, batch_losses[key], iter_idx) | |
if key in epoch_losses: | |
epoch_losses[key] += batch_losses[key] | |
else: | |
epoch_losses[key] = batch_losses[key] | |
print(log_info) | |
with open(os.path.join(log_dir, 'loss.txt'), 'a') as fp: | |
# record loss weight | |
if batch_idx == 0: | |
loss_weights_info = '' | |
for k in self.opt['train']['loss_weight'].keys(): | |
loss_weights_info += '%s: %f, ' % (k, self.opt['train']['loss_weight'][k]) | |
fp.write('# Loss weights: \n' + loss_weights_info + '\n') | |
fp.write(log_info + '\n') | |
if iter_idx % self.opt['train']['ckpt_interval']['batch'] == 0 and iter_idx != 0: | |
for folder in glob.glob(self.opt['train']['net_ckpt_dir'] + '/batch_*'): | |
shutil.rmtree(folder) | |
model_folder = self.opt['train']['net_ckpt_dir'] + '/batch_%d' % iter_idx | |
os.makedirs(model_folder, exist_ok = True) | |
self.save_ckpt(model_folder, save_optm = False) | |
if iter_idx % self.opt['train']['eval_interval'] == 0 and iter_idx != 0: | |
# if True: | |
self.mini_test() | |
self.set_train() | |
time1 = time.time() | |
print('One iteration costs %f secs' % (time1 - time0)) | |
time0 = time1 | |
if iter_idx == self.iter_num: | |
return | |
""" EPOCH """ | |
# record epoch loss | |
for key in epoch_losses.keys(): | |
epoch_losses[key] /= self.batch_num | |
writer.add_scalar('%s/Epoch' % key, epoch_losses[key], epoch_idx) | |
if epoch_idx % self.opt['train']['ckpt_interval']['epoch'] == 0: | |
model_folder = self.opt['train']['net_ckpt_dir'] + '/epoch_%d' % epoch_idx | |
os.makedirs(model_folder, exist_ok = True) | |
self.save_ckpt(model_folder) | |
if self.batch_num > 50: | |
latest_folder = self.opt['train']['net_ckpt_dir'] + '/epoch_latest' | |
os.makedirs(latest_folder, exist_ok = True) | |
self.save_ckpt(latest_folder) | |
writer.close() | |
def mini_test(self): | |
""" Test during training """ | |
pass | |
def forward_one_pass(self, items): | |
raise NotImplementedError('"forward_one_pass" method is not implemented!') | |
def update_config_before_epoch(self, epoch_idx): | |
pass | |