File size: 4,916 Bytes
88b0dcb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
"""
@date: 2021/6/16
@description:
"""

import os

from dataset.pano_s2d3d_dataset import PanoS2D3DDataset
from utils.logger import get_logger


class PanoS2D3DMixDataset(PanoS2D3DDataset):
    def __init__(self, root_dir, mode, shape=None, max_wall_num=0, aug=None, camera_height=1.6, logger=None,
                 split_list=None, patch_num=256, keys=None, for_test_index=None, subset=None):
        assert subset == 's2d3d' or subset == 'pano', 'error subset'
        super().__init__(root_dir, None, shape, max_wall_num, aug, camera_height, logger,
                         split_list, patch_num, keys, None, subset)
        if logger is None:
            logger = get_logger()
        self.mode = mode
        if mode == 'train':
            if subset == 'pano':
                s2d3d_train_data = PanoS2D3DDataset(root_dir, 'train', shape, max_wall_num, aug, camera_height, logger,
                                                    split_list, patch_num, keys, None, 's2d3d').data
                s2d3d_val_data = PanoS2D3DDataset(root_dir, 'val', shape, max_wall_num, aug, camera_height, logger,
                                                  split_list, patch_num, keys, None, 's2d3d').data
                s2d3d_test_data = PanoS2D3DDataset(root_dir, 'test', shape, max_wall_num, aug, camera_height, logger,
                                                   split_list, patch_num, keys, None, 's2d3d').data
                s2d3d_all_data = s2d3d_train_data + s2d3d_val_data + s2d3d_test_data

                pano_train_data = PanoS2D3DDataset(root_dir, 'train', shape, max_wall_num, aug, camera_height, logger,
                                                   split_list, patch_num, keys, None, 'pano').data
                self.data = s2d3d_all_data + pano_train_data
            elif subset == 's2d3d':
                pano_train_data = PanoS2D3DDataset(root_dir, 'train', shape, max_wall_num, aug, camera_height, logger,
                                                   split_list, patch_num, keys, None, 'pano').data
                pano_val_data = PanoS2D3DDataset(root_dir, 'val', shape, max_wall_num, aug, camera_height, logger,
                                                 split_list, patch_num, keys, None, 'pano').data
                pano_test_data = PanoS2D3DDataset(root_dir, 'test', shape, max_wall_num, aug, camera_height, logger,
                                                  split_list, patch_num, keys, None, 'pano').data
                pano_all_data = pano_train_data + pano_val_data + pano_test_data

                s2d3d_train_data = PanoS2D3DDataset(root_dir, 'train', shape, max_wall_num, aug, camera_height, logger,
                                                    split_list, patch_num, keys, None, 's2d3d').data
                self.data = pano_all_data + s2d3d_train_data
        else:
            self.data = PanoS2D3DDataset(root_dir, mode, shape, max_wall_num, aug, camera_height, logger,
                                         split_list, patch_num, keys, None, subset).data

        if for_test_index is not None:
            self.data = self.data[:for_test_index]
        logger.info(f"Build dataset mode: {self.mode}  valid: {len(self.data)}")


if __name__ == '__main__':
    import numpy as np
    from PIL import Image

    from tqdm import tqdm
    from visualization.boundary import draw_boundaries
    from visualization.floorplan import draw_floorplan
    from utils.boundary import depth2boundaries
    from utils.conversion import uv2xyz

    modes = ['test', 'val', 'train']
    for i in range(1):
        for mode in modes:
            print(mode)
            mp3d_dataset = PanoS2D3DMixDataset(root_dir='../src/dataset/pano_s2d3d', mode=mode, aug={
                # 'STRETCH': True,
                # 'ROTATE': True,
                # 'FLIP': True,
                # 'GAMMA': True
            }, subset='pano')
            continue
            save_dir = f'../src/dataset/pano_s2d3d/visualization1/{mode}'
            if not os.path.isdir(save_dir):
                os.makedirs(save_dir)

            bar = tqdm(mp3d_dataset, ncols=100)
            for data in bar:
                bar.set_description(f"Processing {data['id']}")
                boundary_list = depth2boundaries(data['ratio'], data['depth'], step=None)
                pano_img = draw_boundaries(data['image'].transpose(1, 2, 0), boundary_list=boundary_list, show=False)
                Image.fromarray((pano_img * 255).astype(np.uint8)).save(
                    os.path.join(save_dir, f"{data['id']}_boundary.png"))

                floorplan = draw_floorplan(uv2xyz(boundary_list[0])[..., ::2], show=False,
                                           marker_color=None, center_color=0.8, show_radius=None)
                Image.fromarray((floorplan.squeeze() * 255).astype(np.uint8)).save(
                    os.path.join(save_dir, f"{data['id']}_floorplan.png"))