Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
import torchvision.transforms as transforms | |
from PIL import Image | |
import os | |
# Download the required model checkpoint if it doesn't exist | |
MODEL_PATH = "weights/RealESRGAN_x4 Anime6B.pth" | |
if not os.path.exists(MODEL_PATH): | |
os.system(f"wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4_Anime6B.pth -P weights") | |
# Define the Real-ESRGAN class using the pre-trained model path | |
class RealESRGAN: | |
def __init__(self, device): | |
self.model = torch.hub.load("xinntao/Real-ESRGAN", "restoration", model="RealESRGAN_x4_Anime6B.pth", device=device) | |
self.model.eval() | |
def upscale(self, image_tensor): | |
with torch.no_grad(): | |
sr_image_tensor = self.model(image_tensor).clamp(0, 1) | |
return sr_image_tensor | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model = RealESRGAN(device) | |
# Define the image upscaling function | |
def upscale_image(image): | |
# Pre-process the input image | |
image = image.convert("RGB") | |
transform = transforms.ToTensor() | |
image_tensor = transform(image).unsqueeze(0) # Add batch dimension | |
# Perform upscaling | |
sr_image_tensor = model.upscale(image_tensor) | |
# Post-process the output image and convert back to PIL | |
sr_image = (sr_image_tensor.squeeze(0).numpy() * 255).astype(np.uint8) | |
sr_image = Image.fromarray(sr_image.transpose(1, 2, 0)) # Change to HWC format | |
return sr_image | |
# Create a Gradio interface | |
interface = gr.Interface( | |
fn=upscale_image, | |
inputs=gr.inputs.Image(type="pil"), | |
outputs=gr.outputs.Image(type="pil"), | |
title="Image Upscaler with R-ESRGAN Anime 6B", | |
description="Upload an image, and it'll be upscaled using the R-ESRGAN Anime model. Supports inputs in various formats." | |
) | |
# Launch the interface | |
if __name__ == "__main__": | |
interface.launch() |