Spaces:
Runtime error
Runtime error
import argparse | |
import cv2 | |
import glob | |
import numpy as np | |
from collections import OrderedDict | |
from skimage import img_as_ubyte | |
import os | |
import torch | |
import requests | |
from PIL import Image | |
import torchvision.transforms.functional as TF | |
import torch.nn.functional as F | |
from natsort import natsorted | |
from model.CMFNet import CMFNet | |
def main(): | |
parser = argparse.ArgumentParser(description='Demo Image Dehaze') | |
parser.add_argument('--input_dir', default='test/', type=str, help='Input images') | |
parser.add_argument('--result_dir', default='results/', type=str, help='Directory for results') | |
parser.add_argument('--weights', | |
default='experiments/pretrained_models/dehaze_model.pth', type=str, | |
help='Path to weights') | |
args = parser.parse_args() | |
inp_dir = args.input_dir | |
out_dir = args.result_dir | |
os.makedirs(out_dir, exist_ok=True) | |
files = natsorted(glob.glob(os.path.join(inp_dir, '*'))) | |
if len(files) == 0: | |
raise Exception(f"No files found at {inp_dir}") | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
# Load corresponding models architecture and weights | |
model = CMFNet() | |
model = model.to(device) | |
model.eval() | |
load_checkpoint(model, args.weights) | |
mul = 8 | |
for file_ in files: | |
img = Image.open(file_).convert('RGB') | |
input_ = TF.to_tensor(img).unsqueeze(0).to(device) | |
# Pad the input if not_multiple_of 8 | |
h, w = input_.shape[2], input_.shape[3] | |
H, W = ((h + mul) // mul) * mul, ((w + mul) // mul) * mul | |
padh = H - h if h % mul != 0 else 0 | |
padw = W - w if w % mul != 0 else 0 | |
input_ = F.pad(input_, (0, padw, 0, padh), 'reflect') | |
with torch.no_grad(): | |
restored = model(input_) | |
restored = torch.clamp(restored, 0, 1) | |
restored = restored[:, :, :h, :w] | |
restored = restored.permute(0, 2, 3, 1).cpu().detach().numpy() | |
restored = img_as_ubyte(restored[0]) | |
f = os.path.splitext(os.path.split(file_)[-1])[0] | |
save_img((os.path.join(out_dir, f + '.png')), restored) | |
def save_img(filepath, img): | |
cv2.imwrite(filepath, cv2.cvtColor(img, cv2.COLOR_RGB2BGR)) | |
def load_checkpoint(model, weights): | |
checkpoint = torch.load(weights, map_location=torch.device('cpu')) | |
try: | |
model.load_state_dict(checkpoint["state_dict"]) | |
except: | |
state_dict = checkpoint["state_dict"] | |
new_state_dict = OrderedDict() | |
for k, v in state_dict.items(): | |
name = k[7:] # remove `module.` | |
new_state_dict[name] = v | |
model.load_state_dict(new_state_dict) | |
if __name__ == '__main__': | |
main() |