import gradio as gr import torch import torchvision.transforms as transforms from PIL import Image import os import requests import numpy as np # Define the model URL and local path MODEL_URL = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth" MODEL_PATH = "RealESRGAN_x4_Anime6B.pth" # Download the model if it doesn't exist if not os.path.exists(MODEL_PATH): response = requests.get(MODEL_URL) if response.status_code == 200: with open(MODEL_PATH, 'wb') as f: f.write(response.content) else: raise Exception(f"Could not download model: {response.status_code}") # Define the Real-ESRGAN class class RealESRGAN: def __init__(self, model_path, device): # Correctly load the model self.model = torch.hub.load('xinntao/Real-ESRGAN', 'real_esrgan', model_path, device=device) self.model.eval() # Set the model to eval mode self.device = device # Store the device def upscale(self, image_tensor): with torch.no_grad(): sr_image_tensor = self.model(image_tensor).clamp(0, 1) return sr_image_tensor device = "cuda" if torch.cuda.is_available() else "cpu" model = RealESRGAN(MODEL_PATH, device) # Define the image upscaling function def upscale_image(image): # Pre-process the input image image = image.convert("RGB") transform = transforms.ToTensor() image_tensor = transform(image).unsqueeze(0).to(device) # Add batch dimension and move to device # Perform upscaling sr_image_tensor = model.upscale(image_tensor) # Post-process the output image and convert back to PIL sr_image = (sr_image_tensor.squeeze(0).cpu().numpy() * 255).astype(np.uint8) sr_image = Image.fromarray(sr_image.transpose(1, 2, 0)) # Change to HWC format return sr_image # Create a Gradio interface interface = gr.Interface( fn=upscale_image, inputs=gr.Image(type="pil"), outputs=gr.Image(type="pil"), title="Image Upscaler with R-ESRGAN Anime 6B", description="Upload an image, and it'll be upscaled using the R-ESRGAN Anime model. Supports inputs in various formats." ) # Launch the interface if __name__ == "__main__": interface.launch()