Spaces:
Running
Running
from __future__ import print_function, division | |
import os, random, time | |
import torch | |
import numpy as np | |
from torch.utils.data import Dataset | |
from torchvision import transforms, utils | |
import rawpy | |
from glob import glob | |
from PIL import Image as PILImage | |
import numbers | |
from scipy.misc import imread | |
from .base_dataset import BaseDataset | |
class FiveKDatasetTrain(BaseDataset): | |
def __init__(self, opt): | |
super().__init__(opt=opt) | |
self.patch_size = 256 | |
input_RAWs_WBs, target_RGBs = self.load(is_train=True) | |
assert len(input_RAWs_WBs) == len(target_RGBs) | |
self.data = {"input_RAWs_WBs": input_RAWs_WBs, "target_RGBs": target_RGBs} | |
def random_flip(self, input_raw, target_rgb): | |
idx = np.random.randint(2) | |
input_raw = np.flip(input_raw, axis=idx).copy() | |
target_rgb = np.flip(target_rgb, axis=idx).copy() | |
return input_raw, target_rgb | |
def random_rotate(self, input_raw, target_rgb): | |
idx = np.random.randint(4) | |
input_raw = np.rot90(input_raw, k=idx) | |
target_rgb = np.rot90(target_rgb, k=idx) | |
return input_raw, target_rgb | |
def random_crop(self, patch_size, input_raw, target_rgb, flow=False, demos=False): | |
H, W, _ = input_raw.shape | |
rnd_h = random.randint(0, max(0, H - patch_size)) | |
rnd_w = random.randint(0, max(0, W - patch_size)) | |
patch_input_raw = input_raw[ | |
rnd_h : rnd_h + patch_size, rnd_w : rnd_w + patch_size, : | |
] | |
if flow or demos: | |
patch_target_rgb = target_rgb[ | |
rnd_h : rnd_h + patch_size, rnd_w : rnd_w + patch_size, : | |
] | |
else: | |
patch_target_rgb = target_rgb[ | |
rnd_h * 2 : rnd_h * 2 + patch_size * 2, | |
rnd_w * 2 : rnd_w * 2 + patch_size * 2, | |
:, | |
] | |
return patch_input_raw, patch_target_rgb | |
def aug(self, patch_size, input_raw, target_rgb, flow=False, demos=False): | |
input_raw, target_rgb = self.random_crop( | |
patch_size, input_raw, target_rgb, flow=flow, demos=demos | |
) | |
input_raw, target_rgb = self.random_rotate(input_raw, target_rgb) | |
input_raw, target_rgb = self.random_flip(input_raw, target_rgb) | |
return input_raw, target_rgb | |
def __len__(self): | |
return len(self.data["input_RAWs_WBs"]) | |
def __getitem__(self, idx): | |
input_raw_wb_path = self.data["input_RAWs_WBs"][idx] | |
target_rgb_path = self.data["target_RGBs"][idx] | |
target_rgb_img = imread(target_rgb_path) | |
input_raw_wb = np.load(input_raw_wb_path) | |
input_raw_img = input_raw_wb["raw"] | |
wb = input_raw_wb["wb"] | |
wb = wb / wb.max() | |
input_raw_img = input_raw_img * wb[:-1] | |
self.patch_size = 256 | |
input_raw_img, target_rgb_img = self.aug( | |
self.patch_size, input_raw_img, target_rgb_img, flow=True, demos=True | |
) | |
if self.gamma: | |
norm_value = ( | |
np.power(4095, 1 / 2.2) | |
if self.camera_name == "Canon_EOS_5D" | |
else np.power(16383, 1 / 2.2) | |
) | |
input_raw_img = np.power(input_raw_img, 1 / 2.2) | |
else: | |
norm_value = 4095 if self.camera_name == "Canon_EOS_5D" else 16383 | |
target_rgb_img = self.norm_img(target_rgb_img, max_value=255) | |
input_raw_img = self.norm_img(input_raw_img, max_value=norm_value) | |
target_raw_img = input_raw_img.copy() | |
input_raw_img = self.np2tensor(input_raw_img).float() | |
target_rgb_img = self.np2tensor(target_rgb_img).float() | |
target_raw_img = self.np2tensor(target_raw_img).float() | |
sample = { | |
"input_raw": input_raw_img, | |
"target_rgb": target_rgb_img, | |
"target_raw": target_raw_img, | |
"file_name": input_raw_wb_path.split("/")[-1].split(".")[0], | |
} | |
return sample | |
class FiveKDatasetTest(BaseDataset): | |
def __init__(self, opt): | |
super().__init__(opt=opt) | |
self.patch_size = 256 | |
input_RAWs_WBs, target_RGBs = self.load(is_train=False) | |
assert len(input_RAWs_WBs) == len(target_RGBs) | |
self.data = {"input_RAWs_WBs": input_RAWs_WBs, "target_RGBs": target_RGBs} | |
def __len__(self): | |
return len(self.data["input_RAWs_WBs"]) | |
def __getitem__(self, idx): | |
input_raw_wb_path = self.data["input_RAWs_WBs"][idx] | |
target_rgb_path = self.data["target_RGBs"][idx] | |
target_rgb_img = imread(target_rgb_path) | |
input_raw_wb = np.load(input_raw_wb_path) | |
input_raw_img = input_raw_wb["raw"] | |
wb = input_raw_wb["wb"] | |
wb = wb / wb.max() | |
input_raw_img = input_raw_img * wb[:-1] | |
if self.gamma: | |
norm_value = ( | |
np.power(4095, 1 / 2.2) | |
if self.camera_name == "Canon_EOS_5D" | |
else np.power(16383, 1 / 2.2) | |
) | |
input_raw_img = np.power(input_raw_img, 1 / 2.2) | |
else: | |
norm_value = 4095 if self.camera_name == "Canon_EOS_5D" else 16383 | |
target_rgb_img = self.norm_img(target_rgb_img, max_value=255) | |
input_raw_img = self.norm_img(input_raw_img, max_value=norm_value) | |
target_raw_img = input_raw_img.copy() | |
input_raw_img = self.np2tensor(input_raw_img).float() | |
target_rgb_img = self.np2tensor(target_rgb_img).float() | |
target_raw_img = self.np2tensor(target_raw_img).float() | |
sample = { | |
"input_raw": input_raw_img, | |
"target_rgb": target_rgb_img, | |
"target_raw": target_raw_img, | |
"file_name": input_raw_wb_path.split("/")[-1].split(".")[0], | |
} | |
return sample | |