import sys
import torch
from ..utils.base_model import BaseModel
from ..utils import do_system
from pathlib import Path
import subprocess
import logging

logger = logging.getLogger(__name__)

sys.path.append(str(Path(__file__).parent / "../../third_party"))
from ASpanFormer.src.ASpanFormer.aspanformer import ASpanFormer as _ASpanFormer
from ASpanFormer.src.config.default import get_cfg_defaults
from ASpanFormer.src.utils.misc import lower_config
from ASpanFormer.demo import demo_utils

aspanformer_path = Path(__file__).parent / "../../third_party/ASpanFormer"


class ASpanFormer(BaseModel):
    default_conf = {
        "weights": "outdoor",
        "match_threshold": 0.2,
        "config_path": aspanformer_path / "configs/aspan/outdoor/aspan_test.py",
        "model_name": "weights_aspanformer.tar",
    }
    required_inputs = ["image0", "image1"]
    proxy = "http://localhost:1080"
    aspanformer_models = {
        "weights_aspanformer.tar": "https://drive.google.com/uc?id=1eavM9dTkw9nbc-JqlVVfGPU5UvTTfc6k&confirm=t"
    }

    def _init(self, conf):
        model_path = aspanformer_path / "weights" / Path(conf["weights"] + ".ckpt")
        # Download the model.
        if not model_path.exists():
            # model_path.parent.mkdir(exist_ok=True)
            tar_path = aspanformer_path / conf["model_name"]
            if not tar_path.exists():
                link = self.aspanformer_models[conf["model_name"]]
                cmd = ["gdown", link, "-O", str(tar_path), "--proxy", self.proxy]
                cmd_wo_proxy = ["gdown", link, "-O", str(tar_path)]
                logger.info(f"Downloading the Aspanformer model with `{cmd_wo_proxy}`.")
                try:
                    subprocess.run(cmd_wo_proxy, check=True)
                except subprocess.CalledProcessError as e:
                    logger.info(f"Downloading the Aspanformer model with `{cmd}`.")
                    try:
                        subprocess.run(cmd, check=True)
                    except subprocess.CalledProcessError as e:
                        logger.error(f"Failed to download the Aspanformer model.")
                        raise e

            do_system(f"cd {str(aspanformer_path)} & tar -xvf {str(tar_path)}")

        logger.info(f"Loading Aspanformer model...")

        config = get_cfg_defaults()
        config.merge_from_file(conf["config_path"])
        _config = lower_config(config)
        self.net = _ASpanFormer(config=_config["aspan"])
        weight_path = model_path
        state_dict = torch.load(str(weight_path), map_location="cpu")["state_dict"]
        self.net.load_state_dict(state_dict, strict=False)

    def _forward(self, data):
        data_ = {
            "image0": data["image0"],
            "image1": data["image1"],
        }
        self.net(data_, online_resize=True)
        corr0 = data_["mkpts0_f"]
        corr1 = data_["mkpts1_f"]
        pred = {}
        pred["keypoints0"], pred["keypoints1"] = corr0, corr1
        return pred