# -*- coding: UTF-8 -*- '''================================================= @Project -> File pram -> train @IDE PyCharm @Author fx221@cam.ac.uk @Date 03/04/2024 16:33 ==================================================''' import argparse import os import os.path as osp import torch import torchvision.transforms.transforms as tvt import yaml import torch.utils.data as Data import torch.multiprocessing as mp import torch.distributed as dist from nets.sfd2 import load_sfd2 from nets.segnet import SegNet from nets.segnetvit import SegNetViT from nets.load_segnet import load_segnet from dataset.utils import collect_batch from dataset.get_dataset import compose_datasets from tools.common import torch_set_gpu from trainer import Trainer def get_model(config): desc_dim = 256 if config['feature'] == 'spp' else 128 if config['use_mid_feature']: desc_dim = 256 model_config = { 'network': { 'descriptor_dim': desc_dim, 'n_layers': config['layers'], 'ac_fn': config['ac_fn'], 'norm_fn': config['norm_fn'], 'n_class': config['n_class'], 'output_dim': config['output_dim'], # 'with_cls': config['with_cls'], # 'with_sc': config['with_sc'], 'with_score': config['with_score'], } } if config['network'] == 'segnet': model = SegNet(model_config.get('network', {})) config['with_cls'] = False elif config['network'] == 'segnetvit': model = SegNetViT(model_config.get('network', {})) config['with_cls'] = False else: raise 'ERROR! {:s} model does not exist'.format(config['network']) return model parser = argparse.ArgumentParser(description='PRAM', formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument('--config', type=str, required=True, help='config of specifications') # parser.add_argument('--landmark_path', type=str, required=True, help='path of landmarks') parser.add_argument('--feat_weight_path', type=str, default='weights/sfd2_20230511_210205_resnet4x.79.pth') def setup(rank, world_size): os.environ['MASTER_ADDR'] = 'localhost' os.environ['MASTER_PORT'] = '12355' # initialize the process group dist.init_process_group("nccl", rank=rank, world_size=world_size) def train_DDP(rank, world_size, model, config, train_set, test_set, feat_model, img_transforms): print('In train_DDP..., rank: ', rank) torch.cuda.set_device(rank) device = torch.device(f'cuda:{rank}') if feat_model is not None: feat_model.to(device) model.to(device) model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) setup(rank=rank, world_size=world_size) model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[rank]) train_sampler = torch.utils.data.distributed.DistributedSampler(train_set, shuffle=True, rank=rank, num_replicas=world_size, drop_last=True, # important? ) train_loader = torch.utils.data.DataLoader(train_set, batch_size=config['batch_size'] // world_size, num_workers=config['workers'] // world_size, # num_workers=1, pin_memory=True, # persistent_workers=True, shuffle=False, # must be False drop_last=True, collate_fn=collect_batch, prefetch_factor=4, sampler=train_sampler) config['local_rank'] = rank if rank == 0: test_set = test_set else: test_set = None trainer = Trainer(model=model, train_loader=train_loader, feat_model=feat_model, eval_loader=test_set, config=config, img_transforms=img_transforms) trainer.train() if __name__ == '__main__': args = parser.parse_args() with open(args.config, 'rt') as f: config = yaml.load(f, Loader=yaml.Loader) torch_set_gpu(gpus=config['gpu']) if config['local_rank'] == 0: print(config) img_transforms = [] img_transforms.append(tvt.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])) img_transforms = tvt.Compose(img_transforms) feat_model = load_sfd2(weight_path=args.feat_weight_path).cuda().eval() print('Load SFD2 weight from {:s}'.format(args.feat_weight_path)) dataset = config['dataset'] train_set = compose_datasets(datasets=dataset, config=config, train=True, sample_ratio=None) if config['do_eval']: test_set = compose_datasets(datasets=dataset, config=config, train=False, sample_ratio=None) else: test_set = None config['n_class'] = train_set.n_class # model = get_model(config=config) model = load_segnet(network=config['network'], n_class=config['n_class'], desc_dim=256 if config['use_mid_feature'] else 128, n_layers=config['layers'], output_dim=config['output_dim']) if config['local_rank'] == 0: if config['resume_path'] is not None: # only for training model.load_state_dict( torch.load(osp.join(config['save_path'], config['resume_path']), map_location='cpu')['model'], strict=True) print('Load resume weight from {:s}'.format(osp.join(config['save_path'], config['resume_path']))) if not config['with_dist'] or len(config['gpu']) == 1: config['with_dist'] = False model = model.cuda() train_loader = Data.DataLoader(dataset=train_set, shuffle=True, batch_size=config['batch_size'], drop_last=True, collate_fn=collect_batch, num_workers=config['workers']) if test_set is not None: test_loader = Data.DataLoader(dataset=test_set, shuffle=False, batch_size=1, drop_last=False, collate_fn=collect_batch, num_workers=4) else: test_loader = None trainer = Trainer(model=model, train_loader=train_loader, feat_model=feat_model, eval_loader=test_loader, config=config, img_transforms=img_transforms) trainer.train() else: mp.spawn(train_DDP, nprocs=len(config['gpu']), args=(len(config['gpu']), model, config, train_set, test_set, feat_model, img_transforms), join=True)