|
from __future__ import print_function, division |
|
import os, random, time |
|
import torch |
|
import numpy as np |
|
from torch.utils.data import Dataset |
|
from torchvision import transforms, utils |
|
import rawpy |
|
from glob import glob |
|
from PIL import Image as PILImage |
|
import numbers |
|
from scipy.misc import imread |
|
from .base_dataset import BaseDataset |
|
|
|
|
|
class FiveKDatasetTrain(BaseDataset): |
|
def __init__(self, opt): |
|
super().__init__(opt=opt) |
|
self.patch_size = 256 |
|
input_RAWs_WBs, target_RGBs = self.load(is_train=True) |
|
assert len(input_RAWs_WBs) == len(target_RGBs) |
|
self.data = {'input_RAWs_WBs':input_RAWs_WBs, 'target_RGBs':target_RGBs} |
|
|
|
def random_flip(self, input_raw, target_rgb): |
|
idx = np.random.randint(2) |
|
input_raw = np.flip(input_raw,axis=idx).copy() |
|
target_rgb = np.flip(target_rgb,axis=idx).copy() |
|
|
|
return input_raw, target_rgb |
|
|
|
def random_rotate(self, input_raw, target_rgb): |
|
idx = np.random.randint(4) |
|
input_raw = np.rot90(input_raw,k=idx) |
|
target_rgb = np.rot90(target_rgb,k=idx) |
|
|
|
return input_raw, target_rgb |
|
|
|
def random_crop(self, patch_size, input_raw, target_rgb,flow=False,demos=False): |
|
H, W, _ = input_raw.shape |
|
rnd_h = random.randint(0, max(0, H - patch_size)) |
|
rnd_w = random.randint(0, max(0, W - patch_size)) |
|
|
|
patch_input_raw = input_raw[rnd_h:rnd_h + patch_size, rnd_w:rnd_w + patch_size, :] |
|
if flow or demos: |
|
patch_target_rgb = target_rgb[rnd_h:rnd_h + patch_size, rnd_w:rnd_w + patch_size, :] |
|
else: |
|
patch_target_rgb = target_rgb[rnd_h*2:rnd_h*2 + patch_size*2, rnd_w*2:rnd_w*2 + patch_size*2, :] |
|
|
|
return patch_input_raw, patch_target_rgb |
|
|
|
def aug(self, patch_size, input_raw, target_rgb, flow=False, demos=False): |
|
input_raw, target_rgb = self.random_crop(patch_size, input_raw,target_rgb,flow=flow, demos=demos) |
|
input_raw, target_rgb = self.random_rotate(input_raw,target_rgb) |
|
input_raw, target_rgb = self.random_flip(input_raw,target_rgb) |
|
|
|
return input_raw, target_rgb |
|
|
|
def __len__(self): |
|
return len(self.data['input_RAWs_WBs']) |
|
|
|
def __getitem__(self, idx): |
|
input_raw_wb_path = self.data['input_RAWs_WBs'][idx] |
|
target_rgb_path = self.data['target_RGBs'][idx] |
|
|
|
target_rgb_img = imread(target_rgb_path) |
|
input_raw_wb = np.load(input_raw_wb_path) |
|
input_raw_img = input_raw_wb['raw'] |
|
wb = input_raw_wb['wb'] |
|
wb = wb / wb.max() |
|
input_raw_img = input_raw_img * wb[:-1] |
|
|
|
self.patch_size = 256 |
|
input_raw_img, target_rgb_img = self.aug(self.patch_size, input_raw_img, target_rgb_img, flow=True, demos=True) |
|
|
|
if self.gamma: |
|
norm_value = np.power(4095, 1/2.2) if self.camera_name=='Canon_EOS_5D' else np.power(16383, 1/2.2) |
|
input_raw_img = np.power(input_raw_img, 1/2.2) |
|
else: |
|
norm_value = 4095 if self.camera_name=='Canon_EOS_5D' else 16383 |
|
|
|
target_rgb_img = self.norm_img(target_rgb_img, max_value=255) |
|
input_raw_img = self.norm_img(input_raw_img, max_value=norm_value) |
|
target_raw_img = input_raw_img.copy() |
|
|
|
input_raw_img = self.np2tensor(input_raw_img).float() |
|
target_rgb_img = self.np2tensor(target_rgb_img).float() |
|
target_raw_img = self.np2tensor(target_raw_img).float() |
|
|
|
sample = {'input_raw':input_raw_img, 'target_rgb':target_rgb_img, 'target_raw':target_raw_img, |
|
'file_name':input_raw_wb_path.split("/")[-1].split(".")[0]} |
|
return sample |
|
|
|
class FiveKDatasetTest(BaseDataset): |
|
def __init__(self, opt): |
|
super().__init__(opt=opt) |
|
self.patch_size = 256 |
|
|
|
input_RAWs_WBs, target_RGBs = self.load(is_train=False) |
|
assert len(input_RAWs_WBs) == len(target_RGBs) |
|
self.data = {'input_RAWs_WBs':input_RAWs_WBs, 'target_RGBs':target_RGBs} |
|
|
|
def __len__(self): |
|
return len(self.data['input_RAWs_WBs']) |
|
|
|
def __getitem__(self, idx): |
|
input_raw_wb_path = self.data['input_RAWs_WBs'][idx] |
|
target_rgb_path = self.data['target_RGBs'][idx] |
|
|
|
target_rgb_img = imread(target_rgb_path) |
|
input_raw_wb = np.load(input_raw_wb_path) |
|
input_raw_img = input_raw_wb['raw'] |
|
wb = input_raw_wb['wb'] |
|
wb = wb / wb.max() |
|
input_raw_img = input_raw_img * wb[:-1] |
|
|
|
if self.gamma: |
|
norm_value = np.power(4095, 1/2.2) if self.camera_name=='Canon_EOS_5D' else np.power(16383, 1/2.2) |
|
input_raw_img = np.power(input_raw_img, 1/2.2) |
|
else: |
|
norm_value = 4095 if self.camera_name=='Canon_EOS_5D' else 16383 |
|
|
|
target_rgb_img = self.norm_img(target_rgb_img, max_value=255) |
|
input_raw_img = self.norm_img(input_raw_img, max_value=norm_value) |
|
target_raw_img = input_raw_img.copy() |
|
|
|
input_raw_img = self.np2tensor(input_raw_img).float() |
|
target_rgb_img = self.np2tensor(target_rgb_img).float() |
|
target_raw_img = self.np2tensor(target_raw_img).float() |
|
|
|
sample = {'input_raw':input_raw_img, 'target_rgb':target_rgb_img, 'target_raw':target_raw_img, |
|
'file_name':input_raw_wb_path.split("/")[-1].split(".")[0]} |
|
return sample |
|
|
|
|