Spaces:
Running
Running
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch.autograd import Variable | |
import torch | |
import numpy as np | |
import os, time, random | |
import argparse | |
from torch.utils.data import Dataset, DataLoader | |
from PIL import Image as PILImage | |
from glob import glob | |
from tqdm import tqdm | |
import rawpy | |
import colour_demosaicing | |
from .InvISP.model.model import InvISPNet | |
from .utils.common import Notify | |
from datasets.noise import camera_params, addGStarNoise, addPStarNoise, addQuantNoise, addRowNoise, sampleK | |
class NoiseSimulator: | |
def __init__(self, device, ckpt_path='./datasets/InvISP/pretrained/canon.pth'): | |
self.device = device | |
# load Invertible ISP Network | |
self.net = InvISPNet(channel_in=3, channel_out=3, block_num=8).to(self.device).eval() | |
self.net.load_state_dict(torch.load(ckpt_path), strict=False) | |
print(Notify.INFO, "Loaded ISPNet checkpoint: {}".format(ckpt_path), Notify.ENDC) | |
# white balance parameters | |
self.wb = np.array([2020.0, 1024.0, 1458.0, 1024.0]) | |
# use Canon EOS 5D4 noise parameters provided by ELD | |
self.camera_params = camera_params | |
# random specify exposure time ratio from 50 to 150 | |
self.ratio_min = 50 | |
self.ratio_max = 150 | |
pass | |
# inverse demosaic | |
# input: [H, W, 3] | |
# output: [H, W] | |
def invDemosaic(self, img): | |
img_R = img[::2, ::2, 0] | |
img_G1 = img[::2, 1::2, 1] | |
img_G2 = img[1::2, ::2, 1] | |
img_B = img[1::2, 1::2, 2] | |
raw_img = np.ones(img.shape[:2]) | |
raw_img[::2, ::2] = img_R | |
raw_img[::2, 1::2] = img_G1 | |
raw_img[1::2, ::2] = img_G2 | |
raw_img[1::2, 1::2] = img_B | |
return raw_img | |
# demosaic - nearest ver | |
# input: [H, W] | |
# output: [H, W, 3] | |
def demosaicNearest(self, img): | |
raw = np.ones((img.shape[0], img.shape[1], 3)) | |
raw[::2, ::2, 0] = img[::2, ::2] | |
raw[::2, 1::2, 0] = img[::2, ::2] | |
raw[1::2, ::2, 0] = img[::2, ::2] | |
raw[1::2, 1::2, 0] = img[::2, ::2] | |
raw[::2, ::2, 2] = img[1::2, 1::2] | |
raw[::2, 1::2, 2] = img[1::2, 1::2] | |
raw[1::2, ::2, 2] = img[1::2, 1::2] | |
raw[1::2, 1::2, 2] = img[1::2, 1::2] | |
raw[::2, ::2, 1] = img[::2, 1::2] | |
raw[::2, 1::2, 1] = img[::2, 1::2] | |
raw[1::2, ::2, 1] = img[1::2, ::2] | |
raw[1::2, 1::2, 1] = img[1::2, ::2] | |
return raw | |
# demosaic | |
# input: [H, W] | |
# output: [H, W, 3] | |
def demosaic(self, img): | |
return colour_demosaicing.demosaicing_CFA_Bayer_bilinear(img, 'RGGB') | |
# load rgb image | |
def path2rgb(self, path): | |
return torch.from_numpy(np.array(PILImage.open(path))/255.0) | |
# InvISP | |
# input: rgb image [H, W, 3] | |
# output: raw image [H, W] | |
def rgb2raw(self, rgb, batched=False): | |
# 1. rgb -> invnet | |
if not batched: | |
rgb = rgb.unsqueeze(0) | |
rgb = rgb.permute(0,3,1,2).float().to(self.device) | |
with torch.no_grad(): | |
reconstruct_raw = self.net(rgb, rev=True) | |
pred_raw = reconstruct_raw.detach().permute(0,2,3,1) | |
pred_raw = torch.clamp(pred_raw, 0, 1) | |
if not batched: | |
pred_raw = pred_raw[0, ...] | |
pred_raw = pred_raw.cpu().numpy() | |
# 2. -> inv gamma | |
norm_value = np.power(16383, 1/2.2) | |
pred_raw *= norm_value | |
pred_raw = np.power(pred_raw, 2.2) | |
# 3. -> inv white balance | |
wb = self.wb / self.wb.max() | |
pred_raw = pred_raw / wb[:-1] | |
# 4. -> add black level | |
pred_raw += self.camera_params['black_level'] | |
# 5. -> inv demosaic | |
if not batched: | |
pred_raw = self.invDemosaic(pred_raw) | |
else: | |
preds = [] | |
for i in range(pred_raw.shape[0]): | |
preds.append(self.invDemosaic(pred_raw[i])) | |
pred_raw = np.stack(preds, axis=0) | |
return pred_raw | |
def raw2noisyRaw(self, raw, ratio_dec=1, batched=False): | |
if not batched: | |
ratio = (random.uniform(self.ratio_min, self.ratio_max) - 1) * ratio_dec + 1 | |
raw = raw.copy() / ratio | |
K = sampleK(self.camera_params['Kmin'], self.camera_params['Kmax']) | |
q = 1 / (self.camera_params['max_value'] - self.camera_params['black_level']) | |
raw = addPStarNoise(raw, K) | |
raw = addGStarNoise(raw, K, self.camera_params['G_shape'], self.camera_params['Profile-1']['G_scale']) | |
raw = addRowNoise(raw, K, self.camera_params['Profile-1']['R_scale']) | |
raw = addQuantNoise(raw, q) | |
raw *= ratio | |
return raw | |
else: | |
raw = raw.copy() | |
for i in range(raw.shape[0]): | |
ratio = random.uniform(self.ratio_min, self.ratio_max) | |
raw[i] /= ratio | |
K = sampleK(self.camera_params['Kmin'], self.camera_params['Kmax']) | |
q = 1 / (self.camera_params['max_value'] - self.camera_params['black_level']) | |
raw[i] = addPStarNoise(raw[i], K) | |
raw[i] = addGStarNoise(raw[i], K, self.camera_params['G_shape'], self.camera_params['Profile-1']['G_scale']) | |
raw[i] = addRowNoise(raw[i], K, self.camera_params['Profile-1']['R_scale']) | |
raw[i] = addQuantNoise(raw[i], q) | |
raw[i] *= ratio | |
return raw | |
def raw2rgb(self, raw, batched=False): | |
# 1. -> demosaic | |
if not batched: | |
raw = self.demosaic(raw) | |
else: | |
raws = [] | |
for i in range(raw.shape[0]): | |
raws.append(self.demosaic(raw[i])) | |
raw = np.stack(raws, axis=0) | |
# 2. -> substract black level | |
raw -= self.camera_params['black_level'] | |
raw = np.clip(raw, 0, self.camera_params['max_value'] - self.camera_params['black_level']) | |
# 3. -> white balance | |
wb = self.wb / self.wb.max() | |
raw = raw * wb[:-1] | |
# 4. -> gamma | |
norm_value = np.power(16383, 1/2.2) | |
raw = np.power(raw, 1/2.2) | |
raw /= norm_value | |
# 5. -> ispnet | |
if not batched: | |
input_raw_img = torch.Tensor(raw).permute(2,0,1).float().to(self.device)[np.newaxis, ...] | |
else: | |
input_raw_img = torch.Tensor(raw).permute(0,3,1,2).float().to(self.device) | |
with torch.no_grad(): | |
reconstruct_rgb = self.net(input_raw_img) | |
reconstruct_rgb = torch.clamp(reconstruct_rgb, 0, 1) | |
pred_rgb = reconstruct_rgb.detach().permute(0,2,3,1) | |
if not batched: | |
pred_rgb = pred_rgb[0, ...] | |
pred_rgb = pred_rgb.cpu().numpy() | |
return pred_rgb | |
def raw2packedRaw(self, raw, batched=False): | |
# 1. -> substract black level | |
raw -= self.camera_params['black_level'] | |
raw = np.clip(raw, 0, self.camera_params['max_value'] - self.camera_params['black_level']) | |
raw /= self.camera_params['max_value'] | |
# 2. pack | |
if not batched: | |
im = np.expand_dims(raw, axis=2) | |
img_shape = im.shape | |
H = img_shape[0] | |
W = img_shape[1] | |
out = np.concatenate((im[0:H:2, 0:W:2, :], | |
im[0:H:2, 1:W:2, :], | |
im[1:H:2, 1:W:2, :], | |
im[1:H:2, 0:W:2, :]), axis=2) | |
else: | |
im = np.expand_dims(raw, axis=3) | |
img_shape = im.shape | |
H = img_shape[1] | |
W = img_shape[2] | |
out = np.concatenate((im[:, 0:H:2, 0:W:2, :], | |
im[:, 0:H:2, 1:W:2, :], | |
im[:, 1:H:2, 1:W:2, :], | |
im[:, 1:H:2, 0:W:2, :]), axis=3) | |
return out | |
def raw2demosaicRaw(self, raw, batched=False): | |
# 1. -> demosaic | |
if not batched: | |
raw = self.demosaic(raw) | |
else: | |
raws = [] | |
for i in range(raw.shape[0]): | |
raws.append(self.demosaic(raw[i])) | |
raw = np.stack(raws, axis=0) | |
# 2. -> substract black level | |
raw -= self.camera_params['black_level'] | |
raw = np.clip(raw, 0, self.camera_params['max_value'] - self.camera_params['black_level']) | |
raw /= self.camera_params['max_value'] | |
return raw | |