Vincentqyw
fix: roma
358ab8f
raw
history blame
9.01 kB
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import torch
import numpy as np
import os, time, random
import argparse
from torch.utils.data import Dataset, DataLoader
from PIL import Image as PILImage
from glob import glob
from tqdm import tqdm
import rawpy
import colour_demosaicing
from .InvISP.model.model import InvISPNet
from .utils.common import Notify
from datasets.noise import (
camera_params,
addGStarNoise,
addPStarNoise,
addQuantNoise,
addRowNoise,
sampleK,
)
class NoiseSimulator:
def __init__(self, device, ckpt_path="./datasets/InvISP/pretrained/canon.pth"):
self.device = device
# load Invertible ISP Network
self.net = (
InvISPNet(channel_in=3, channel_out=3, block_num=8).to(self.device).eval()
)
self.net.load_state_dict(torch.load(ckpt_path), strict=False)
print(
Notify.INFO, "Loaded ISPNet checkpoint: {}".format(ckpt_path), Notify.ENDC
)
# white balance parameters
self.wb = np.array([2020.0, 1024.0, 1458.0, 1024.0])
# use Canon EOS 5D4 noise parameters provided by ELD
self.camera_params = camera_params
# random specify exposure time ratio from 50 to 150
self.ratio_min = 50
self.ratio_max = 150
pass
# inverse demosaic
# input: [H, W, 3]
# output: [H, W]
def invDemosaic(self, img):
img_R = img[::2, ::2, 0]
img_G1 = img[::2, 1::2, 1]
img_G2 = img[1::2, ::2, 1]
img_B = img[1::2, 1::2, 2]
raw_img = np.ones(img.shape[:2])
raw_img[::2, ::2] = img_R
raw_img[::2, 1::2] = img_G1
raw_img[1::2, ::2] = img_G2
raw_img[1::2, 1::2] = img_B
return raw_img
# demosaic - nearest ver
# input: [H, W]
# output: [H, W, 3]
def demosaicNearest(self, img):
raw = np.ones((img.shape[0], img.shape[1], 3))
raw[::2, ::2, 0] = img[::2, ::2]
raw[::2, 1::2, 0] = img[::2, ::2]
raw[1::2, ::2, 0] = img[::2, ::2]
raw[1::2, 1::2, 0] = img[::2, ::2]
raw[::2, ::2, 2] = img[1::2, 1::2]
raw[::2, 1::2, 2] = img[1::2, 1::2]
raw[1::2, ::2, 2] = img[1::2, 1::2]
raw[1::2, 1::2, 2] = img[1::2, 1::2]
raw[::2, ::2, 1] = img[::2, 1::2]
raw[::2, 1::2, 1] = img[::2, 1::2]
raw[1::2, ::2, 1] = img[1::2, ::2]
raw[1::2, 1::2, 1] = img[1::2, ::2]
return raw
# demosaic
# input: [H, W]
# output: [H, W, 3]
def demosaic(self, img):
return colour_demosaicing.demosaicing_CFA_Bayer_bilinear(img, "RGGB")
# load rgb image
def path2rgb(self, path):
return torch.from_numpy(np.array(PILImage.open(path)) / 255.0)
# InvISP
# input: rgb image [H, W, 3]
# output: raw image [H, W]
def rgb2raw(self, rgb, batched=False):
# 1. rgb -> invnet
if not batched:
rgb = rgb.unsqueeze(0)
rgb = rgb.permute(0, 3, 1, 2).float().to(self.device)
with torch.no_grad():
reconstruct_raw = self.net(rgb, rev=True)
pred_raw = reconstruct_raw.detach().permute(0, 2, 3, 1)
pred_raw = torch.clamp(pred_raw, 0, 1)
if not batched:
pred_raw = pred_raw[0, ...]
pred_raw = pred_raw.cpu().numpy()
# 2. -> inv gamma
norm_value = np.power(16383, 1 / 2.2)
pred_raw *= norm_value
pred_raw = np.power(pred_raw, 2.2)
# 3. -> inv white balance
wb = self.wb / self.wb.max()
pred_raw = pred_raw / wb[:-1]
# 4. -> add black level
pred_raw += self.camera_params["black_level"]
# 5. -> inv demosaic
if not batched:
pred_raw = self.invDemosaic(pred_raw)
else:
preds = []
for i in range(pred_raw.shape[0]):
preds.append(self.invDemosaic(pred_raw[i]))
pred_raw = np.stack(preds, axis=0)
return pred_raw
def raw2noisyRaw(self, raw, ratio_dec=1, batched=False):
if not batched:
ratio = (random.uniform(self.ratio_min, self.ratio_max) - 1) * ratio_dec + 1
raw = raw.copy() / ratio
K = sampleK(self.camera_params["Kmin"], self.camera_params["Kmax"])
q = 1 / (
self.camera_params["max_value"] - self.camera_params["black_level"]
)
raw = addPStarNoise(raw, K)
raw = addGStarNoise(
raw,
K,
self.camera_params["G_shape"],
self.camera_params["Profile-1"]["G_scale"],
)
raw = addRowNoise(raw, K, self.camera_params["Profile-1"]["R_scale"])
raw = addQuantNoise(raw, q)
raw *= ratio
return raw
else:
raw = raw.copy()
for i in range(raw.shape[0]):
ratio = random.uniform(self.ratio_min, self.ratio_max)
raw[i] /= ratio
K = sampleK(self.camera_params["Kmin"], self.camera_params["Kmax"])
q = 1 / (
self.camera_params["max_value"] - self.camera_params["black_level"]
)
raw[i] = addPStarNoise(raw[i], K)
raw[i] = addGStarNoise(
raw[i],
K,
self.camera_params["G_shape"],
self.camera_params["Profile-1"]["G_scale"],
)
raw[i] = addRowNoise(
raw[i], K, self.camera_params["Profile-1"]["R_scale"]
)
raw[i] = addQuantNoise(raw[i], q)
raw[i] *= ratio
return raw
def raw2rgb(self, raw, batched=False):
# 1. -> demosaic
if not batched:
raw = self.demosaic(raw)
else:
raws = []
for i in range(raw.shape[0]):
raws.append(self.demosaic(raw[i]))
raw = np.stack(raws, axis=0)
# 2. -> substract black level
raw -= self.camera_params["black_level"]
raw = np.clip(
raw, 0, self.camera_params["max_value"] - self.camera_params["black_level"]
)
# 3. -> white balance
wb = self.wb / self.wb.max()
raw = raw * wb[:-1]
# 4. -> gamma
norm_value = np.power(16383, 1 / 2.2)
raw = np.power(raw, 1 / 2.2)
raw /= norm_value
# 5. -> ispnet
if not batched:
input_raw_img = (
torch.Tensor(raw)
.permute(2, 0, 1)
.float()
.to(self.device)[np.newaxis, ...]
)
else:
input_raw_img = (
torch.Tensor(raw).permute(0, 3, 1, 2).float().to(self.device)
)
with torch.no_grad():
reconstruct_rgb = self.net(input_raw_img)
reconstruct_rgb = torch.clamp(reconstruct_rgb, 0, 1)
pred_rgb = reconstruct_rgb.detach().permute(0, 2, 3, 1)
if not batched:
pred_rgb = pred_rgb[0, ...]
pred_rgb = pred_rgb.cpu().numpy()
return pred_rgb
def raw2packedRaw(self, raw, batched=False):
# 1. -> substract black level
raw -= self.camera_params["black_level"]
raw = np.clip(
raw, 0, self.camera_params["max_value"] - self.camera_params["black_level"]
)
raw /= self.camera_params["max_value"]
# 2. pack
if not batched:
im = np.expand_dims(raw, axis=2)
img_shape = im.shape
H = img_shape[0]
W = img_shape[1]
out = np.concatenate(
(
im[0:H:2, 0:W:2, :],
im[0:H:2, 1:W:2, :],
im[1:H:2, 1:W:2, :],
im[1:H:2, 0:W:2, :],
),
axis=2,
)
else:
im = np.expand_dims(raw, axis=3)
img_shape = im.shape
H = img_shape[1]
W = img_shape[2]
out = np.concatenate(
(
im[:, 0:H:2, 0:W:2, :],
im[:, 0:H:2, 1:W:2, :],
im[:, 1:H:2, 1:W:2, :],
im[:, 1:H:2, 0:W:2, :],
),
axis=3,
)
return out
def raw2demosaicRaw(self, raw, batched=False):
# 1. -> demosaic
if not batched:
raw = self.demosaic(raw)
else:
raws = []
for i in range(raw.shape[0]):
raws.append(self.demosaic(raw[i]))
raw = np.stack(raws, axis=0)
# 2. -> substract black level
raw -= self.camera_params["black_level"]
raw = np.clip(
raw, 0, self.camera_params["max_value"] - self.camera_params["black_level"]
)
raw /= self.camera_params["max_value"]
return raw