import gradio as gr import torch import torchvision.transforms as transforms from PIL import Image import os # Download the required model checkpoint if it doesn't exist MODEL_PATH = "weights/RealESRGAN_x4 Anime6B.pth" if not os.path.exists(MODEL_PATH): os.system(f"wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4_Anime6B.pth -P weights") # Define the Real-ESRGAN class using the pre-trained model path class RealESRGAN: def __init__(self, device): self.model = torch.hub.load("xinntao/Real-ESRGAN", "restoration", model="RealESRGAN_x4_Anime6B.pth", device=device) self.model.eval() def upscale(self, image_tensor): with torch.no_grad(): sr_image_tensor = self.model(image_tensor).clamp(0, 1) return sr_image_tensor device = "cuda" if torch.cuda.is_available() else "cpu" model = RealESRGAN(device) # Define the image upscaling function def upscale_image(image): # Pre-process the input image image = image.convert("RGB") transform = transforms.ToTensor() image_tensor = transform(image).unsqueeze(0) # Add batch dimension # Perform upscaling sr_image_tensor = model.upscale(image_tensor) # Post-process the output image and convert back to PIL sr_image = (sr_image_tensor.squeeze(0).numpy() * 255).astype(np.uint8) sr_image = Image.fromarray(sr_image.transpose(1, 2, 0)) # Change to HWC format return sr_image # Create a Gradio interface interface = gr.Interface( fn=upscale_image, inputs=gr.inputs.Image(type="pil"), outputs=gr.outputs.Image(type="pil"), title="Image Upscaler with R-ESRGAN Anime 6B", description="Upload an image, and it'll be upscaled using the R-ESRGAN Anime model. Supports inputs in various formats." ) # Launch the interface if __name__ == "__main__": interface.launch()