|
import torch.utils.data |
|
from dataset import Offline_Dataset |
|
import yaml |
|
from sgmnet.match_model import matcher as SGM_Model |
|
from superglue.match_model import matcher as SG_Model |
|
import torch.distributed as dist |
|
import torch |
|
import os |
|
from collections import namedtuple |
|
from train import train |
|
from config import get_config, print_usage |
|
|
|
|
|
def main(config, model_config): |
|
"""The main function.""" |
|
|
|
if config.model_name == "SGM": |
|
model = SGM_Model(model_config) |
|
elif config.model_name == "SG": |
|
model = SG_Model(model_config) |
|
else: |
|
raise NotImplementedError |
|
|
|
|
|
torch.cuda.set_device(config.local_rank) |
|
device = torch.device(f"cuda:{config.local_rank}") |
|
model.to(device) |
|
dist.init_process_group(backend="nccl", init_method="env://") |
|
model = torch.nn.parallel.DistributedDataParallel( |
|
model, device_ids=[config.local_rank] |
|
) |
|
|
|
if config.local_rank == 0: |
|
os.system("nvidia-smi") |
|
|
|
|
|
train_dataset = Offline_Dataset(config, "train") |
|
train_sampler = torch.utils.data.distributed.DistributedSampler( |
|
train_dataset, shuffle=True |
|
) |
|
train_loader = torch.utils.data.DataLoader( |
|
train_dataset, |
|
batch_size=config.train_batch_size // torch.distributed.get_world_size(), |
|
num_workers=8 // dist.get_world_size(), |
|
pin_memory=False, |
|
sampler=train_sampler, |
|
collate_fn=train_dataset.collate_fn, |
|
) |
|
|
|
valid_dataset = Offline_Dataset(config, "valid") |
|
valid_sampler = torch.utils.data.distributed.DistributedSampler( |
|
valid_dataset, shuffle=False |
|
) |
|
valid_loader = torch.utils.data.DataLoader( |
|
valid_dataset, |
|
batch_size=config.train_batch_size, |
|
num_workers=8 // dist.get_world_size(), |
|
pin_memory=False, |
|
collate_fn=valid_dataset.collate_fn, |
|
sampler=valid_sampler, |
|
) |
|
|
|
if config.local_rank == 0: |
|
print("start training .....") |
|
train(model, train_loader, valid_loader, config, model_config) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
config, unparsed = get_config() |
|
with open(config.config_path, "r") as f: |
|
model_config = yaml.load(f) |
|
model_config = namedtuple("model_config", model_config.keys())( |
|
*model_config.values() |
|
) |
|
|
|
if len(unparsed) > 0: |
|
print_usage() |
|
exit(1) |
|
|
|
main(config, model_config) |
|
|