|
|
|
|
|
|
|
|
|
import os, pdb |
|
import numpy as np |
|
from PIL import Image |
|
|
|
from .dataset import Dataset |
|
from .pair_dataset import PairDataset, StillPairDataset |
|
|
|
|
|
class AachenImages(Dataset): |
|
"""Loads all images from the Aachen Day-Night dataset""" |
|
|
|
def __init__(self, select="db day night", root="data/aachen"): |
|
Dataset.__init__(self) |
|
self.root = root |
|
self.img_dir = "images_upright" |
|
self.select = set(select.split()) |
|
assert self.select, "Nothing was selected" |
|
|
|
self.imgs = [] |
|
root = os.path.join(root, self.img_dir) |
|
for dirpath, _, filenames in os.walk(root): |
|
r = dirpath[len(root) + 1 :] |
|
if not (self.select & set(r.split("/"))): |
|
continue |
|
self.imgs += [os.path.join(r, f) for f in filenames if f.endswith(".jpg")] |
|
|
|
self.nimg = len(self.imgs) |
|
assert self.nimg, "Empty Aachen dataset" |
|
|
|
def get_key(self, idx): |
|
return self.imgs[idx] |
|
|
|
|
|
class AachenImages_DB(AachenImages): |
|
"""Only database (db) images.""" |
|
|
|
def __init__(self, **kw): |
|
AachenImages.__init__(self, select="db", **kw) |
|
self.db_image_idxs = {self.get_tag(i): i for i, f in enumerate(self.imgs)} |
|
|
|
def get_tag(self, idx): |
|
|
|
return os.path.split(self.imgs[idx][:-4])[1] |
|
|
|
|
|
class AachenPairs_StyleTransferDayNight(AachenImages_DB, StillPairDataset): |
|
"""synthetic day-night pairs of images |
|
(night images obtained using autoamtic style transfer from web night images) |
|
""" |
|
|
|
def __init__(self, root="data/aachen/style_transfer", **kw): |
|
StillPairDataset.__init__(self) |
|
AachenImages_DB.__init__(self, **kw) |
|
old_root = os.path.join(self.root, self.img_dir) |
|
self.root = os.path.commonprefix((old_root, root)) |
|
self.img_dir = "" |
|
|
|
newpath = lambda folder, f: os.path.join(folder, f)[len(self.root) :] |
|
self.imgs = [newpath(old_root, f) for f in self.imgs] |
|
|
|
self.image_pairs = [] |
|
for fname in os.listdir(root): |
|
tag = fname.split(".jpg.st_")[0] |
|
self.image_pairs.append((self.db_image_idxs[tag], len(self.imgs))) |
|
self.imgs.append(newpath(root, fname)) |
|
|
|
self.nimg = len(self.imgs) |
|
self.npairs = len(self.image_pairs) |
|
assert self.nimg and self.npairs |
|
|
|
|
|
class AachenPairs_OpticalFlow(AachenImages_DB, PairDataset): |
|
"""Image pairs from Aachen db with optical flow.""" |
|
|
|
def __init__(self, root="data/aachen/optical_flow", **kw): |
|
PairDataset.__init__(self) |
|
AachenImages_DB.__init__(self, **kw) |
|
self.root_flow = root |
|
|
|
|
|
flows = { |
|
f for f in os.listdir(os.path.join(root, "flow")) if f.endswith(".png") |
|
} |
|
masks = { |
|
f for f in os.listdir(os.path.join(root, "mask")) if f.endswith(".png") |
|
} |
|
assert flows == masks, "Missing flow or mask pairs" |
|
|
|
make_pair = lambda f: tuple(self.db_image_idxs[v] for v in f[:-4].split("_")) |
|
self.image_pairs = [make_pair(f) for f in flows] |
|
self.npairs = len(self.image_pairs) |
|
assert self.nimg and self.npairs |
|
|
|
def get_mask_filename(self, pair_idx): |
|
tag_a, tag_b = map(self.get_tag, self.image_pairs[pair_idx]) |
|
return os.path.join(self.root_flow, "mask", f"{tag_a}_{tag_b}.png") |
|
|
|
def get_mask(self, pair_idx): |
|
return np.asarray(Image.open(self.get_mask_filename(pair_idx))) |
|
|
|
def get_flow_filename(self, pair_idx): |
|
tag_a, tag_b = map(self.get_tag, self.image_pairs[pair_idx]) |
|
return os.path.join(self.root_flow, "flow", f"{tag_a}_{tag_b}.png") |
|
|
|
def get_flow(self, pair_idx): |
|
fname = self.get_flow_filename(pair_idx) |
|
try: |
|
return self._png2flow(fname) |
|
except IOError: |
|
flow = open(fname[:-4], "rb") |
|
help = np.fromfile(flow, np.float32, 1) |
|
assert help == 202021.25 |
|
W, H = np.fromfile(flow, np.int32, 2) |
|
flow = np.fromfile(flow, np.float32).reshape((H, W, 2)) |
|
return self._flow2png(flow, fname) |
|
|
|
def get_pair(self, idx, output=()): |
|
if isinstance(output, str): |
|
output = output.split() |
|
|
|
img1, img2 = map(self.get_image, self.image_pairs[idx]) |
|
meta = {} |
|
|
|
if "flow" in output or "aflow" in output: |
|
flow = self.get_flow(idx) |
|
assert flow.shape[:2] == img1.size[::-1] |
|
meta["flow"] = flow |
|
H, W = flow.shape[:2] |
|
meta["aflow"] = flow + np.mgrid[:H, :W][::-1].transpose(1, 2, 0) |
|
|
|
if "mask" in output: |
|
mask = self.get_mask(idx) |
|
assert mask.shape[:2] == img1.size[::-1] |
|
meta["mask"] = mask |
|
|
|
return img1, img2, meta |
|
|
|
|
|
if __name__ == "__main__": |
|
print(aachen_db_images) |
|
print(aachen_style_transfer_pairs) |
|
print(aachen_flow_pairs) |
|
pdb.set_trace() |
|
|