Vincentqyw
fix: roma
358ab8f
raw
history blame
5.72 kB
from __future__ import print_function, division
import os, random, time
import torch
import numpy as np
from torch.utils.data import Dataset
from torchvision import transforms, utils
import rawpy
from glob import glob
from PIL import Image as PILImage
import numbers
from scipy.misc import imread
from .base_dataset import BaseDataset
class FiveKDatasetTrain(BaseDataset):
def __init__(self, opt):
super().__init__(opt=opt)
self.patch_size = 256
input_RAWs_WBs, target_RGBs = self.load(is_train=True)
assert len(input_RAWs_WBs) == len(target_RGBs)
self.data = {"input_RAWs_WBs": input_RAWs_WBs, "target_RGBs": target_RGBs}
def random_flip(self, input_raw, target_rgb):
idx = np.random.randint(2)
input_raw = np.flip(input_raw, axis=idx).copy()
target_rgb = np.flip(target_rgb, axis=idx).copy()
return input_raw, target_rgb
def random_rotate(self, input_raw, target_rgb):
idx = np.random.randint(4)
input_raw = np.rot90(input_raw, k=idx)
target_rgb = np.rot90(target_rgb, k=idx)
return input_raw, target_rgb
def random_crop(self, patch_size, input_raw, target_rgb, flow=False, demos=False):
H, W, _ = input_raw.shape
rnd_h = random.randint(0, max(0, H - patch_size))
rnd_w = random.randint(0, max(0, W - patch_size))
patch_input_raw = input_raw[
rnd_h : rnd_h + patch_size, rnd_w : rnd_w + patch_size, :
]
if flow or demos:
patch_target_rgb = target_rgb[
rnd_h : rnd_h + patch_size, rnd_w : rnd_w + patch_size, :
]
else:
patch_target_rgb = target_rgb[
rnd_h * 2 : rnd_h * 2 + patch_size * 2,
rnd_w * 2 : rnd_w * 2 + patch_size * 2,
:,
]
return patch_input_raw, patch_target_rgb
def aug(self, patch_size, input_raw, target_rgb, flow=False, demos=False):
input_raw, target_rgb = self.random_crop(
patch_size, input_raw, target_rgb, flow=flow, demos=demos
)
input_raw, target_rgb = self.random_rotate(input_raw, target_rgb)
input_raw, target_rgb = self.random_flip(input_raw, target_rgb)
return input_raw, target_rgb
def __len__(self):
return len(self.data["input_RAWs_WBs"])
def __getitem__(self, idx):
input_raw_wb_path = self.data["input_RAWs_WBs"][idx]
target_rgb_path = self.data["target_RGBs"][idx]
target_rgb_img = imread(target_rgb_path)
input_raw_wb = np.load(input_raw_wb_path)
input_raw_img = input_raw_wb["raw"]
wb = input_raw_wb["wb"]
wb = wb / wb.max()
input_raw_img = input_raw_img * wb[:-1]
self.patch_size = 256
input_raw_img, target_rgb_img = self.aug(
self.patch_size, input_raw_img, target_rgb_img, flow=True, demos=True
)
if self.gamma:
norm_value = (
np.power(4095, 1 / 2.2)
if self.camera_name == "Canon_EOS_5D"
else np.power(16383, 1 / 2.2)
)
input_raw_img = np.power(input_raw_img, 1 / 2.2)
else:
norm_value = 4095 if self.camera_name == "Canon_EOS_5D" else 16383
target_rgb_img = self.norm_img(target_rgb_img, max_value=255)
input_raw_img = self.norm_img(input_raw_img, max_value=norm_value)
target_raw_img = input_raw_img.copy()
input_raw_img = self.np2tensor(input_raw_img).float()
target_rgb_img = self.np2tensor(target_rgb_img).float()
target_raw_img = self.np2tensor(target_raw_img).float()
sample = {
"input_raw": input_raw_img,
"target_rgb": target_rgb_img,
"target_raw": target_raw_img,
"file_name": input_raw_wb_path.split("/")[-1].split(".")[0],
}
return sample
class FiveKDatasetTest(BaseDataset):
def __init__(self, opt):
super().__init__(opt=opt)
self.patch_size = 256
input_RAWs_WBs, target_RGBs = self.load(is_train=False)
assert len(input_RAWs_WBs) == len(target_RGBs)
self.data = {"input_RAWs_WBs": input_RAWs_WBs, "target_RGBs": target_RGBs}
def __len__(self):
return len(self.data["input_RAWs_WBs"])
def __getitem__(self, idx):
input_raw_wb_path = self.data["input_RAWs_WBs"][idx]
target_rgb_path = self.data["target_RGBs"][idx]
target_rgb_img = imread(target_rgb_path)
input_raw_wb = np.load(input_raw_wb_path)
input_raw_img = input_raw_wb["raw"]
wb = input_raw_wb["wb"]
wb = wb / wb.max()
input_raw_img = input_raw_img * wb[:-1]
if self.gamma:
norm_value = (
np.power(4095, 1 / 2.2)
if self.camera_name == "Canon_EOS_5D"
else np.power(16383, 1 / 2.2)
)
input_raw_img = np.power(input_raw_img, 1 / 2.2)
else:
norm_value = 4095 if self.camera_name == "Canon_EOS_5D" else 16383
target_rgb_img = self.norm_img(target_rgb_img, max_value=255)
input_raw_img = self.norm_img(input_raw_img, max_value=norm_value)
target_raw_img = input_raw_img.copy()
input_raw_img = self.np2tensor(input_raw_img).float()
target_rgb_img = self.np2tensor(target_rgb_img).float()
target_raw_img = self.np2tensor(target_raw_img).float()
sample = {
"input_raw": input_raw_img,
"target_rgb": target_rgb_img,
"target_raw": target_raw_img,
"file_name": input_raw_wb_path.split("/")[-1].split(".")[0],
}
return sample