Vincentqyw
fix: roma
8b973ee
raw
history blame
4.21 kB
import torch
import cv2
import numpy as np
from torchvision import transforms
from torch.utils.data import Dataset
from pathlib import Path
class PatchesDataset(Dataset):
"""
HPatches dataset class.
# Note: output_shape = (output_width, output_height)
# Note: this returns Pytorch tensors, resized to output_shape (if specified)
# Note: the homography will be adjusted according to output_shape.
Parameters
----------
root_dir : str
Path to the dataset
use_color : bool
Return color images or convert to grayscale.
data_transform : Function
Transformations applied to the sample
output_shape: tuple
If specified, the images and homographies will be resized to the desired shape.
type: str
Dataset subset to return from ['i', 'v', 'all']:
i - illumination sequences
v - viewpoint sequences
all - all sequences
"""
def __init__(
self,
root_dir,
use_color=True,
data_transform=None,
output_shape=None,
type="all",
):
super().__init__()
self.type = type
self.root_dir = root_dir
self.data_transform = data_transform
self.output_shape = output_shape
self.use_color = use_color
base_path = Path(root_dir)
folder_paths = [x for x in base_path.iterdir() if x.is_dir()]
image_paths = []
warped_image_paths = []
homographies = []
for path in folder_paths:
if self.type == "i" and path.stem[0] != "i":
continue
if self.type == "v" and path.stem[0] != "v":
continue
num_images = 5
file_ext = ".ppm"
for i in range(2, 2 + num_images):
image_paths.append(str(Path(path, "1" + file_ext)))
warped_image_paths.append(str(Path(path, str(i) + file_ext)))
homographies.append(np.loadtxt(str(Path(path, "H_1_" + str(i)))))
self.files = {
"image_paths": image_paths,
"warped_image_paths": warped_image_paths,
"homography": homographies,
}
def scale_homography(self, homography, original_scale, new_scale, pre):
scales = np.divide(new_scale, original_scale)
if pre:
s = np.diag(np.append(scales, 1.0))
homography = np.matmul(s, homography)
else:
sinv = np.diag(np.append(1.0 / scales, 1.0))
homography = np.matmul(homography, sinv)
return homography
def __len__(self):
return len(self.files["image_paths"])
def __getitem__(self, idx):
def _read_image(path):
img = cv2.imread(path, cv2.IMREAD_COLOR)
if self.use_color:
return img
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
return gray
image = _read_image(self.files["image_paths"][idx])
warped_image = _read_image(self.files["warped_image_paths"][idx])
homography = np.array(self.files["homography"][idx])
sample = {
"image": image,
"warped_image": warped_image,
"homography": homography,
"index": idx,
}
# Apply transformations
if self.output_shape is not None:
sample["homography"] = self.scale_homography(
sample["homography"],
sample["image"].shape[:2][::-1],
self.output_shape,
pre=False,
)
sample["homography"] = self.scale_homography(
sample["homography"],
sample["warped_image"].shape[:2][::-1],
self.output_shape,
pre=True,
)
for key in ["image", "warped_image"]:
sample[key] = cv2.resize(sample[key], self.output_shape)
if self.use_color is False:
sample[key] = np.expand_dims(sample[key], axis=2)
transform = transforms.ToTensor()
for key in ["image", "warped_image"]:
sample[key] = transform(sample[key]).type("torch.FloatTensor")
return sample