File size: 2,243 Bytes
bc5e1ec
 
56d2483
 
 
c08ae31
ea89a8b
56d2483
c08ae31
55425f6
c08ae31
56d2483
c08ae31
 
 
 
 
 
 
 
 
 
56d2483
c08ae31
4cce72f
 
 
 
 
56d2483
 
 
 
 
 
c08ae31
56d2483
 
 
 
 
 
4cce72f
56d2483
 
 
 
 
ea89a8b
56d2483
 
 
 
 
 
ea89a8b
 
56d2483
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import gradio as gr
import torch
import torchvision.transforms as transforms
from PIL import Image
import os
import requests
import numpy as np

# Define the model URL and local path
MODEL_URL = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth"
MODEL_PATH = "RealESRGAN_x4_Anime6B.pth"

# Download the model if it doesn't exist
if not os.path.exists(MODEL_PATH):
    response = requests.get(MODEL_URL)
    if response.status_code == 200:
        with open(MODEL_PATH, 'wb') as f:
            f.write(response.content)
    else:
        raise Exception(f"Could not download model: {response.status_code}")

# Define the Real-ESRGAN class
class RealESRGAN:
    def __init__(self, model_path, device):
        # Correctly load the model
        self.model = torch.hub.load('xinntao/Real-ESRGAN', 'real_esrgan', model_path, device=device)
        self.model.eval()  # Set the model to eval mode
        self.device = device  # Store the device
    
    def upscale(self, image_tensor):
        with torch.no_grad():
            sr_image_tensor = self.model(image_tensor).clamp(0, 1)
        return sr_image_tensor

device = "cuda" if torch.cuda.is_available() else "cpu"
model = RealESRGAN(MODEL_PATH, device)

# Define the image upscaling function
def upscale_image(image):
    # Pre-process the input image
    image = image.convert("RGB")
    transform = transforms.ToTensor()
    image_tensor = transform(image).unsqueeze(0).to(device)  # Add batch dimension and move to device

    # Perform upscaling
    sr_image_tensor = model.upscale(image_tensor)

    # Post-process the output image and convert back to PIL
    sr_image = (sr_image_tensor.squeeze(0).cpu().numpy() * 255).astype(np.uint8)
    sr_image = Image.fromarray(sr_image.transpose(1, 2, 0))  # Change to HWC format
    return sr_image

# Create a Gradio interface
interface = gr.Interface(
    fn=upscale_image,
    inputs=gr.Image(type="pil"),
    outputs=gr.Image(type="pil"),
    title="Image Upscaler with R-ESRGAN Anime 6B",
    description="Upload an image, and it'll be upscaled using the R-ESRGAN Anime model. Supports inputs in various formats."
)

# Launch the interface
if __name__ == "__main__":
    interface.launch()