import gradio as gr import deepinv as dinv import torch import numpy as np import PIL.Image def pil_to_torch(image): image = np.array(image) image = image.transpose((2, 0, 1)) image = torch.tensor(image).float() / 255 return image.unsqueeze(0) def torch_to_pil(image): image = image.squeeze(0).cpu().detach().numpy() image = image.transpose((1, 2, 0)) image = (image * 255).astype(np.uint8) image = PIL.Image.fromarray(image) return image def image_mod(image, noise_level, denoiser): image = pil_to_torch(image) if denoiser == 'DnCNN': denoiser = dinv.models.DnCNN() elif denoiser == 'MedianFilter': denoiser = dinv.models.MedianFilter() elif denoiser == 'BM3D': denoiser = dinv.models.BM3D() elif denoiser == 'DRUNet': denoiser = dinv.models.DRUNet() else: raise ValueError("Invalid denoiser") noisy = image + torch.randn_like(image) * noise_level estimated = denoiser(image, noise_level) return torch_to_pil(noisy), torch_to_pil(estimated) input_image = gr.Image(label='Input Image') output_images = gr.Image(label='Denoised Image') noise_image = gr.Image(label='Noisy Image') input_image_output = gr.Image(label='Input Image') noise_levels = gr.Dropdown(choices=[0.1, 0.2, 0.3, 0.4, 0.5], value=0.1, label='Noise Level') denoiser = gr.Dropdown(choices=['DnCNN', 'DRUNet', 'BM3D', 'MedianFilter'], value=0.1, label='DRUNet') demo = gr.Interface( image_mod, inputs=[input_image, noise_levels, denoiser], outputs=[noise_image, output_images], title="Image Denoising with DeepInverse", ) demo.launch()