File size: 5,571 Bytes
b9029bc dfe3452 b9029bc 7140d5d c2cb492 b9029bc dfe3452 b9029bc feba69f b9029bc c2cb492 b9029bc c2cb492 b9029bc cd93723 67f2615 b9029bc feba69f b9029bc 5824344 b9029bc c2cb492 b9029bc c2cb492 b9029bc 969e4dd b9029bc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 |
import os
import yaml
import torch
import argparse
import numpy as np
import gradio as gr
from PIL import Image
from copy import deepcopy
from torch.nn.parallel import DataParallel, DistributedDataParallel
from huggingface_hub import hf_hub_download
from gradio_imageslider import ImageSlider
## local code
from models.smfanet_arch import SMFANet
def dict2namespace(config):
namespace = argparse.Namespace()
for key, value in config.items():
if isinstance(value, dict):
new_value = dict2namespace(value)
else:
new_value = value
setattr(namespace, key, new_value)
return namespace
def load_img (filename, norm=True,):
img = np.array(Image.open(filename).convert("RGB"))
h, w = img.shape[:2]
if w > 1920 or h > 1080:
new_h, new_w = h // 4, w // 4
img = np.array(Image.fromarray(img).resize((new_w, new_h), Image.BICUBIC))
if norm:
img = img / 255.
img = img.astype(np.float32)
return img
def process_img (image):
img = np.array(image)
img = img / 255.
img = img.astype(np.float32)
y = torch.tensor(img).permute(2,0,1).unsqueeze(0).to(device)
with torch.no_grad():
x_hat = model(y)
restored_img = x_hat.squeeze().permute(1,2,0).clamp_(0, 1).cpu().detach().numpy()
restored_img = np.clip(restored_img, 0. , 1.)
restored_img = (restored_img * 255.0).round().astype(np.uint8) # float32 to uint8
#return Image.fromarray(restored_img) #
return (image, Image.fromarray(restored_img))
def load_network(net, load_path, strict=True, param_key='params'):
if isinstance(net, (DataParallel, DistributedDataParallel)):
net = net.module
load_net = torch.load(load_path, map_location=lambda storage, loc: storage)
if param_key is not None:
if param_key not in load_net and 'params' in load_net:
param_key = 'params'
load_net = load_net[param_key]
# remove unnecessary 'module.'
for k, v in deepcopy(load_net).items():
if k.startswith('module.'):
load_net[k[7:]] = v
load_net.pop(k)
net.load_state_dict(load_net, strict=strict)
CONFIG = "configs/SMFANet_plus_x4SR.yml"
MODEL_NAME = "pth/SMFANet_plus_DF2K_100w_x4SR.pth"
# parse config file
with open(os.path.join(CONFIG), "r") as f:
config = yaml.safe_load(f)
cfg = dict2namespace(config)
device = torch.device("cpu")
model = SMFANet(dim=cfg.model.dim, n_blocks=cfg.model.n_blocks, ffn_scale=cfg.model.ffn_scale, upscaling_factor=cfg.model.upscaling_factor)
model = model.to(device)
print ("IMAGE MODEL CKPT:", MODEL_NAME)
load_network(model, MODEL_NAME, strict=True, param_key='params')
title = "[ECCV 2024] SMFANet: A Lightweight Self-Modulation Feature Aggregation Network for Efficient Image Super-Resolution"
description = '''
#### [Mingjun Zheng](https://github.com/Zheng-MJ), [Long Sun](https://github.com/sunny2109), [Jiangxin Dong](https://scholar.google.com/citations?user=ruebFVEAAAAJ&hl=zh-CN&oi=ao), and [Jinshan Pan](https://jspan.github.io/)
#### [IMAG Lab](https://imag-njust.net/), Nanjing University of Science and Technology
<center><img src='assets/smfanet_arch.png'></center>
### Network architecture of the proposed SMFANet. The proposed s SMFANet consists of a shallow feature extraction module, feature modulation blocks, and a lightweight image reconstruction module. Feature modulation block contains one self-modulation feature aggregation (SMFA) module and one partial convolution-based feed-forward network (PCFN).*
#### Drag the slider on the super-resolution image left and right to see the changes in the image details. SeemoRe performs x4 upscaling on the input image.
<br>
<code>
@inproceedings{smfanet,
title={SMFANet: A Lightweight Self-Modulation Feature Aggregation Network for Efficient Image Super-Resolution},
author={Zheng, Mingjun and Sun, Long and Dong, Jiangxin and Pan, Jinshan},
booktitle={Proceedings of the European Conference on Computer Vision (ECCV)},
year={2024}
}
</code>
<br>
'''
article = "<p style='text-align: center'><a href='https://raw.githubusercontent.com/Zheng-MJ/SMFANet' target='_blank'>SMFANet: A Lightweight Self-Modulation Feature Aggregation Network for Efficient Image Super-Resolution </a></p>"
#### Image,Prompts examples
examples = [
['images/0801x4.png'],
# ['images/0840x4.png'],
# ['images/0841x4.png'],
# ['images/0870x4.png'],
# ['images/0878x4.png'],
# ['images/0884x4.png'],
# ['images/0900x4.png'],
# ['images/img002x4.png'],
# ['images/img003x4.png'],
# ['images/img004x4.png'],
# ['images/img035x4.png'],
# ['images/img053x4.png'],
# ['images/img064x4.png'],
# ['images/img083x4.png'],
# ['images/img092x4.png'],
]
css = """
.image-frame img, .image-container img {
width: auto;
height: auto;
max-width: none;
}
"""
demo = gr.Interface(
fn=process_img,
inputs=[gr.Image(type="pil", label="Input", value="images/0878x4.png"),],
outputs=ImageSlider(label="Super-Resolved Image",
type="pil",
show_download_button=True,
), #[gr.Image(type="pil", label="Ouput", min_width=500)],
title=title,
description=description,
article=article,
examples=examples,
css=css,
)
if __name__ == "__main__":
demo.launch() |