NickKolok's picture
Initial - as done by ChatGPT
56d2483 verified
raw
history blame
1.86 kB
import gradio as gr
import torch
import torchvision.transforms as transforms
from PIL import Image
import os
# Download the required model checkpoint if it doesn't exist
MODEL_PATH = "weights/RealESRGAN_x4 Anime6B.pth"
if not os.path.exists(MODEL_PATH):
os.system(f"wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4_Anime6B.pth -P weights")
# Define the Real-ESRGAN class using the pre-trained model path
class RealESRGAN:
def __init__(self, device):
self.model = torch.hub.load("xinntao/Real-ESRGAN", "restoration", model="RealESRGAN_x4_Anime6B.pth", device=device)
self.model.eval()
def upscale(self, image_tensor):
with torch.no_grad():
sr_image_tensor = self.model(image_tensor).clamp(0, 1)
return sr_image_tensor
device = "cuda" if torch.cuda.is_available() else "cpu"
model = RealESRGAN(device)
# Define the image upscaling function
def upscale_image(image):
# Pre-process the input image
image = image.convert("RGB")
transform = transforms.ToTensor()
image_tensor = transform(image).unsqueeze(0) # Add batch dimension
# Perform upscaling
sr_image_tensor = model.upscale(image_tensor)
# Post-process the output image and convert back to PIL
sr_image = (sr_image_tensor.squeeze(0).numpy() * 255).astype(np.uint8)
sr_image = Image.fromarray(sr_image.transpose(1, 2, 0)) # Change to HWC format
return sr_image
# Create a Gradio interface
interface = gr.Interface(
fn=upscale_image,
inputs=gr.inputs.Image(type="pil"),
outputs=gr.outputs.Image(type="pil"),
title="Image Upscaler with R-ESRGAN Anime 6B",
description="Upload an image, and it'll be upscaled using the R-ESRGAN Anime model. Supports inputs in various formats."
)
# Launch the interface
if __name__ == "__main__":
interface.launch()