Hugging Face's logo Hugging Face Search models, datasets, users... Models Datasets Spaces Posts Docs Solutions Pricing Spaces: Meloo / SAFMN like 2 Logs App Files Community Settings SAFMN / app.py Meloo's picture Meloo Update app.py f0dd1d3 verified 15 days ago raw Copy download link history blame edit delete 7.32 kB import os import cv2 import argparse import glob import numpy as np import os import torch import torch.nn.functional as F import gradio as gr from PIL import Image from utils.download_url import load_file_from_url from utils.color_fix import wavelet_reconstruction from models.safmn_arch import SAFMN from gradio_imageslider import ImageSlider pretrain_model_url = { 'safmn_x2': 'https://github.com/sunny2109/SAFMN/releases/download/v0.1.0/SAFMN_L_Real_LSDIR_x2-v2.pth', 'safmn_x4': 'https://github.com/sunny2109/SAFMN/releases/download/v0.1.0/SAFMN_L_Real_LSDIR_x4-v2.pth', } # download weights if not os.path.exists('pretrained_models/SAFMN_L_Real_LSDIR_x2-v2.pth'): load_file_from_url(url=pretrain_model_url['safmn_x2'], model_dir='./pretrained_models/', progress=True, file_name=None) if not os.path.exists('pretrained_models/SAFMN_L_Real_LSDIR_x4-v2.pth'): load_file_from_url(url=pretrain_model_url['safmn_x4'], model_dir='./pretrained_models/', progress=True, file_name=None) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') def set_safmn(upscale): model = SAFMN(dim=128, n_blocks=16, ffn_scale=2.0, upscaling_factor=upscale) if upscale == 2: model_path = 'pretrained_models/SAFMN_L_Real_LSDIR_x2-v2.pth' elif upscale == 4: model_path = 'pretrained_models/SAFMN_L_Real_LSDIR_x4-v2.pth' else: raise NotImplementedError('Only support x2/x4 upscaling!') model.load_state_dict(torch.load(model_path)['params'], strict=True) model.eval() return model.to(device) def img2patch(lq, scale=4, crop_size=512): b, c, hl, wl = lq.size() h, w = hl*scale, wl*scale sr_size = (b, c, h, w) assert b == 1 crop_size_h, crop_size_w = crop_size // scale * scale, crop_size // scale * scale #adaptive step_i, step_j num_row = (h - 1) // crop_size_h + 1 num_col = (w - 1) // crop_size_w + 1 import math step_j = crop_size_w if num_col == 1 else math.ceil((w - crop_size_w) / (num_col - 1) - 1e-8) step_i = crop_size_h if num_row == 1 else math.ceil((h - crop_size_h) / (num_row - 1) - 1e-8) step_i = step_i // scale * scale step_j = step_j // scale * scale parts = [] idxes = [] i = 0 # 0~h-1 last_i = False while i < h and not last_i: j = 0 if i + crop_size_h >= h: i = h - crop_size_h last_i = True last_j = False while j < w and not last_j: if j + crop_size_w >= w: j = w - crop_size_w last_j = True parts.append(lq[:, :, i // scale :(i + crop_size_h) // scale, j // scale:(j + crop_size_w) // scale]) idxes.append({'i': i, 'j': j}) j = j + step_j i = i + step_i return torch.cat(parts, dim=0), idxes, sr_size def patch2img(outs, idxes, sr_size, scale=4, crop_size=512): preds = torch.zeros(sr_size).to(outs.device) b, c, h, w = sr_size count_mt = torch.zeros((b, 1, h, w)).to(outs.device) crop_size_h, crop_size_w = crop_size // scale * scale, crop_size // scale * scale for cnt, each_idx in enumerate(idxes): i = each_idx['i'] j = each_idx['j'] preds[0, :, i: i + crop_size_h, j: j + crop_size_w] += outs[cnt] count_mt[0, 0, i: i + crop_size_h, j: j + crop_size_w] += 1. return (preds / count_mt).to(outs.device) def inference(image, upscale, large_input_flag, color_fix): if upscale is None or not isinstance(upscale, (int, float)) or upscale == 3.: upscale = 2 upscale = int(upscale) model = set_safmn(upscale) # img2tensor y = np.array(image).astype(np.float32) / 255. y = torch.from_numpy(np.transpose(y[:, :, [2, 1, 0]], (2, 0, 1))).float() y = y.unsqueeze(0).to(device) # inference if large_input_flag: patches, idx, size = img2patch(y, scale=upscale) with torch.no_grad(): n = len(patches) outs = [] m = 1 i = 0 while i < n: j = i + m if j >= n: j = n pred = output = model(patches[i:j]) if isinstance(pred, list): pred = pred[-1] outs.append(pred.detach()) i = j output = torch.cat(outs, dim=0) output = patch2img(output, idx, size, scale=upscale) else: with torch.no_grad(): output = model(y) # color fix if color_fix: y = F.interpolate(y, scale_factor=upscale, mode='bilinear') output = wavelet_reconstruction(output, y) # tensor2img output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy() if output.ndim == 3: output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0)) output = (output * 255.0).round().astype(np.uint8) # save results save_path = './out.png' cv2.imwrite(save_path, output[:, :, ::-1]) return (image, Image.fromarray(output)), save_path title = "SAFMN for Real-world SR (running on CPU)" description = ''' ### Spatially-Adaptive Feature Modulation for Efficient Image Super-Resolution - ICCV 2023 ### [Long Sun](https://github.com/sunny2109), [Jiangxin Dong](https://scholar.google.com/citations?user=ruebFVEAAAAJ&hl=zh-CN&oi=ao), [Jinhui Tang](https://scholar.google.com/citations?user=ByBLlEwAAAAJ&hl=zh-CN), and [Jinshan Pan](https://jspan.github.io/) ### [IMAG Lab](https://imag-njust.net/), Nanjing University of Science and Technology ### Drag the slider on the super-resolution image left and right to see the changes in the image details. ### SAFMN performs x2/x4 upscaling on the input image. If the input image is larger than 720P, it is recommended to use Memory-efficient inference. ### If our work is useful for your research, please consider citing:
@inproceedings{sun2023safmn, title={Spatially-Adaptive Feature Modulation for Efficient Image Super-Resolution}, author={Sun, Long and Dong, Jiangxin and Tang, Jinhui and Pan, Jinshan}, booktitle={ICCV}, year={2023} }
''' article = "

Spatially-Adaptive Feature Modulation for Efficient Image Super-Resolution

" #### Image examples examples = [ ['real_testdata/060.png'], ['real_testdata/004.png'], ['real_testdata/013.png'], ['real_testdata/014.png'], ['real_testdata/015.png'], ['real_testdata/021.png'], ['real_testdata/032.png'], ['real_testdata/045.png'], ['real_testdata/036.png'], ['real_testdata/058.png'], ] css = """ .image-frame img, .image-container img { width: auto; height: auto; max-width: none; } """ demo = gr.Interface( fn=inference, inputs=[ gr.Image(value="real_testdata/060.png", type="pil", label="Input"), gr.Number(minimum=2, maximum=4, label="Upscaling factor (up to 4)"), gr.Checkbox(value=False, label="Memory-efficient inference"), gr.Checkbox(value=False, label="Color correction"), ], outputs = [ ImageSlider(label="Super-Resolved Image", type="pil", show_download_button=True, ), gr.File(label="Download Output") # gr.Image( # label="Download Output", # type='filepath', # ), ], title=title, description=description, article=article, examples=examples, css=css, ) if __name__ == "__main__": demo.launch()