|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torch.autograd import Variable |
|
import torch |
|
import numpy as np |
|
import os, time, random |
|
import argparse |
|
from torch.utils.data import Dataset, DataLoader |
|
from PIL import Image as PILImage |
|
from glob import glob |
|
from tqdm import tqdm |
|
import rawpy |
|
import colour_demosaicing |
|
|
|
from .InvISP.model.model import InvISPNet |
|
from .utils.common import Notify |
|
from datasets.noise import ( |
|
camera_params, |
|
addGStarNoise, |
|
addPStarNoise, |
|
addQuantNoise, |
|
addRowNoise, |
|
sampleK, |
|
) |
|
|
|
|
|
class NoiseSimulator: |
|
def __init__(self, device, ckpt_path="./datasets/InvISP/pretrained/canon.pth"): |
|
self.device = device |
|
|
|
|
|
self.net = ( |
|
InvISPNet(channel_in=3, channel_out=3, block_num=8).to(self.device).eval() |
|
) |
|
self.net.load_state_dict(torch.load(ckpt_path), strict=False) |
|
print( |
|
Notify.INFO, "Loaded ISPNet checkpoint: {}".format(ckpt_path), Notify.ENDC |
|
) |
|
|
|
|
|
self.wb = np.array([2020.0, 1024.0, 1458.0, 1024.0]) |
|
|
|
|
|
self.camera_params = camera_params |
|
|
|
|
|
self.ratio_min = 50 |
|
self.ratio_max = 150 |
|
pass |
|
|
|
|
|
|
|
|
|
def invDemosaic(self, img): |
|
img_R = img[::2, ::2, 0] |
|
img_G1 = img[::2, 1::2, 1] |
|
img_G2 = img[1::2, ::2, 1] |
|
img_B = img[1::2, 1::2, 2] |
|
raw_img = np.ones(img.shape[:2]) |
|
raw_img[::2, ::2] = img_R |
|
raw_img[::2, 1::2] = img_G1 |
|
raw_img[1::2, ::2] = img_G2 |
|
raw_img[1::2, 1::2] = img_B |
|
return raw_img |
|
|
|
|
|
|
|
|
|
def demosaicNearest(self, img): |
|
raw = np.ones((img.shape[0], img.shape[1], 3)) |
|
raw[::2, ::2, 0] = img[::2, ::2] |
|
raw[::2, 1::2, 0] = img[::2, ::2] |
|
raw[1::2, ::2, 0] = img[::2, ::2] |
|
raw[1::2, 1::2, 0] = img[::2, ::2] |
|
raw[::2, ::2, 2] = img[1::2, 1::2] |
|
raw[::2, 1::2, 2] = img[1::2, 1::2] |
|
raw[1::2, ::2, 2] = img[1::2, 1::2] |
|
raw[1::2, 1::2, 2] = img[1::2, 1::2] |
|
raw[::2, ::2, 1] = img[::2, 1::2] |
|
raw[::2, 1::2, 1] = img[::2, 1::2] |
|
raw[1::2, ::2, 1] = img[1::2, ::2] |
|
raw[1::2, 1::2, 1] = img[1::2, ::2] |
|
return raw |
|
|
|
|
|
|
|
|
|
def demosaic(self, img): |
|
return colour_demosaicing.demosaicing_CFA_Bayer_bilinear(img, "RGGB") |
|
|
|
|
|
def path2rgb(self, path): |
|
return torch.from_numpy(np.array(PILImage.open(path)) / 255.0) |
|
|
|
|
|
|
|
|
|
def rgb2raw(self, rgb, batched=False): |
|
|
|
if not batched: |
|
rgb = rgb.unsqueeze(0) |
|
|
|
rgb = rgb.permute(0, 3, 1, 2).float().to(self.device) |
|
with torch.no_grad(): |
|
reconstruct_raw = self.net(rgb, rev=True) |
|
|
|
pred_raw = reconstruct_raw.detach().permute(0, 2, 3, 1) |
|
pred_raw = torch.clamp(pred_raw, 0, 1) |
|
|
|
if not batched: |
|
pred_raw = pred_raw[0, ...] |
|
|
|
pred_raw = pred_raw.cpu().numpy() |
|
|
|
|
|
norm_value = np.power(16383, 1 / 2.2) |
|
pred_raw *= norm_value |
|
pred_raw = np.power(pred_raw, 2.2) |
|
|
|
|
|
wb = self.wb / self.wb.max() |
|
pred_raw = pred_raw / wb[:-1] |
|
|
|
|
|
pred_raw += self.camera_params["black_level"] |
|
|
|
|
|
if not batched: |
|
pred_raw = self.invDemosaic(pred_raw) |
|
else: |
|
preds = [] |
|
for i in range(pred_raw.shape[0]): |
|
preds.append(self.invDemosaic(pred_raw[i])) |
|
pred_raw = np.stack(preds, axis=0) |
|
|
|
return pred_raw |
|
|
|
def raw2noisyRaw(self, raw, ratio_dec=1, batched=False): |
|
if not batched: |
|
ratio = (random.uniform(self.ratio_min, self.ratio_max) - 1) * ratio_dec + 1 |
|
raw = raw.copy() / ratio |
|
|
|
K = sampleK(self.camera_params["Kmin"], self.camera_params["Kmax"]) |
|
q = 1 / ( |
|
self.camera_params["max_value"] - self.camera_params["black_level"] |
|
) |
|
|
|
raw = addPStarNoise(raw, K) |
|
raw = addGStarNoise( |
|
raw, |
|
K, |
|
self.camera_params["G_shape"], |
|
self.camera_params["Profile-1"]["G_scale"], |
|
) |
|
raw = addRowNoise(raw, K, self.camera_params["Profile-1"]["R_scale"]) |
|
raw = addQuantNoise(raw, q) |
|
raw *= ratio |
|
return raw |
|
|
|
else: |
|
raw = raw.copy() |
|
for i in range(raw.shape[0]): |
|
ratio = random.uniform(self.ratio_min, self.ratio_max) |
|
raw[i] /= ratio |
|
|
|
K = sampleK(self.camera_params["Kmin"], self.camera_params["Kmax"]) |
|
q = 1 / ( |
|
self.camera_params["max_value"] - self.camera_params["black_level"] |
|
) |
|
|
|
raw[i] = addPStarNoise(raw[i], K) |
|
raw[i] = addGStarNoise( |
|
raw[i], |
|
K, |
|
self.camera_params["G_shape"], |
|
self.camera_params["Profile-1"]["G_scale"], |
|
) |
|
raw[i] = addRowNoise( |
|
raw[i], K, self.camera_params["Profile-1"]["R_scale"] |
|
) |
|
raw[i] = addQuantNoise(raw[i], q) |
|
raw[i] *= ratio |
|
return raw |
|
|
|
def raw2rgb(self, raw, batched=False): |
|
|
|
if not batched: |
|
raw = self.demosaic(raw) |
|
else: |
|
raws = [] |
|
for i in range(raw.shape[0]): |
|
raws.append(self.demosaic(raw[i])) |
|
raw = np.stack(raws, axis=0) |
|
|
|
|
|
raw -= self.camera_params["black_level"] |
|
raw = np.clip( |
|
raw, 0, self.camera_params["max_value"] - self.camera_params["black_level"] |
|
) |
|
|
|
|
|
wb = self.wb / self.wb.max() |
|
raw = raw * wb[:-1] |
|
|
|
|
|
norm_value = np.power(16383, 1 / 2.2) |
|
raw = np.power(raw, 1 / 2.2) |
|
raw /= norm_value |
|
|
|
|
|
if not batched: |
|
input_raw_img = ( |
|
torch.Tensor(raw) |
|
.permute(2, 0, 1) |
|
.float() |
|
.to(self.device)[np.newaxis, ...] |
|
) |
|
else: |
|
input_raw_img = ( |
|
torch.Tensor(raw).permute(0, 3, 1, 2).float().to(self.device) |
|
) |
|
|
|
with torch.no_grad(): |
|
reconstruct_rgb = self.net(input_raw_img) |
|
reconstruct_rgb = torch.clamp(reconstruct_rgb, 0, 1) |
|
|
|
pred_rgb = reconstruct_rgb.detach().permute(0, 2, 3, 1) |
|
|
|
if not batched: |
|
pred_rgb = pred_rgb[0, ...] |
|
pred_rgb = pred_rgb.cpu().numpy() |
|
|
|
return pred_rgb |
|
|
|
def raw2packedRaw(self, raw, batched=False): |
|
|
|
raw -= self.camera_params["black_level"] |
|
raw = np.clip( |
|
raw, 0, self.camera_params["max_value"] - self.camera_params["black_level"] |
|
) |
|
raw /= self.camera_params["max_value"] |
|
|
|
|
|
if not batched: |
|
im = np.expand_dims(raw, axis=2) |
|
img_shape = im.shape |
|
H = img_shape[0] |
|
W = img_shape[1] |
|
|
|
out = np.concatenate( |
|
( |
|
im[0:H:2, 0:W:2, :], |
|
im[0:H:2, 1:W:2, :], |
|
im[1:H:2, 1:W:2, :], |
|
im[1:H:2, 0:W:2, :], |
|
), |
|
axis=2, |
|
) |
|
else: |
|
im = np.expand_dims(raw, axis=3) |
|
img_shape = im.shape |
|
H = img_shape[1] |
|
W = img_shape[2] |
|
|
|
out = np.concatenate( |
|
( |
|
im[:, 0:H:2, 0:W:2, :], |
|
im[:, 0:H:2, 1:W:2, :], |
|
im[:, 1:H:2, 1:W:2, :], |
|
im[:, 1:H:2, 0:W:2, :], |
|
), |
|
axis=3, |
|
) |
|
return out |
|
|
|
def raw2demosaicRaw(self, raw, batched=False): |
|
|
|
if not batched: |
|
raw = self.demosaic(raw) |
|
else: |
|
raws = [] |
|
for i in range(raw.shape[0]): |
|
raws.append(self.demosaic(raw[i])) |
|
raw = np.stack(raws, axis=0) |
|
|
|
|
|
raw -= self.camera_params["black_level"] |
|
raw = np.clip( |
|
raw, 0, self.camera_params["max_value"] - self.camera_params["black_level"] |
|
) |
|
raw /= self.camera_params["max_value"] |
|
return raw |
|
|