File size: 4,625 Bytes
4d4dd90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
"""
Simply load images from a folder or nested folders (does not have any split).
"""
import argparse
import logging
import tarfile

import matplotlib.pyplot as plt
import numpy as np
import torch
from omegaconf import OmegaConf

from ..settings import DATA_PATH
from ..utils.image import ImagePreprocessor, load_image
from ..utils.tools import fork_rng
from ..visualization.viz2d import plot_image_grid
from .base_dataset import BaseDataset

logger = logging.getLogger(__name__)


def read_homography(path):
    with open(path) as f:
        result = []
        for line in f.readlines():
            while "  " in line:  # Remove double spaces
                line = line.replace("  ", " ")
            line = line.replace(" \n", "").replace("\n", "")
            # Split and discard empty strings
            elements = list(filter(lambda s: s, line.split(" ")))
            if elements:
                result.append(elements)
        return np.array(result).astype(float)


class HPatches(BaseDataset, torch.utils.data.Dataset):
    default_conf = {
        "preprocessing": ImagePreprocessor.default_conf,
        "data_dir": "hpatches-sequences-release",
        "subset": None,
        "ignore_large_images": True,
        "grayscale": False,
    }

    # Large images that were ignored in previous papers
    ignored_scenes = (
        "i_contruction",
        "i_crownnight",
        "i_dc",
        "i_pencils",
        "i_whitebuilding",
        "v_artisans",
        "v_astronautis",
        "v_talent",
    )
    url = "http://icvl.ee.ic.ac.uk/vbalnt/hpatches/hpatches-sequences-release.tar.gz"

    def _init(self, conf):
        assert conf.batch_size == 1
        self.preprocessor = ImagePreprocessor(conf.preprocessing)

        self.root = DATA_PATH / conf.data_dir
        if not self.root.exists():
            logger.info("Downloading the HPatches dataset.")
            self.download()
        self.sequences = sorted([x.name for x in self.root.iterdir()])
        if not self.sequences:
            raise ValueError("No image found!")
        self.items = []  # (seq, q_idx, is_illu)
        for seq in self.sequences:
            if conf.ignore_large_images and seq in self.ignored_scenes:
                continue
            if conf.subset is not None and conf.subset != seq[0]:
                continue
            for i in range(2, 7):
                self.items.append((seq, i, seq[0] == "i"))

    def download(self):
        data_dir = self.root.parent
        data_dir.mkdir(exist_ok=True, parents=True)
        tar_path = data_dir / self.url.rsplit("/", 1)[-1]
        torch.hub.download_url_to_file(self.url, tar_path)
        with tarfile.open(tar_path) as tar:
            tar.extractall(data_dir)
        tar_path.unlink()

    def get_dataset(self, split):
        assert split in ["val", "test"]
        return self

    def _read_image(self, seq: str, idx: int) -> dict:
        img = load_image(self.root / seq / f"{idx}.ppm", self.conf.grayscale)
        return self.preprocessor(img)

    def __getitem__(self, idx):
        seq, q_idx, is_illu = self.items[idx]
        data0 = self._read_image(seq, 1)
        data1 = self._read_image(seq, q_idx)
        H = read_homography(self.root / seq / f"H_1_{q_idx}")
        H = data1["transform"] @ H @ np.linalg.inv(data0["transform"])
        return {
            "H_0to1": H.astype(np.float32),
            "scene": seq,
            "idx": idx,
            "is_illu": is_illu,
            "name": f"{seq}/{idx}.ppm",
            "view0": data0,
            "view1": data1,
        }

    def __len__(self):
        return len(self.items)


def visualize(args):
    conf = {
        "batch_size": 1,
        "num_workers": 8,
        "prefetch_factor": 1,
    }
    conf = OmegaConf.merge(conf, OmegaConf.from_cli(args.dotlist))
    dataset = HPatches(conf)
    loader = dataset.get_data_loader("test")
    logger.info("The dataset has %d elements.", len(loader))

    with fork_rng(seed=dataset.conf.seed):
        images = []
        for _, data in zip(range(args.num_items), loader):
            images.append(
                (data[f"view{i}"]["image"][0].permute(1, 2, 0) for i in range(2))
            )
    plot_image_grid(images, dpi=args.dpi)
    plt.tight_layout()
    plt.show()


if __name__ == "__main__":
    from .. import logger  # overwrite the logger

    parser = argparse.ArgumentParser()
    parser.add_argument("--num_items", type=int, default=8)
    parser.add_argument("--dpi", type=int, default=100)
    parser.add_argument("dotlist", nargs="*")
    args = parser.parse_intermixed_args()
    visualize(args)