import torch.utils.data
from dataset import Offline_Dataset
import yaml
from sgmnet.match_model import matcher as SGM_Model
from superglue.match_model import matcher as SG_Model
import torch.distributed as dist
import torch
import os
from collections import namedtuple
from train import train
from config import get_config, print_usage


def main(config, model_config):
    """The main function."""
    # Initialize network
    if config.model_name == "SGM":
        model = SGM_Model(model_config)
    elif config.model_name == "SG":
        model = SG_Model(model_config)
    else:
        raise NotImplementedError

    # initialize ddp
    torch.cuda.set_device(config.local_rank)
    device = torch.device(f"cuda:{config.local_rank}")
    model.to(device)
    dist.init_process_group(backend="nccl", init_method="env://")
    model = torch.nn.parallel.DistributedDataParallel(
        model, device_ids=[config.local_rank]
    )

    if config.local_rank == 0:
        os.system("nvidia-smi")

    # initialize dataset
    train_dataset = Offline_Dataset(config, "train")
    train_sampler = torch.utils.data.distributed.DistributedSampler(
        train_dataset, shuffle=True
    )
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=config.train_batch_size // torch.distributed.get_world_size(),
        num_workers=8 // dist.get_world_size(),
        pin_memory=False,
        sampler=train_sampler,
        collate_fn=train_dataset.collate_fn,
    )

    valid_dataset = Offline_Dataset(config, "valid")
    valid_sampler = torch.utils.data.distributed.DistributedSampler(
        valid_dataset, shuffle=False
    )
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=config.train_batch_size,
        num_workers=8 // dist.get_world_size(),
        pin_memory=False,
        collate_fn=valid_dataset.collate_fn,
        sampler=valid_sampler,
    )

    if config.local_rank == 0:
        print("start training .....")
    train(model, train_loader, valid_loader, config, model_config)


if __name__ == "__main__":
    # ----------------------------------------
    # Parse configuration
    config, unparsed = get_config()
    with open(config.config_path, "r") as f:
        model_config = yaml.load(f)
    model_config = namedtuple("model_config", model_config.keys())(
        *model_config.values()
    )
    # If we have unparsed arguments, print usage and exit
    if len(unparsed) > 0:
        print_usage()
        exit(1)

    main(config, model_config)