Realcat
fix: eloftr
63f3cf2
raw
history blame
4.8 kB
# -*- coding: UTF-8 -*-
'''=================================================
@Project -> File pram -> seven_scenes
@IDE PyCharm
@Author fx221@cam.ac.uk
@Date 29/01/2024 14:36
=================================================='''
import os
import os.path as osp
import numpy as np
from colmap_utils.read_write_model import read_model
import torchvision.transforms as tvt
from dataset.basicdataset import BasicDataset
class SevenScenes(BasicDataset):
def __init__(self, landmark_path, scene, dataset_path, n_class, seg_mode, seg_method, dataset='7Scenes',
nfeatures=1024,
query_p3d_fn=None,
train=True,
with_aug=False,
min_inliers=0,
max_inliers=4096,
random_inliers=False,
jitter_params=None,
scale_params=None,
image_dim=3,
query_info_path=None,
sample_ratio=1,
):
self.landmark_path = osp.join(landmark_path, scene)
self.dataset_path = osp.join(dataset_path, scene)
self.n_class = n_class
self.dataset = dataset + '/' + scene
self.nfeatures = nfeatures
self.with_aug = with_aug
self.jitter_params = jitter_params
self.scale_params = scale_params
self.image_dim = image_dim
self.train = train
self.min_inliers = min_inliers
self.max_inliers = max_inliers if max_inliers < nfeatures else nfeatures
self.random_inliers = random_inliers
self.image_prefix = ''
train_transforms = []
if self.with_aug:
train_transforms.append(tvt.ColorJitter(
brightness=jitter_params['brightness'],
contrast=jitter_params['contrast'],
saturation=jitter_params['saturation'],
hue=jitter_params['hue']))
if jitter_params['blur'] > 0:
train_transforms.append(tvt.GaussianBlur(kernel_size=int(jitter_params['blur'])))
self.train_transforms = tvt.Compose(train_transforms)
if train:
self.cameras, self.images, point3Ds = read_model(path=osp.join(self.landmark_path, '3D-models'), ext='.bin')
self.name_to_id = {image.name: i for i, image in self.images.items() if len(self.images[i].point3D_ids) > 0}
# only for testing of query images
if not self.train:
data = np.load(query_p3d_fn, allow_pickle=True)[()]
self.img_p3d = data
else:
self.img_p3d = {}
if self.train:
split_fn = osp.join(self.dataset_path, 'TrainSplit.txt')
else:
split_fn = osp.join(self.dataset_path, 'TestSplit.txt')
self.img_fns = []
with open(split_fn, 'r') as f:
lines = f.readlines()
for l in lines:
seq = int(l.strip()[8:])
fns = os.listdir(osp.join(self.dataset_path, osp.join('seq-{:02d}'.format(seq))))
fns = sorted(fns)
nf = 0
for fn in fns:
if fn.find('png') >= 0:
if train and 'seq-{:02d}'.format(seq) + '/' + fn not in self.name_to_id.keys():
continue
if not train and 'seq-{:02d}'.format(seq) + '/' + fn not in self.img_p3d.keys():
continue
if nf % sample_ratio == 0:
self.img_fns.append('seq-{:02d}'.format(seq) + '/' + fn)
nf += 1
print('Load {} images from {} for {}...'.format(len(self.img_fns),
self.dataset, 'training' if train else 'eval'))
data = np.load(osp.join(self.landmark_path,
'point3D_cluster_n{:d}_{:s}_{:s}.npy'.format(n_class - 1, seg_mode, seg_method)),
allow_pickle=True)[()]
p3d_id = data['id']
seg_id = data['label']
self.p3d_seg = {p3d_id[i]: seg_id[i] for i in range(p3d_id.shape[0])}
xyzs = data['xyz']
self.p3d_xyzs = {p3d_id[i]: xyzs[i] for i in range(p3d_id.shape[0])}
# with open(osp.join(self.landmark_path, 'sc_mean_scale.txt'), 'r') as f:
# lines = f.readlines()
# for l in lines:
# l = l.strip().split()
# self.mean_xyz = np.array([float(v) for v in l[:3]])
# self.scale_xyz = np.array([float(v) for v in l[3:]])
if not train:
self.query_info = self.read_query_info(path=query_info_path)
self.nfeatures = nfeatures
self.feature_dir = osp.join(self.landmark_path, 'feats')
self.feats = {}