File size: 5,724 Bytes
a80d6bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c74a070
a80d6bb
 
c74a070
 
a80d6bb
 
 
c74a070
 
 
a80d6bb
 
 
 
c74a070
 
a80d6bb
 
 
c74a070
a80d6bb
 
 
 
c74a070
 
 
a80d6bb
c74a070
 
 
a80d6bb
c74a070
 
 
 
 
a80d6bb
 
c74a070
a80d6bb
c74a070
 
 
 
 
 
a80d6bb
 
 
c74a070
 
 
 
 
a80d6bb
 
 
c74a070
 
 
 
a80d6bb
 
c74a070
 
 
 
 
 
 
 
 
 
 
a80d6bb
c74a070
a80d6bb
 
c74a070
a80d6bb
 
 
 
 
c74a070
 
 
 
 
 
 
a80d6bb
 
c74a070
a80d6bb
 
 
 
c74a070
a80d6bb
c74a070
 
a80d6bb
 
c74a070
 
 
 
 
a80d6bb
 
 
c74a070
 
 
 
 
 
 
 
 
 
 
 
a80d6bb
c74a070
a80d6bb
 
c74a070
a80d6bb
 
 
 
 
 
c74a070
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
from __future__ import print_function, division
import os, random, time
import torch
import numpy as np
from torch.utils.data import Dataset
from torchvision import transforms, utils
import rawpy
from glob import glob
from PIL import Image as PILImage
import numbers
from scipy.misc import imread
from .base_dataset import BaseDataset


class FiveKDatasetTrain(BaseDataset):
    def __init__(self, opt):
        super().__init__(opt=opt)
        self.patch_size = 256
        input_RAWs_WBs, target_RGBs = self.load(is_train=True)
        assert len(input_RAWs_WBs) == len(target_RGBs)
        self.data = {"input_RAWs_WBs": input_RAWs_WBs, "target_RGBs": target_RGBs}

    def random_flip(self, input_raw, target_rgb):
        idx = np.random.randint(2)
        input_raw = np.flip(input_raw, axis=idx).copy()
        target_rgb = np.flip(target_rgb, axis=idx).copy()

        return input_raw, target_rgb

    def random_rotate(self, input_raw, target_rgb):
        idx = np.random.randint(4)
        input_raw = np.rot90(input_raw, k=idx)
        target_rgb = np.rot90(target_rgb, k=idx)

        return input_raw, target_rgb

    def random_crop(self, patch_size, input_raw, target_rgb, flow=False, demos=False):
        H, W, _ = input_raw.shape
        rnd_h = random.randint(0, max(0, H - patch_size))
        rnd_w = random.randint(0, max(0, W - patch_size))

        patch_input_raw = input_raw[
            rnd_h : rnd_h + patch_size, rnd_w : rnd_w + patch_size, :
        ]
        if flow or demos:
            patch_target_rgb = target_rgb[
                rnd_h : rnd_h + patch_size, rnd_w : rnd_w + patch_size, :
            ]
        else:
            patch_target_rgb = target_rgb[
                rnd_h * 2 : rnd_h * 2 + patch_size * 2,
                rnd_w * 2 : rnd_w * 2 + patch_size * 2,
                :,
            ]

        return patch_input_raw, patch_target_rgb

    def aug(self, patch_size, input_raw, target_rgb, flow=False, demos=False):
        input_raw, target_rgb = self.random_crop(
            patch_size, input_raw, target_rgb, flow=flow, demos=demos
        )
        input_raw, target_rgb = self.random_rotate(input_raw, target_rgb)
        input_raw, target_rgb = self.random_flip(input_raw, target_rgb)

        return input_raw, target_rgb

    def __len__(self):
        return len(self.data["input_RAWs_WBs"])

    def __getitem__(self, idx):
        input_raw_wb_path = self.data["input_RAWs_WBs"][idx]
        target_rgb_path = self.data["target_RGBs"][idx]

        target_rgb_img = imread(target_rgb_path)
        input_raw_wb = np.load(input_raw_wb_path)
        input_raw_img = input_raw_wb["raw"]
        wb = input_raw_wb["wb"]
        wb = wb / wb.max()
        input_raw_img = input_raw_img * wb[:-1]

        self.patch_size = 256
        input_raw_img, target_rgb_img = self.aug(
            self.patch_size, input_raw_img, target_rgb_img, flow=True, demos=True
        )

        if self.gamma:
            norm_value = (
                np.power(4095, 1 / 2.2)
                if self.camera_name == "Canon_EOS_5D"
                else np.power(16383, 1 / 2.2)
            )
            input_raw_img = np.power(input_raw_img, 1 / 2.2)
        else:
            norm_value = 4095 if self.camera_name == "Canon_EOS_5D" else 16383

        target_rgb_img = self.norm_img(target_rgb_img, max_value=255)
        input_raw_img = self.norm_img(input_raw_img, max_value=norm_value)
        target_raw_img = input_raw_img.copy()

        input_raw_img = self.np2tensor(input_raw_img).float()
        target_rgb_img = self.np2tensor(target_rgb_img).float()
        target_raw_img = self.np2tensor(target_raw_img).float()

        sample = {
            "input_raw": input_raw_img,
            "target_rgb": target_rgb_img,
            "target_raw": target_raw_img,
            "file_name": input_raw_wb_path.split("/")[-1].split(".")[0],
        }
        return sample


class FiveKDatasetTest(BaseDataset):
    def __init__(self, opt):
        super().__init__(opt=opt)
        self.patch_size = 256

        input_RAWs_WBs, target_RGBs = self.load(is_train=False)
        assert len(input_RAWs_WBs) == len(target_RGBs)
        self.data = {"input_RAWs_WBs": input_RAWs_WBs, "target_RGBs": target_RGBs}

    def __len__(self):
        return len(self.data["input_RAWs_WBs"])

    def __getitem__(self, idx):
        input_raw_wb_path = self.data["input_RAWs_WBs"][idx]
        target_rgb_path = self.data["target_RGBs"][idx]

        target_rgb_img = imread(target_rgb_path)
        input_raw_wb = np.load(input_raw_wb_path)
        input_raw_img = input_raw_wb["raw"]
        wb = input_raw_wb["wb"]
        wb = wb / wb.max()
        input_raw_img = input_raw_img * wb[:-1]

        if self.gamma:
            norm_value = (
                np.power(4095, 1 / 2.2)
                if self.camera_name == "Canon_EOS_5D"
                else np.power(16383, 1 / 2.2)
            )
            input_raw_img = np.power(input_raw_img, 1 / 2.2)
        else:
            norm_value = 4095 if self.camera_name == "Canon_EOS_5D" else 16383

        target_rgb_img = self.norm_img(target_rgb_img, max_value=255)
        input_raw_img = self.norm_img(input_raw_img, max_value=norm_value)
        target_raw_img = input_raw_img.copy()

        input_raw_img = self.np2tensor(input_raw_img).float()
        target_rgb_img = self.np2tensor(target_rgb_img).float()
        target_raw_img = self.np2tensor(target_raw_img).float()

        sample = {
            "input_raw": input_raw_img,
            "target_rgb": target_rgb_img,
            "target_raw": target_raw_img,
            "file_name": input_raw_wb_path.split("/")[-1].split(".")[0],
        }
        return sample