File size: 4,804 Bytes
63f3cf2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
# -*- coding: UTF-8 -*-
'''=================================================
@Project -> File   pram -> seven_scenes
@IDE    PyCharm
@Author fx221@cam.ac.uk
@Date   29/01/2024 14:36
=================================================='''
import os
import os.path as osp
import numpy as np
from colmap_utils.read_write_model import read_model
import torchvision.transforms as tvt
from dataset.basicdataset import BasicDataset


class SevenScenes(BasicDataset):
    def __init__(self, landmark_path, scene, dataset_path, n_class, seg_mode, seg_method, dataset='7Scenes',
                 nfeatures=1024,
                 query_p3d_fn=None,
                 train=True,
                 with_aug=False,
                 min_inliers=0,
                 max_inliers=4096,
                 random_inliers=False,
                 jitter_params=None,
                 scale_params=None,
                 image_dim=3,
                 query_info_path=None,
                 sample_ratio=1,
                 ):
        self.landmark_path = osp.join(landmark_path, scene)
        self.dataset_path = osp.join(dataset_path, scene)
        self.n_class = n_class
        self.dataset = dataset + '/' + scene
        self.nfeatures = nfeatures
        self.with_aug = with_aug
        self.jitter_params = jitter_params
        self.scale_params = scale_params
        self.image_dim = image_dim
        self.train = train
        self.min_inliers = min_inliers
        self.max_inliers = max_inliers if max_inliers < nfeatures else nfeatures
        self.random_inliers = random_inliers
        self.image_prefix = ''

        train_transforms = []
        if self.with_aug:
            train_transforms.append(tvt.ColorJitter(
                brightness=jitter_params['brightness'],
                contrast=jitter_params['contrast'],
                saturation=jitter_params['saturation'],
                hue=jitter_params['hue']))
            if jitter_params['blur'] > 0:
                train_transforms.append(tvt.GaussianBlur(kernel_size=int(jitter_params['blur'])))
        self.train_transforms = tvt.Compose(train_transforms)

        if train:
            self.cameras, self.images, point3Ds = read_model(path=osp.join(self.landmark_path, '3D-models'), ext='.bin')
            self.name_to_id = {image.name: i for i, image in self.images.items() if len(self.images[i].point3D_ids) > 0}

        # only for testing of query images
        if not self.train:
            data = np.load(query_p3d_fn, allow_pickle=True)[()]
            self.img_p3d = data
        else:
            self.img_p3d = {}

        if self.train:
            split_fn = osp.join(self.dataset_path, 'TrainSplit.txt')
        else:
            split_fn = osp.join(self.dataset_path, 'TestSplit.txt')

        self.img_fns = []
        with open(split_fn, 'r') as f:
            lines = f.readlines()
            for l in lines:
                seq = int(l.strip()[8:])
                fns = os.listdir(osp.join(self.dataset_path, osp.join('seq-{:02d}'.format(seq))))
                fns = sorted(fns)
                nf = 0
                for fn in fns:
                    if fn.find('png') >= 0:
                        if train and 'seq-{:02d}'.format(seq) + '/' + fn not in self.name_to_id.keys():
                            continue
                        if not train and 'seq-{:02d}'.format(seq) + '/' + fn not in self.img_p3d.keys():
                            continue
                        if nf % sample_ratio == 0:
                            self.img_fns.append('seq-{:02d}'.format(seq) + '/' + fn)
                        nf += 1

        print('Load {} images from {} for {}...'.format(len(self.img_fns),
                                                        self.dataset, 'training' if train else 'eval'))

        data = np.load(osp.join(self.landmark_path,
                                'point3D_cluster_n{:d}_{:s}_{:s}.npy'.format(n_class - 1, seg_mode, seg_method)),
                       allow_pickle=True)[()]
        p3d_id = data['id']
        seg_id = data['label']
        self.p3d_seg = {p3d_id[i]: seg_id[i] for i in range(p3d_id.shape[0])}
        xyzs = data['xyz']
        self.p3d_xyzs = {p3d_id[i]: xyzs[i] for i in range(p3d_id.shape[0])}

        # with open(osp.join(self.landmark_path, 'sc_mean_scale.txt'), 'r') as f:
        #     lines = f.readlines()
        #     for l in lines:
        #         l = l.strip().split()
        #         self.mean_xyz = np.array([float(v) for v in l[:3]])
        #         self.scale_xyz = np.array([float(v) for v in l[3:]])

        if not train:
            self.query_info = self.read_query_info(path=query_info_path)

        self.nfeatures = nfeatures
        self.feature_dir = osp.join(self.landmark_path, 'feats')
        self.feats = {}