Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
import torchvision.transforms as transforms | |
from PIL import Image | |
import os | |
import requests | |
import numpy as np | |
# Define the model URL and local path | |
MODEL_URL = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth" | |
MODEL_PATH = "RealESRGAN_x4_Anime6B.pth" | |
# Download the model if it doesn't exist | |
if not os.path.exists(MODEL_PATH): | |
response = requests.get(MODEL_URL) | |
if response.status_code == 200: | |
with open(MODEL_PATH, 'wb') as f: | |
f.write(response.content) | |
else: | |
raise Exception(f"Could not download model: {response.status_code}") | |
# Define the Real-ESRGAN class | |
class RealESRGAN: | |
def __init__(self, model_path, device): | |
# Correctly load the model | |
self.model = torch.hub.load('xinntao/Real-ESRGAN', 'real_esrgan', model_path, device=device) | |
self.model.eval() # Set the model to eval mode | |
self.device = device # Store the device | |
def upscale(self, image_tensor): | |
with torch.no_grad(): | |
sr_image_tensor = self.model(image_tensor).clamp(0, 1) | |
return sr_image_tensor | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model = RealESRGAN(MODEL_PATH, device) | |
# Define the image upscaling function | |
def upscale_image(image): | |
# Pre-process the input image | |
image = image.convert("RGB") | |
transform = transforms.ToTensor() | |
image_tensor = transform(image).unsqueeze(0).to(device) # Add batch dimension and move to device | |
# Perform upscaling | |
sr_image_tensor = model.upscale(image_tensor) | |
# Post-process the output image and convert back to PIL | |
sr_image = (sr_image_tensor.squeeze(0).cpu().numpy() * 255).astype(np.uint8) | |
sr_image = Image.fromarray(sr_image.transpose(1, 2, 0)) # Change to HWC format | |
return sr_image | |
# Create a Gradio interface | |
interface = gr.Interface( | |
fn=upscale_image, | |
inputs=gr.Image(type="pil"), | |
outputs=gr.Image(type="pil"), | |
title="Image Upscaler with R-ESRGAN Anime 6B", | |
description="Upload an image, and it'll be upscaled using the R-ESRGAN Anime model. Supports inputs in various formats." | |
) | |
# Launch the interface | |
if __name__ == "__main__": | |
interface.launch() |