|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torch.autograd import Variable |
|
import torch |
|
import numpy as np |
|
import os, time, random |
|
import argparse |
|
from torch.utils.data import Dataset, DataLoader |
|
from PIL import Image as PILImage |
|
from glob import glob |
|
from tqdm import tqdm |
|
import rawpy |
|
import colour_demosaicing |
|
|
|
from .InvISP.model.model import InvISPNet |
|
from .utils.common import Notify |
|
from datasets.noise import camera_params, addGStarNoise, addPStarNoise, addQuantNoise, addRowNoise, sampleK |
|
|
|
|
|
class NoiseSimulator: |
|
def __init__(self, device, ckpt_path='./datasets/InvISP/pretrained/canon.pth'): |
|
self.device = device |
|
|
|
|
|
self.net = InvISPNet(channel_in=3, channel_out=3, block_num=8).to(self.device).eval() |
|
self.net.load_state_dict(torch.load(ckpt_path), strict=False) |
|
print(Notify.INFO, "Loaded ISPNet checkpoint: {}".format(ckpt_path), Notify.ENDC) |
|
|
|
|
|
self.wb = np.array([2020.0, 1024.0, 1458.0, 1024.0]) |
|
|
|
|
|
self.camera_params = camera_params |
|
|
|
|
|
self.ratio_min = 50 |
|
self.ratio_max = 150 |
|
pass |
|
|
|
|
|
|
|
|
|
def invDemosaic(self, img): |
|
img_R = img[::2, ::2, 0] |
|
img_G1 = img[::2, 1::2, 1] |
|
img_G2 = img[1::2, ::2, 1] |
|
img_B = img[1::2, 1::2, 2] |
|
raw_img = np.ones(img.shape[:2]) |
|
raw_img[::2, ::2] = img_R |
|
raw_img[::2, 1::2] = img_G1 |
|
raw_img[1::2, ::2] = img_G2 |
|
raw_img[1::2, 1::2] = img_B |
|
return raw_img |
|
|
|
|
|
|
|
|
|
def demosaicNearest(self, img): |
|
raw = np.ones((img.shape[0], img.shape[1], 3)) |
|
raw[::2, ::2, 0] = img[::2, ::2] |
|
raw[::2, 1::2, 0] = img[::2, ::2] |
|
raw[1::2, ::2, 0] = img[::2, ::2] |
|
raw[1::2, 1::2, 0] = img[::2, ::2] |
|
raw[::2, ::2, 2] = img[1::2, 1::2] |
|
raw[::2, 1::2, 2] = img[1::2, 1::2] |
|
raw[1::2, ::2, 2] = img[1::2, 1::2] |
|
raw[1::2, 1::2, 2] = img[1::2, 1::2] |
|
raw[::2, ::2, 1] = img[::2, 1::2] |
|
raw[::2, 1::2, 1] = img[::2, 1::2] |
|
raw[1::2, ::2, 1] = img[1::2, ::2] |
|
raw[1::2, 1::2, 1] = img[1::2, ::2] |
|
return raw |
|
|
|
|
|
|
|
|
|
def demosaic(self, img): |
|
return colour_demosaicing.demosaicing_CFA_Bayer_bilinear(img, 'RGGB') |
|
|
|
|
|
def path2rgb(self, path): |
|
return torch.from_numpy(np.array(PILImage.open(path))/255.0) |
|
|
|
|
|
|
|
|
|
def rgb2raw(self, rgb, batched=False): |
|
|
|
if not batched: |
|
rgb = rgb.unsqueeze(0) |
|
|
|
rgb = rgb.permute(0,3,1,2).float().to(self.device) |
|
with torch.no_grad(): |
|
reconstruct_raw = self.net(rgb, rev=True) |
|
|
|
pred_raw = reconstruct_raw.detach().permute(0,2,3,1) |
|
pred_raw = torch.clamp(pred_raw, 0, 1) |
|
|
|
if not batched: |
|
pred_raw = pred_raw[0, ...] |
|
|
|
pred_raw = pred_raw.cpu().numpy() |
|
|
|
|
|
norm_value = np.power(16383, 1/2.2) |
|
pred_raw *= norm_value |
|
pred_raw = np.power(pred_raw, 2.2) |
|
|
|
|
|
wb = self.wb / self.wb.max() |
|
pred_raw = pred_raw / wb[:-1] |
|
|
|
|
|
pred_raw += self.camera_params['black_level'] |
|
|
|
|
|
if not batched: |
|
pred_raw = self.invDemosaic(pred_raw) |
|
else: |
|
preds = [] |
|
for i in range(pred_raw.shape[0]): |
|
preds.append(self.invDemosaic(pred_raw[i])) |
|
pred_raw = np.stack(preds, axis=0) |
|
|
|
return pred_raw |
|
|
|
|
|
def raw2noisyRaw(self, raw, ratio_dec=1, batched=False): |
|
if not batched: |
|
ratio = (random.uniform(self.ratio_min, self.ratio_max) - 1) * ratio_dec + 1 |
|
raw = raw.copy() / ratio |
|
|
|
K = sampleK(self.camera_params['Kmin'], self.camera_params['Kmax']) |
|
q = 1 / (self.camera_params['max_value'] - self.camera_params['black_level']) |
|
|
|
raw = addPStarNoise(raw, K) |
|
raw = addGStarNoise(raw, K, self.camera_params['G_shape'], self.camera_params['Profile-1']['G_scale']) |
|
raw = addRowNoise(raw, K, self.camera_params['Profile-1']['R_scale']) |
|
raw = addQuantNoise(raw, q) |
|
raw *= ratio |
|
return raw |
|
|
|
else: |
|
raw = raw.copy() |
|
for i in range(raw.shape[0]): |
|
ratio = random.uniform(self.ratio_min, self.ratio_max) |
|
raw[i] /= ratio |
|
|
|
K = sampleK(self.camera_params['Kmin'], self.camera_params['Kmax']) |
|
q = 1 / (self.camera_params['max_value'] - self.camera_params['black_level']) |
|
|
|
raw[i] = addPStarNoise(raw[i], K) |
|
raw[i] = addGStarNoise(raw[i], K, self.camera_params['G_shape'], self.camera_params['Profile-1']['G_scale']) |
|
raw[i] = addRowNoise(raw[i], K, self.camera_params['Profile-1']['R_scale']) |
|
raw[i] = addQuantNoise(raw[i], q) |
|
raw[i] *= ratio |
|
return raw |
|
|
|
def raw2rgb(self, raw, batched=False): |
|
|
|
if not batched: |
|
raw = self.demosaic(raw) |
|
else: |
|
raws = [] |
|
for i in range(raw.shape[0]): |
|
raws.append(self.demosaic(raw[i])) |
|
raw = np.stack(raws, axis=0) |
|
|
|
|
|
raw -= self.camera_params['black_level'] |
|
raw = np.clip(raw, 0, self.camera_params['max_value'] - self.camera_params['black_level']) |
|
|
|
|
|
wb = self.wb / self.wb.max() |
|
raw = raw * wb[:-1] |
|
|
|
|
|
norm_value = np.power(16383, 1/2.2) |
|
raw = np.power(raw, 1/2.2) |
|
raw /= norm_value |
|
|
|
|
|
if not batched: |
|
input_raw_img = torch.Tensor(raw).permute(2,0,1).float().to(self.device)[np.newaxis, ...] |
|
else: |
|
input_raw_img = torch.Tensor(raw).permute(0,3,1,2).float().to(self.device) |
|
|
|
with torch.no_grad(): |
|
reconstruct_rgb = self.net(input_raw_img) |
|
reconstruct_rgb = torch.clamp(reconstruct_rgb, 0, 1) |
|
|
|
pred_rgb = reconstruct_rgb.detach().permute(0,2,3,1) |
|
|
|
if not batched: |
|
pred_rgb = pred_rgb[0, ...] |
|
pred_rgb = pred_rgb.cpu().numpy() |
|
|
|
return pred_rgb |
|
|
|
|
|
def raw2packedRaw(self, raw, batched=False): |
|
|
|
raw -= self.camera_params['black_level'] |
|
raw = np.clip(raw, 0, self.camera_params['max_value'] - self.camera_params['black_level']) |
|
raw /= self.camera_params['max_value'] |
|
|
|
|
|
if not batched: |
|
im = np.expand_dims(raw, axis=2) |
|
img_shape = im.shape |
|
H = img_shape[0] |
|
W = img_shape[1] |
|
|
|
out = np.concatenate((im[0:H:2, 0:W:2, :], |
|
im[0:H:2, 1:W:2, :], |
|
im[1:H:2, 1:W:2, :], |
|
im[1:H:2, 0:W:2, :]), axis=2) |
|
else: |
|
im = np.expand_dims(raw, axis=3) |
|
img_shape = im.shape |
|
H = img_shape[1] |
|
W = img_shape[2] |
|
|
|
out = np.concatenate((im[:, 0:H:2, 0:W:2, :], |
|
im[:, 0:H:2, 1:W:2, :], |
|
im[:, 1:H:2, 1:W:2, :], |
|
im[:, 1:H:2, 0:W:2, :]), axis=3) |
|
return out |
|
|
|
def raw2demosaicRaw(self, raw, batched=False): |
|
|
|
if not batched: |
|
raw = self.demosaic(raw) |
|
else: |
|
raws = [] |
|
for i in range(raw.shape[0]): |
|
raws.append(self.demosaic(raw[i])) |
|
raw = np.stack(raws, axis=0) |
|
|
|
|
|
raw -= self.camera_params['black_level'] |
|
raw = np.clip(raw, 0, self.camera_params['max_value'] - self.camera_params['black_level']) |
|
raw /= self.camera_params['max_value'] |
|
return raw |
|
|