File size: 1,863 Bytes
bc5e1ec
 
56d2483
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import gradio as gr
import torch
import torchvision.transforms as transforms
from PIL import Image
import os

# Download the required model checkpoint if it doesn't exist
MODEL_PATH = "weights/RealESRGAN_x4 Anime6B.pth"
if not os.path.exists(MODEL_PATH):
    os.system(f"wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4_Anime6B.pth -P weights")

# Define the Real-ESRGAN class using the pre-trained model path
class RealESRGAN:
    def __init__(self, device):
        self.model = torch.hub.load("xinntao/Real-ESRGAN", "restoration", model="RealESRGAN_x4_Anime6B.pth", device=device)
        self.model.eval()

    def upscale(self, image_tensor):
        with torch.no_grad():
            sr_image_tensor = self.model(image_tensor).clamp(0, 1)
        return sr_image_tensor

device = "cuda" if torch.cuda.is_available() else "cpu"
model = RealESRGAN(device)

# Define the image upscaling function
def upscale_image(image):
    # Pre-process the input image
    image = image.convert("RGB")
    transform = transforms.ToTensor()
    image_tensor = transform(image).unsqueeze(0)  # Add batch dimension

    # Perform upscaling
    sr_image_tensor = model.upscale(image_tensor)

    # Post-process the output image and convert back to PIL
    sr_image = (sr_image_tensor.squeeze(0).numpy() * 255).astype(np.uint8)
    sr_image = Image.fromarray(sr_image.transpose(1, 2, 0))  # Change to HWC format
    return sr_image

# Create a Gradio interface
interface = gr.Interface(
    fn=upscale_image,
    inputs=gr.inputs.Image(type="pil"),
    outputs=gr.outputs.Image(type="pil"),
    title="Image Upscaler with R-ESRGAN Anime 6B",
    description="Upload an image, and it'll be upscaled using the R-ESRGAN Anime model. Supports inputs in various formats."
)

# Launch the interface
if __name__ == "__main__":
    interface.launch()