Spaces:
Sleeping
Sleeping
File size: 3,060 Bytes
af9162a 52893ae 501b516 a983c72 af9162a ba74db2 71c2965 501b516 a983c72 da5fdaa ece0ce5 71c2965 6cc096a fd548c6 ece0ce5 fd548c6 ece0ce5 da5fdaa 501b516 a983c72 501b516 a983c72 501b516 da5fdaa a983c72 501b516 af9162a a983c72 ba74db2 71c2965 ba74db2 69590ad 866446d ba74db2 ece0ce5 8a7fe4e da5fdaa 866446d 76c7ca0 8a7fe4e ba74db2 a983c72 a4dc15b 501b516 dbb94b0 7f74cd7 501b516 71c2965 a983c72 71c2965 744ad2f ba74db2 71c2965 501b516 71c2965 744ad2f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
import gradio as gr
import deepinv as dinv
import torch
import numpy as np
import PIL.Image
def pil_to_torch(image, ref_size=512):
image = np.array(image)
image = image.transpose((2, 0, 1))
image = torch.tensor(image).float() / 255
image = image.unsqueeze(0)
if ref_size == 256:
size = (ref_size, ref_size)
elif image.shape[2] > image.shape[3]:
size = (ref_size, ref_size * image.shape[3]//image.shape[2])
else:
size = (ref_size * image.shape[2]//image.shape[3], ref_size)
image = torch.nn.functional.interpolate(image, size=size, mode='bilinear')
return image
def torch_to_pil(image):
image = image.squeeze(0).cpu().detach().numpy()
image = image.transpose((1, 2, 0))
image = (np.clip(image, 0, 1) * 255).astype(np.uint8)
image = PIL.Image.fromarray(image)
return image
def image_mod(image, noise_level, denoiser):
image = pil_to_torch(image, ref_size=256 if denoiser == 'DiffUNet' else 512)
if denoiser == 'DnCNN':
den = dinv.models.DnCNN()
sigma0 = 2/255
denoiser = lambda x, sigma: den(x*sigma0/sigma)*sigma/sigma0
elif denoiser == 'MedianFilter':
denoiser = dinv.models.MedianFilter(kernel_size=5)
elif denoiser == 'BM3D':
denoiser = dinv.models.BM3D()
elif denoiser == 'TV':
denoiser = dinv.models.TVDenoiser()
elif denoiser == 'TGV':
denoiser = dinv.models.TGVDenoiser()
elif denoiser == 'Wavelets':
denoiser = dinv.models.WaveletPrior()
elif denoiser == 'DiffUNet':
denoiser = dinv.models.DiffUNet()
elif denoiser == 'DRUNet':
denoiser = dinv.models.DRUNet()
else:
raise ValueError("Invalid denoiser")
noisy = image + torch.randn_like(image) * noise_level
estimated = denoiser(noisy, noise_level)
return torch_to_pil(noisy), torch_to_pil(estimated)
input_image = gr.Image(label='Input Image')
output_images = gr.Image(label='Denoised Image')
noise_image = gr.Image(label='Noisy Image')
input_image_output = gr.Image(label='Input Image')
noise_levels = gr.Dropdown(choices=[0.05, 0.1, 0.2, 0.3, 0.5, 1], value=0.1, label='Noise Level')
denoiser = gr.Dropdown(choices=['DnCNN', 'DRUNet', 'DiffUNet', 'BM3D', 'MedianFilter', 'TV', 'TGV', 'Wavelets'], value='DRUNet', label='Denoiser')
demo = gr.Interface(
image_mod,
inputs=[input_image, noise_levels, denoiser],
examples=[['https://upload.wikimedia.org/wikipedia/commons/b/b4/Lionel-Messi-Argentina-2022-FIFA-World-Cup_%28cropped%29.jpg', 0.1, 'DRUNet']],
outputs=[noise_image, output_images],
title="Image Denoising with DeepInverse",
description="Denoise an image using a variety of denoisers and noise levels using the deepinverse library (https://deepinv.github.io/). We only include lightweight models like DnCNN and MedianFilter as this example is intended to be run on a CPU. We also automatically resize the input image to 512 pixels to reduce the computation time. For more advanced models, please run the code locally.",
)
demo.launch() |