|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torch.autograd import Variable |
|
import torch |
|
import numpy as np |
|
import os, time, random |
|
import argparse |
|
from torch.utils.data import Dataset, DataLoader |
|
from PIL import Image as PILImage |
|
from glob import glob |
|
from tqdm import tqdm |
|
|
|
from model.model import InvISPNet |
|
from dataset.FiveK_dataset import FiveKDatasetTest |
|
from config.config import get_arguments |
|
|
|
from utils.JPEG import DiffJPEG |
|
from utils.commons import denorm, preprocess_test_patch |
|
|
|
|
|
os.system("nvidia-smi -q -d Memory |grep -A4 GPU|grep Free >tmp") |
|
os.environ["CUDA_VISIBLE_DEVICES"] = str( |
|
np.argmax([int(x.split()[2]) for x in open("tmp", "r").readlines()]) |
|
) |
|
|
|
os.system("rm tmp") |
|
|
|
DiffJPEG = DiffJPEG(differentiable=True, quality=90).cuda() |
|
|
|
parser = get_arguments() |
|
parser.add_argument("--ckpt", type=str, help="Checkpoint path.") |
|
parser.add_argument( |
|
"--out_path", type=str, default="./exps/", help="Path to save checkpoint. " |
|
) |
|
parser.add_argument( |
|
"--split_to_patch", |
|
dest="split_to_patch", |
|
action="store_true", |
|
help="Test on patch. ", |
|
) |
|
args = parser.parse_args() |
|
print("Parsed arguments: {}".format(args)) |
|
|
|
|
|
ckpt_name = args.ckpt.split("/")[-1].split(".")[0] |
|
if args.split_to_patch: |
|
os.makedirs( |
|
args.out_path + "%s/results_metric_%s/" % (args.task, ckpt_name), exist_ok=True |
|
) |
|
out_path = args.out_path + "%s/results_metric_%s/" % (args.task, ckpt_name) |
|
else: |
|
os.makedirs( |
|
args.out_path + "%s/results_%s/" % (args.task, ckpt_name), exist_ok=True |
|
) |
|
out_path = args.out_path + "%s/results_%s/" % (args.task, ckpt_name) |
|
|
|
|
|
def main(args): |
|
|
|
net = InvISPNet(channel_in=3, channel_out=3, block_num=8) |
|
device = torch.device("cuda:0") |
|
|
|
net.to(device) |
|
net.eval() |
|
|
|
if os.path.isfile(args.ckpt): |
|
net.load_state_dict(torch.load(args.ckpt), strict=False) |
|
print("[INFO] Loaded checkpoint: {}".format(args.ckpt)) |
|
|
|
print("[INFO] Start data load and preprocessing") |
|
RAWDataset = FiveKDatasetTest(opt=args) |
|
dataloader = DataLoader( |
|
RAWDataset, batch_size=1, shuffle=False, num_workers=0, drop_last=True |
|
) |
|
|
|
input_RGBs = sorted(glob(out_path + "pred*jpg")) |
|
input_RGBs_names = [path.split("/")[-1].split(".")[0][5:] for path in input_RGBs] |
|
|
|
print("[INFO] Start test...") |
|
for i_batch, sample_batched in enumerate(tqdm(dataloader)): |
|
step_time = time.time() |
|
|
|
input, target_rgb, target_raw = ( |
|
sample_batched["input_raw"].to(device), |
|
sample_batched["target_rgb"].to(device), |
|
sample_batched["target_raw"].to(device), |
|
) |
|
file_name = sample_batched["file_name"][0] |
|
|
|
if args.split_to_patch: |
|
input_list, target_rgb_list, target_raw_list = preprocess_test_patch( |
|
input, target_rgb, target_raw |
|
) |
|
else: |
|
|
|
input_list, target_rgb_list, target_raw_list = ( |
|
[input[:, :, ::2, ::2]], |
|
[target_rgb[:, :, ::2, ::2]], |
|
[target_raw[:, :, ::2, ::2]], |
|
) |
|
|
|
for i_patch in range(len(input_list)): |
|
file_name_patch = file_name + "_%05d" % i_patch |
|
idx = input_RGBs_names.index(file_name_patch) |
|
input_RGB_path = input_RGBs[idx] |
|
input_RGB = ( |
|
torch.from_numpy(np.array(PILImage.open(input_RGB_path)) / 255.0) |
|
.unsqueeze(0) |
|
.permute(0, 3, 1, 2) |
|
.float() |
|
.to(device) |
|
) |
|
|
|
target_raw_patch = target_raw_list[i_patch] |
|
|
|
with torch.no_grad(): |
|
reconstruct_raw = net(input_RGB, rev=True) |
|
|
|
pred_raw = reconstruct_raw.detach().permute(0, 2, 3, 1) |
|
pred_raw = torch.clamp(pred_raw, 0, 1) |
|
|
|
target_raw_patch = target_raw_patch.permute(0, 2, 3, 1) |
|
pred_raw = denorm(pred_raw, 255) |
|
target_raw_patch = denorm(target_raw_patch, 255) |
|
|
|
pred_raw = pred_raw.cpu().numpy() |
|
target_raw_patch = target_raw_patch.cpu().numpy().astype(np.float32) |
|
|
|
raw_pred = PILImage.fromarray(np.uint8(pred_raw[0, :, :, 0])) |
|
raw_tar_pred = PILImage.fromarray( |
|
np.hstack( |
|
( |
|
np.uint8(target_raw_patch[0, :, :, 0]), |
|
np.uint8(pred_raw[0, :, :, 0]), |
|
) |
|
) |
|
) |
|
|
|
raw_tar = PILImage.fromarray(np.uint8(target_raw_patch[0, :, :, 0])) |
|
|
|
raw_pred.save(out_path + "raw_pred_%s_%05d.jpg" % (file_name, i_patch)) |
|
raw_tar.save(out_path + "raw_tar_%s_%05d.jpg" % (file_name, i_patch)) |
|
raw_tar_pred.save( |
|
out_path + "raw_gt_pred_%s_%05d.jpg" % (file_name, i_patch) |
|
) |
|
|
|
np.save( |
|
out_path + "raw_pred_%s_%05d.npy" % (file_name, i_patch), |
|
pred_raw[0, :, :, :] / 255.0, |
|
) |
|
np.save( |
|
out_path + "raw_tar_%s_%05d.npy" % (file_name, i_patch), |
|
target_raw_patch[0, :, :, :] / 255.0, |
|
) |
|
|
|
del reconstruct_raw |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
torch.set_num_threads(4) |
|
main(args) |
|
|