NickKolok's picture
Update app.py
4cce72f verified
raw
history blame
2.24 kB
import gradio as gr
import torch
import torchvision.transforms as transforms
from PIL import Image
import os
import requests
import numpy as np
# Define the model URL and local path
MODEL_URL = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth"
MODEL_PATH = "RealESRGAN_x4_Anime6B.pth"
# Download the model if it doesn't exist
if not os.path.exists(MODEL_PATH):
response = requests.get(MODEL_URL)
if response.status_code == 200:
with open(MODEL_PATH, 'wb') as f:
f.write(response.content)
else:
raise Exception(f"Could not download model: {response.status_code}")
# Define the Real-ESRGAN class
class RealESRGAN:
def __init__(self, model_path, device):
# Correctly load the model
self.model = torch.hub.load('xinntao/Real-ESRGAN', 'real_esrgan', model_path, device=device)
self.model.eval() # Set the model to eval mode
self.device = device # Store the device
def upscale(self, image_tensor):
with torch.no_grad():
sr_image_tensor = self.model(image_tensor).clamp(0, 1)
return sr_image_tensor
device = "cuda" if torch.cuda.is_available() else "cpu"
model = RealESRGAN(MODEL_PATH, device)
# Define the image upscaling function
def upscale_image(image):
# Pre-process the input image
image = image.convert("RGB")
transform = transforms.ToTensor()
image_tensor = transform(image).unsqueeze(0).to(device) # Add batch dimension and move to device
# Perform upscaling
sr_image_tensor = model.upscale(image_tensor)
# Post-process the output image and convert back to PIL
sr_image = (sr_image_tensor.squeeze(0).cpu().numpy() * 255).astype(np.uint8)
sr_image = Image.fromarray(sr_image.transpose(1, 2, 0)) # Change to HWC format
return sr_image
# Create a Gradio interface
interface = gr.Interface(
fn=upscale_image,
inputs=gr.Image(type="pil"),
outputs=gr.Image(type="pil"),
title="Image Upscaler with R-ESRGAN Anime 6B",
description="Upload an image, and it'll be upscaled using the R-ESRGAN Anime model. Supports inputs in various formats."
)
# Launch the interface
if __name__ == "__main__":
interface.launch()