Spaces:
Running
Running
File size: 3,631 Bytes
63f3cf2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
# -*- coding: UTF-8 -*-
'''=================================================
@Project -> File pram -> recdataset
@IDE PyCharm
@Author fx221@cam.ac.uk
@Date 29/01/2024 14:42
=================================================='''
import numpy as np
from torch.utils.data import Dataset
class RecDataset(Dataset):
def __init__(self, sub_sets=[]):
assert len(sub_sets) >= 1
self.sub_sets = sub_sets
self.names = []
self.sub_set_index = []
self.seg_offsets = []
self.sub_set_item_index = []
self.dataset_names = []
self.scene_names = []
start_index_valid_seg = 1 # start from 1, 0 is for invalid
total_subset = 0
for scene_set in sub_sets: # [0, n_class]
name = scene_set.dataset
self.names.append(name)
n_samples = len(scene_set)
n_class = scene_set.n_class
self.seg_offsets = self.seg_offsets + [start_index_valid_seg for v in range(len(scene_set))]
start_index_valid_seg = start_index_valid_seg + n_class - 1
self.sub_set_index = self.sub_set_index + [total_subset for k in range(n_samples)]
self.sub_set_item_index = self.sub_set_item_index + [k for k in range(n_samples)]
# self.dataset_names = self.dataset_names + [name for k in range(n_samples)]
self.scene_names = self.scene_names + [name for k in range(n_samples)]
total_subset += 1
self.n_class = start_index_valid_seg
print('Load {} images {} segs from {} subsets from {}'.format(len(self.sub_set_item_index), self.n_class,
len(sub_sets), self.names))
def __len__(self):
return len(self.sub_set_item_index)
def __getitem__(self, idx):
subset_idx = self.sub_set_index[idx]
item_idx = self.sub_set_item_index[idx]
scene_name = self.scene_names[idx]
out = self.sub_sets[subset_idx][item_idx]
org_gt_seg = out['gt_seg']
org_gt_cls = out['gt_cls']
org_gt_cls_dist = out['gt_cls_dist']
org_gt_n_seg = out['gt_n_seg']
offset = self.seg_offsets[idx]
org_n_class = self.sub_sets[subset_idx].n_class
gt_seg = np.zeros(shape=(org_gt_seg.shape[0],), dtype=int) # [0, ..., n_features]
gt_n_seg = np.zeros(shape=(self.n_class,), dtype=int)
gt_cls = np.zeros(shape=(self.n_class,), dtype=int)
gt_cls_dist = np.zeros(shape=(self.n_class,), dtype=float)
# copy invalid segments
gt_n_seg[0] = org_gt_n_seg[0]
gt_cls[0] = org_gt_cls[0]
gt_cls_dist[0] = org_gt_cls_dist[0]
# print('org: ', org_n_class, org_gt_seg.shape, org_gt_n_seg.shape, org_gt_seg)
# copy valid segments
gt_seg[org_gt_seg > 0] = org_gt_seg[org_gt_seg > 0] + offset - 1 # [0, ..., 1023]
gt_n_seg[offset:offset + org_n_class - 1] = org_gt_n_seg[1:] # [0...,n_seg]
gt_cls[offset:offset + org_n_class - 1] = org_gt_cls[1:] # [0, ..., n_seg]
gt_cls_dist[offset:offset + org_n_class - 1] = org_gt_cls_dist[1:] # [0, ..., n_seg]
out['gt_seg'] = gt_seg
out['gt_cls'] = gt_cls
out['gt_cls_dist'] = gt_cls_dist
out['gt_n_seg'] = gt_n_seg
# print('gt: ', org_n_class, gt_seg.shape, gt_n_seg.shape, gt_seg)
out['scene_name'] = scene_name
# out['org_gt_seg'] = org_gt_seg
# out['org_gt_n_seg'] = org_gt_n_seg
# out['org_gt_cls'] = org_gt_cls
# out['org_gt_cls_dist'] = org_gt_cls_dist
return out
|