Spaces:
Running
Running
import torch.utils.data | |
from dataset import Offline_Dataset | |
import yaml | |
from sgmnet.match_model import matcher as SGM_Model | |
from superglue.match_model import matcher as SG_Model | |
import torch.distributed as dist | |
import torch | |
import os | |
from collections import namedtuple | |
from train import train | |
from config import get_config, print_usage | |
def main(config,model_config): | |
"""The main function.""" | |
# Initialize network | |
if config.model_name=='SGM': | |
model = SGM_Model(model_config) | |
elif config.model_name=='SG': | |
model= SG_Model(model_config) | |
else: | |
raise NotImplementedError | |
#initialize ddp | |
torch.cuda.set_device(config.local_rank) | |
device = torch.device(f'cuda:{config.local_rank}') | |
model.to(device) | |
dist.init_process_group(backend='nccl',init_method='env://') | |
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[config.local_rank]) | |
if config.local_rank==0: | |
os.system('nvidia-smi') | |
#initialize dataset | |
train_dataset = Offline_Dataset(config,'train') | |
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset,shuffle=True) | |
train_loader=torch.utils.data.DataLoader(train_dataset, batch_size=config.train_batch_size//torch.distributed.get_world_size(), | |
num_workers=8//dist.get_world_size(), pin_memory=False,sampler=train_sampler,collate_fn=train_dataset.collate_fn) | |
valid_dataset = Offline_Dataset(config,'valid') | |
valid_sampler = torch.utils.data.distributed.DistributedSampler(valid_dataset,shuffle=False) | |
valid_loader=torch.utils.data.DataLoader(valid_dataset, batch_size=config.train_batch_size, | |
num_workers=8//dist.get_world_size(), pin_memory=False,collate_fn=valid_dataset.collate_fn,sampler=valid_sampler) | |
if config.local_rank==0: | |
print('start training .....') | |
train(model,train_loader, valid_loader, config,model_config) | |
if __name__ == "__main__": | |
# ---------------------------------------- | |
# Parse configuration | |
config, unparsed = get_config() | |
with open(config.config_path, 'r') as f: | |
model_config = yaml.load(f) | |
model_config=namedtuple('model_config',model_config.keys())(*model_config.values()) | |
# If we have unparsed arguments, print usage and exit | |
if len(unparsed) > 0: | |
print_usage() | |
exit(1) | |
main(config,model_config) | |