Vincentqyw
update: features and matchers
a80d6bb
raw
history blame
2.39 kB
import torch.utils.data
from dataset import Offline_Dataset
import yaml
from sgmnet.match_model import matcher as SGM_Model
from superglue.match_model import matcher as SG_Model
import torch.distributed as dist
import torch
import os
from collections import namedtuple
from train import train
from config import get_config, print_usage
def main(config,model_config):
"""The main function."""
# Initialize network
if config.model_name=='SGM':
model = SGM_Model(model_config)
elif config.model_name=='SG':
model= SG_Model(model_config)
else:
raise NotImplementedError
#initialize ddp
torch.cuda.set_device(config.local_rank)
device = torch.device(f'cuda:{config.local_rank}')
model.to(device)
dist.init_process_group(backend='nccl',init_method='env://')
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[config.local_rank])
if config.local_rank==0:
os.system('nvidia-smi')
#initialize dataset
train_dataset = Offline_Dataset(config,'train')
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset,shuffle=True)
train_loader=torch.utils.data.DataLoader(train_dataset, batch_size=config.train_batch_size//torch.distributed.get_world_size(),
num_workers=8//dist.get_world_size(), pin_memory=False,sampler=train_sampler,collate_fn=train_dataset.collate_fn)
valid_dataset = Offline_Dataset(config,'valid')
valid_sampler = torch.utils.data.distributed.DistributedSampler(valid_dataset,shuffle=False)
valid_loader=torch.utils.data.DataLoader(valid_dataset, batch_size=config.train_batch_size,
num_workers=8//dist.get_world_size(), pin_memory=False,collate_fn=valid_dataset.collate_fn,sampler=valid_sampler)
if config.local_rank==0:
print('start training .....')
train(model,train_loader, valid_loader, config,model_config)
if __name__ == "__main__":
# ----------------------------------------
# Parse configuration
config, unparsed = get_config()
with open(config.config_path, 'r') as f:
model_config = yaml.load(f)
model_config=namedtuple('model_config',model_config.keys())(*model_config.values())
# If we have unparsed arguments, print usage and exit
if len(unparsed) > 0:
print_usage()
exit(1)
main(config,model_config)