Spaces:
Running
Running
File size: 4,625 Bytes
4dfb78b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
"""
Simply load images from a folder or nested folders (does not have any split).
"""
import argparse
import logging
import tarfile
import matplotlib.pyplot as plt
import numpy as np
import torch
from omegaconf import OmegaConf
from ..settings import DATA_PATH
from ..utils.image import ImagePreprocessor, load_image
from ..utils.tools import fork_rng
from ..visualization.viz2d import plot_image_grid
from .base_dataset import BaseDataset
logger = logging.getLogger(__name__)
def read_homography(path):
with open(path) as f:
result = []
for line in f.readlines():
while " " in line: # Remove double spaces
line = line.replace(" ", " ")
line = line.replace(" \n", "").replace("\n", "")
# Split and discard empty strings
elements = list(filter(lambda s: s, line.split(" ")))
if elements:
result.append(elements)
return np.array(result).astype(float)
class HPatches(BaseDataset, torch.utils.data.Dataset):
default_conf = {
"preprocessing": ImagePreprocessor.default_conf,
"data_dir": "hpatches-sequences-release",
"subset": None,
"ignore_large_images": True,
"grayscale": False,
}
# Large images that were ignored in previous papers
ignored_scenes = (
"i_contruction",
"i_crownnight",
"i_dc",
"i_pencils",
"i_whitebuilding",
"v_artisans",
"v_astronautis",
"v_talent",
)
url = "http://icvl.ee.ic.ac.uk/vbalnt/hpatches/hpatches-sequences-release.tar.gz"
def _init(self, conf):
assert conf.batch_size == 1
self.preprocessor = ImagePreprocessor(conf.preprocessing)
self.root = DATA_PATH / conf.data_dir
if not self.root.exists():
logger.info("Downloading the HPatches dataset.")
self.download()
self.sequences = sorted([x.name for x in self.root.iterdir()])
if not self.sequences:
raise ValueError("No image found!")
self.items = [] # (seq, q_idx, is_illu)
for seq in self.sequences:
if conf.ignore_large_images and seq in self.ignored_scenes:
continue
if conf.subset is not None and conf.subset != seq[0]:
continue
for i in range(2, 7):
self.items.append((seq, i, seq[0] == "i"))
def download(self):
data_dir = self.root.parent
data_dir.mkdir(exist_ok=True, parents=True)
tar_path = data_dir / self.url.rsplit("/", 1)[-1]
torch.hub.download_url_to_file(self.url, tar_path)
with tarfile.open(tar_path) as tar:
tar.extractall(data_dir)
tar_path.unlink()
def get_dataset(self, split):
assert split in ["val", "test"]
return self
def _read_image(self, seq: str, idx: int) -> dict:
img = load_image(self.root / seq / f"{idx}.ppm", self.conf.grayscale)
return self.preprocessor(img)
def __getitem__(self, idx):
seq, q_idx, is_illu = self.items[idx]
data0 = self._read_image(seq, 1)
data1 = self._read_image(seq, q_idx)
H = read_homography(self.root / seq / f"H_1_{q_idx}")
H = data1["transform"] @ H @ np.linalg.inv(data0["transform"])
return {
"H_0to1": H.astype(np.float32),
"scene": seq,
"idx": idx,
"is_illu": is_illu,
"name": f"{seq}/{idx}.ppm",
"view0": data0,
"view1": data1,
}
def __len__(self):
return len(self.items)
def visualize(args):
conf = {
"batch_size": 1,
"num_workers": 8,
"prefetch_factor": 1,
}
conf = OmegaConf.merge(conf, OmegaConf.from_cli(args.dotlist))
dataset = HPatches(conf)
loader = dataset.get_data_loader("test")
logger.info("The dataset has %d elements.", len(loader))
with fork_rng(seed=dataset.conf.seed):
images = []
for _, data in zip(range(args.num_items), loader):
images.append(
(data[f"view{i}"]["image"][0].permute(1, 2, 0) for i in range(2))
)
plot_image_grid(images, dpi=args.dpi)
plt.tight_layout()
plt.show()
if __name__ == "__main__":
from .. import logger # overwrite the logger
parser = argparse.ArgumentParser()
parser.add_argument("--num_items", type=int, default=8)
parser.add_argument("--dpi", type=int, default=100)
parser.add_argument("dotlist", nargs="*")
args = parser.parse_intermixed_args()
visualize(args)
|