Spaces:
Runtime error
Runtime error
File size: 2,243 Bytes
bc5e1ec 56d2483 c08ae31 ea89a8b 56d2483 c08ae31 55425f6 c08ae31 56d2483 c08ae31 56d2483 c08ae31 4cce72f 56d2483 c08ae31 56d2483 4cce72f 56d2483 ea89a8b 56d2483 ea89a8b 56d2483 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 |
import gradio as gr
import torch
import torchvision.transforms as transforms
from PIL import Image
import os
import requests
import numpy as np
# Define the model URL and local path
MODEL_URL = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth"
MODEL_PATH = "RealESRGAN_x4_Anime6B.pth"
# Download the model if it doesn't exist
if not os.path.exists(MODEL_PATH):
response = requests.get(MODEL_URL)
if response.status_code == 200:
with open(MODEL_PATH, 'wb') as f:
f.write(response.content)
else:
raise Exception(f"Could not download model: {response.status_code}")
# Define the Real-ESRGAN class
class RealESRGAN:
def __init__(self, model_path, device):
# Correctly load the model
self.model = torch.hub.load('xinntao/Real-ESRGAN', 'real_esrgan', model_path, device=device)
self.model.eval() # Set the model to eval mode
self.device = device # Store the device
def upscale(self, image_tensor):
with torch.no_grad():
sr_image_tensor = self.model(image_tensor).clamp(0, 1)
return sr_image_tensor
device = "cuda" if torch.cuda.is_available() else "cpu"
model = RealESRGAN(MODEL_PATH, device)
# Define the image upscaling function
def upscale_image(image):
# Pre-process the input image
image = image.convert("RGB")
transform = transforms.ToTensor()
image_tensor = transform(image).unsqueeze(0).to(device) # Add batch dimension and move to device
# Perform upscaling
sr_image_tensor = model.upscale(image_tensor)
# Post-process the output image and convert back to PIL
sr_image = (sr_image_tensor.squeeze(0).cpu().numpy() * 255).astype(np.uint8)
sr_image = Image.fromarray(sr_image.transpose(1, 2, 0)) # Change to HWC format
return sr_image
# Create a Gradio interface
interface = gr.Interface(
fn=upscale_image,
inputs=gr.Image(type="pil"),
outputs=gr.Image(type="pil"),
title="Image Upscaler with R-ESRGAN Anime 6B",
description="Upload an image, and it'll be upscaled using the R-ESRGAN Anime model. Supports inputs in various formats."
)
# Launch the interface
if __name__ == "__main__":
interface.launch() |