File size: 3,631 Bytes
63f3cf2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
# -*- coding: UTF-8 -*-
'''=================================================
@Project -> File   pram -> recdataset
@IDE    PyCharm
@Author fx221@cam.ac.uk
@Date   29/01/2024 14:42
=================================================='''
import numpy as np
from torch.utils.data import Dataset


class RecDataset(Dataset):
    def __init__(self, sub_sets=[]):
        assert len(sub_sets) >= 1

        self.sub_sets = sub_sets
        self.names = []

        self.sub_set_index = []
        self.seg_offsets = []
        self.sub_set_item_index = []
        self.dataset_names = []
        self.scene_names = []
        start_index_valid_seg = 1  # start from 1, 0 is for invalid

        total_subset = 0
        for scene_set in sub_sets:  # [0, n_class]
            name = scene_set.dataset
            self.names.append(name)
            n_samples = len(scene_set)

            n_class = scene_set.n_class
            self.seg_offsets = self.seg_offsets + [start_index_valid_seg for v in range(len(scene_set))]
            start_index_valid_seg = start_index_valid_seg + n_class - 1

            self.sub_set_index = self.sub_set_index + [total_subset for k in range(n_samples)]
            self.sub_set_item_index = self.sub_set_item_index + [k for k in range(n_samples)]

            # self.dataset_names = self.dataset_names + [name for k in range(n_samples)]
            self.scene_names = self.scene_names + [name for k in range(n_samples)]
            total_subset += 1

        self.n_class = start_index_valid_seg

        print('Load {} images {} segs from {} subsets from {}'.format(len(self.sub_set_item_index), self.n_class,
                                                                      len(sub_sets), self.names))

    def __len__(self):
        return len(self.sub_set_item_index)

    def __getitem__(self, idx):
        subset_idx = self.sub_set_index[idx]
        item_idx = self.sub_set_item_index[idx]
        scene_name = self.scene_names[idx]

        out = self.sub_sets[subset_idx][item_idx]

        org_gt_seg = out['gt_seg']
        org_gt_cls = out['gt_cls']
        org_gt_cls_dist = out['gt_cls_dist']
        org_gt_n_seg = out['gt_n_seg']
        offset = self.seg_offsets[idx]
        org_n_class = self.sub_sets[subset_idx].n_class

        gt_seg = np.zeros(shape=(org_gt_seg.shape[0],), dtype=int)  # [0, ..., n_features]
        gt_n_seg = np.zeros(shape=(self.n_class,), dtype=int)
        gt_cls = np.zeros(shape=(self.n_class,), dtype=int)
        gt_cls_dist = np.zeros(shape=(self.n_class,), dtype=float)

        # copy invalid segments
        gt_n_seg[0] = org_gt_n_seg[0]
        gt_cls[0] = org_gt_cls[0]
        gt_cls_dist[0] = org_gt_cls_dist[0]
        # print('org: ', org_n_class, org_gt_seg.shape, org_gt_n_seg.shape, org_gt_seg)

        # copy valid segments
        gt_seg[org_gt_seg > 0] = org_gt_seg[org_gt_seg > 0] + offset - 1  # [0, ..., 1023]
        gt_n_seg[offset:offset + org_n_class - 1] = org_gt_n_seg[1:]  # [0...,n_seg]
        gt_cls[offset:offset + org_n_class - 1] = org_gt_cls[1:]  # [0, ..., n_seg]
        gt_cls_dist[offset:offset + org_n_class - 1] = org_gt_cls_dist[1:]  # [0, ..., n_seg]

        out['gt_seg'] = gt_seg
        out['gt_cls'] = gt_cls
        out['gt_cls_dist'] = gt_cls_dist
        out['gt_n_seg'] = gt_n_seg

        # print('gt: ', org_n_class, gt_seg.shape, gt_n_seg.shape, gt_seg)
        out['scene_name'] = scene_name

        # out['org_gt_seg'] = org_gt_seg
        # out['org_gt_n_seg'] = org_gt_n_seg
        # out['org_gt_cls'] = org_gt_cls
        # out['org_gt_cls_dist'] = org_gt_cls_dist

        return out