import os import yaml import torch import argparse import numpy as np import gradio as gr from PIL import Image from copy import deepcopy from torch.nn.parallel import DataParallel, DistributedDataParallel from huggingface_hub import hf_hub_download from gradio_imageslider import ImageSlider ## local code from models.smfanet_arch import SMFANet def dict2namespace(config): namespace = argparse.Namespace() for key, value in config.items(): if isinstance(value, dict): new_value = dict2namespace(value) else: new_value = value setattr(namespace, key, new_value) return namespace def load_img (filename, norm=True,): img = np.array(Image.open(filename).convert("RGB")) h, w = img.shape[:2] if w > 1920 or h > 1080: new_h, new_w = h // 4, w // 4 img = np.array(Image.fromarray(img).resize((new_w, new_h), Image.BICUBIC)) if norm: img = img / 255. img = img.astype(np.float32) return img def process_img (image): img = np.array(image) img = img / 255. img = img.astype(np.float32) y = torch.tensor(img).permute(2,0,1).unsqueeze(0).to(device) with torch.no_grad(): x_hat = model(y) restored_img = x_hat.squeeze().permute(1,2,0).clamp_(0, 1).cpu().detach().numpy() restored_img = np.clip(restored_img, 0. , 1.) restored_img = (restored_img * 255.0).round().astype(np.uint8) # float32 to uint8 #return Image.fromarray(restored_img) # return (image, Image.fromarray(restored_img)) def load_network(net, load_path, strict=True, param_key='params'): if isinstance(net, (DataParallel, DistributedDataParallel)): net = net.module load_net = torch.load(load_path, map_location=lambda storage, loc: storage) if param_key is not None: if param_key not in load_net and 'params' in load_net: param_key = 'params' load_net = load_net[param_key] # remove unnecessary 'module.' for k, v in deepcopy(load_net).items(): if k.startswith('module.'): load_net[k[7:]] = v load_net.pop(k) net.load_state_dict(load_net, strict=strict) CONFIG = "configs/SMFANet_plus_x4SR.yml" MODEL_NAME = "pth/SMFANet_plus_DF2K_100w_x4SR.pth" # parse config file with open(os.path.join(CONFIG), "r") as f: config = yaml.safe_load(f) cfg = dict2namespace(config) device = torch.device("cpu") model = SMFANet(dim=cfg.model.dim, n_blocks=cfg.model.n_blocks, ffn_scale=cfg.model.ffn_scale, upscaling_factor=cfg.model.upscaling_factor) model = model.to(device) print ("IMAGE MODEL CKPT:", MODEL_NAME) load_network(model, MODEL_NAME, strict=True, param_key='params') title = "[ECCV 2024] SMFANet: A Lightweight Self-Modulation Feature Aggregation Network for Efficient Image Super-Resolution" description = ''' #### [Mingjun Zheng](https://github.com/Zheng-MJ), [Long Sun](https://github.com/sunny2109), [Jiangxin Dong](https://scholar.google.com/citations?user=ruebFVEAAAAJ&hl=zh-CN&oi=ao), and [Jinshan Pan](https://jspan.github.io/) #### [IMAG Lab](https://imag-njust.net/), Nanjing University of Science and Technology

###Network architecture of the proposed SMFANet. The proposed s SMFANet consists of a shallow feature extraction module, feature modulation blocks, and a lightweight image reconstruction module. Feature modulation block contains one self-modulation feature aggregation (SMFA) module and one partial convolution-based feed-forward network (PCFN).* #### Drag the slider on the super-resolution image left and right to see the changes in the image details. SeemoRe performs x4 upscaling on the input image.
@inproceedings{smfanet, title={SMFANet: A Lightweight Self-Modulation Feature Aggregation Network for Efficient Image Super-Resolution}, author={Zheng, Mingjun and Sun, Long and Dong, Jiangxin and Pan, Jinshan}, booktitle={Proceedings of the European Conference on Computer Vision (ECCV)}, year={2024} }
''' article = "

SMFANet: A Lightweight Self-Modulation Feature Aggregation Network for Efficient Image Super-Resolution

" #### Image,Prompts examples examples = [ ['images/0801x4.png'], ['images/0840x4.png'], ['images/0841x4.png'], ['images/0870x4.png'], ['images/0878x4.png'], ['images/0884x4.png'], ['images/0900x4.png'], ['images/img002x4.png'], ['images/img003x4.png'], ['images/img004x4.png'], ['images/img035x4.png'], ['images/img053x4.png'], ['images/img064x4.png'], ['images/img083x4.png'], ['images/img092x4.png'], ] css = """ .image-frame img, .image-container img { width: auto; height: auto; max-width: none; } """ demo = gr.Interface( fn=process_img, inputs=[gr.Image(type="pil", label="Input", value="images/0878x4.png"),], outputs=ImageSlider(label="Super-Resolved Image", type="pil", show_download_button=True, ), #[gr.Image(type="pil", label="Ouput", min_width=500)], title=title, description=description, article=article, examples=examples, css=css, ) if __name__ == "__main__": demo.launch()