Spaces:
Running
Running
# -*- coding: UTF-8 -*- | |
'''================================================= | |
@Project -> File pram -> train | |
@IDE PyCharm | |
@Author fx221@cam.ac.uk | |
@Date 29/01/2024 14:26 | |
==================================================''' | |
import argparse | |
import os | |
import os.path as osp | |
import torch | |
import torchvision.transforms.transforms as tvt | |
import yaml | |
import torch.utils.data as Data | |
import torch.multiprocessing as mp | |
import torch.distributed as dist | |
from nets.segnet import SegNet | |
from nets.segnetvit import SegNetViT | |
from dataset.utils import collect_batch | |
from dataset.get_dataset import compose_datasets | |
from tools.common import torch_set_gpu | |
from trainer import Trainer | |
from nets.sfd2 import ResNet4x, DescriptorCompressor | |
from nets.superpoint import SuperPoint | |
torch.set_grad_enabled(True) | |
parser = argparse.ArgumentParser(description='PRAM', formatter_class=argparse.ArgumentDefaultsHelpFormatter) | |
parser.add_argument('--config', type=str, required=True, help='config of specifications') | |
parser.add_argument('--landmark_path', type=str, default=None, help='path of landmarks') | |
def load_feat_network(config): | |
if config['feature'] == 'spp': | |
net = SuperPoint(config={ | |
'weight_path': '/scratches/flyer_2/fx221/Research/Code/third_weights/superpoint_v1.pth', | |
}).eval() | |
elif config['feature'] == 'resnet4x': | |
net = ResNet4x(inputdim=3, outdim=128) | |
net.load_state_dict( | |
torch.load('weights/sfd2_20230511_210205_resnet4x.79.pth', map_location='cpu')['state_dict'], | |
strict=True) | |
net.eval() | |
else: | |
print('Please input correct feature {:s}'.format(config['feature'])) | |
net = None | |
if config['feat_dim'] != 128: | |
desc_compressor = DescriptorCompressor(inputdim=128, outdim=config['feat_dim']).eval() | |
if config['feat_dim'] == 64: | |
desc_compressor.load_state_dict( | |
torch.load('weights/20230511_210205_resnet4x_B6_R512_I3_O128_pho_resnet4x_e79_to_O64.pth', | |
map_location='cpu'), | |
strict=True) | |
elif config['feat_dim'] == 32: | |
desc_compressor.load_state_dict( | |
torch.load('weights/20230511_210205_resnet4x_B6_R512_I3_O128_pho_resnet4x_e79_to_O32.pth', | |
map_location='cpu'), | |
strict=True) | |
else: | |
desc_compressor = None | |
else: | |
desc_compressor = None | |
return net, desc_compressor | |
def get_model(config): | |
desc_dim = 256 if config['feature'] == 'spp' else 128 | |
if config['use_mid_feature']: | |
desc_dim = 256 | |
model_config = { | |
'network': { | |
'descriptor_dim': desc_dim, | |
'n_layers': config['layers'], | |
'ac_fn': config['ac_fn'], | |
'norm_fn': config['norm_fn'], | |
'n_class': config['n_class'], | |
'output_dim': config['output_dim'], | |
'with_cls': config['with_cls'], | |
'with_sc': config['with_sc'], | |
'with_score': config['with_score'], | |
} | |
} | |
if config['network'] == 'segnet': | |
model = SegNet(model_config.get('network', {})) | |
config['with_cls'] = False | |
elif config['network'] == 'segnetvit': | |
model = SegNetViT(model_config.get('network', {})) | |
config['with_cls'] = False | |
else: | |
raise 'ERROR! {:s} model does not exist'.format(config['network']) | |
if config['local_rank'] == 0: | |
if config['weight_path'] is not None: | |
state_dict = torch.load(osp.join(config['save_path'], config['weight_path']), map_location='cpu')['model'] | |
model.load_state_dict(state_dict, strict=True) | |
print('Load weight from {:s}'.format(osp.join(config['save_path'], config['weight_path']))) | |
if config['resume_path'] is not None and not config['eval']: # only for training | |
model.load_state_dict( | |
torch.load(osp.join(config['save_path'], config['resume_path']), map_location='cpu')['model'], | |
strict=True) | |
print('Load resume weight from {:s}'.format(osp.join(config['save_path'], config['resume_path']))) | |
return model | |
def setup(rank, world_size): | |
os.environ['MASTER_ADDR'] = 'localhost' | |
os.environ['MASTER_PORT'] = '12355' | |
# initialize the process group | |
dist.init_process_group("nccl", rank=rank, world_size=world_size) | |
def train_DDP(rank, world_size, model, config, train_set, test_set, feat_model, img_transforms): | |
print('In train_DDP..., rank: ', rank) | |
torch.cuda.set_device(rank) | |
device = torch.device(f'cuda:{rank}') | |
if feat_model is not None: | |
feat_model.to(device) | |
model.to(device) | |
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) | |
setup(rank=rank, world_size=world_size) | |
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[rank]) | |
train_sampler = torch.utils.data.distributed.DistributedSampler(train_set, | |
shuffle=True, | |
rank=rank, | |
num_replicas=world_size, | |
drop_last=True, # important? | |
) | |
train_loader = torch.utils.data.DataLoader(train_set, | |
batch_size=config['batch_size'] // world_size, | |
num_workers=config['workers'] // world_size, | |
# num_workers=1, | |
pin_memory=True, | |
# persistent_workers=True, | |
shuffle=False, # must be False | |
drop_last=True, | |
collate_fn=collect_batch, | |
prefetch_factor=4, | |
sampler=train_sampler) | |
config['local_rank'] = rank | |
if rank == 0: | |
test_set = test_set | |
else: | |
test_set = None | |
trainer = Trainer(model=model, train_loader=train_loader, feat_model=feat_model, eval_loader=test_set, | |
config=config, img_transforms=img_transforms) | |
trainer.train() | |
if __name__ == '__main__': | |
args = parser.parse_args() | |
with open(args.config, 'rt') as f: | |
config = yaml.load(f, Loader=yaml.Loader) | |
torch_set_gpu(gpus=config['gpu']) | |
if config['local_rank'] == 0: | |
print(config) | |
if config['feature'] == 'spp': | |
img_transforms = None | |
else: | |
img_transforms = [] | |
img_transforms.append(tvt.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])) | |
img_transforms = tvt.Compose(img_transforms) | |
feat_model, desc_compressor = load_feat_network(config=config) | |
dataset = config['dataset'] | |
if config['eval'] or config['loc']: | |
if not config['online']: | |
from localization.loc_by_rec_eval import loc_by_rec_eval | |
test_set = compose_datasets(datasets=dataset, config=config, train=False, sample_ratio=1) | |
config['n_class'] = test_set.n_class | |
model = get_model(config=config) | |
loc_by_rec_eval(rec_model=model.cuda().eval(), | |
loader=test_set, | |
local_feat=feat_model.cuda().eval(), | |
config=config, img_transforms=img_transforms) | |
else: | |
from localization.loc_by_rec_online import loc_by_rec_online | |
model = get_model(config=config) | |
loc_by_rec_online(rec_model=model.cuda().eval(), | |
local_feat=feat_model.cuda().eval(), | |
config=config, img_transforms=img_transforms) | |
exit(0) | |
train_set = compose_datasets(datasets=dataset, config=config, train=True, sample_ratio=None) | |
if config['do_eval']: | |
test_set = compose_datasets(datasets=dataset, config=config, train=False, sample_ratio=None) | |
else: | |
test_set = None | |
config['n_class'] = train_set.n_class | |
model = get_model(config=config) | |
if not config['with_dist'] or len(config['gpu']) == 1: | |
config['with_dist'] = False | |
model = model.cuda() | |
train_loader = Data.DataLoader(dataset=train_set, | |
shuffle=True, | |
batch_size=config['batch_size'], | |
drop_last=True, | |
collate_fn=collect_batch, | |
num_workers=config['workers']) | |
if test_set is not None: | |
test_loader = Data.DataLoader(dataset=test_set, | |
shuffle=False, | |
batch_size=1, | |
drop_last=False, | |
collate_fn=collect_batch, | |
num_workers=4) | |
else: | |
test_loader = None | |
trainer = Trainer(model=model, train_loader=train_loader, feat_model=feat_model, eval_loader=test_loader, | |
config=config, img_transforms=img_transforms) | |
trainer.train() | |
else: | |
mp.spawn(train_DDP, nprocs=len(config['gpu']), | |
args=(len(config['gpu']), model, config, train_set, test_set, feat_model, img_transforms), | |
join=True) | |