File size: 1,965 Bytes
49f65e4
 
269eef7
49f65e4
 
 
 
a8c5eee
49f65e4
a8c5eee
49f65e4
a8c5eee
49f65e4
 
c8067b8
7f5a3aa
 
 
df80b79
 
 
70913b7
 
 
df80b79
e198db2
 
 
 
 
 
7f5a3aa
49f65e4
 
 
d911574
9c7d136
0f5662c
49f65e4
269eef7
49f65e4
269eef7
 
 
 
49f65e4
269eef7
49f65e4
 
 
 
 
376476a
f8bccd7
ae52322
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import torch
from PIL import Image
from RealESRGAN import RealESRGAN
import gradio as gr

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model2 = RealESRGAN(device, scale=2)
model2.load_weights('weights/RealESRGAN_x2.pth', download=True)
model4 = RealESRGAN(device, scale=4)
model4.load_weights('weights/RealESRGAN_x4.pth', download=True)
model8 = RealESRGAN(device, scale=8)
model8.load_weights('weights/RealESRGAN_x8.pth', download=True)


def inference(image, size):
    if image is None:
        raise gr.Error("Image not uploaded")
        
    width, height = image.size
    if width >= 5000 or height >= 5000:
        raise gr.Error("The image is too large.")

    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        
    if size == '2x':
        result = model2.predict(image.convert('RGB'))
    elif size == '4x':
        result = model4.predict(image.convert('RGB'))
    else:
        result = model8.predict(image.convert('RGB'))
    print(f"Image size ({device}): {size} ... OK")
    return result


title = "Face Real ESRGAN UpScale: 2x 4x 8x"
description = "This is an unofficial demo for Real-ESRGAN. Scales the resolution of a photo. This model shows better results on faces compared to the original version.<br>Telegram BOT: https://t.me/restoration_photo_bot"
article = "<div style='text-align: center;'>Twitter <a href='https://twitter.com/DoEvent' target='_blank'>Max Skobeev</a> | <a href='https://huggingface.co/sberbank-ai/Real-ESRGAN' target='_blank'>Model card</a><div>"


gr.Interface(inference,
    [gr.Image(type="pil"), 
    gr.Radio(['2x', '4x', '8x'], 
    type="value",
    value='2x',
    label='Resolution model')], 
    gr.Image(type="pil", label="Output"),
    title=title,
    description=description,
    article=article,
    examples=[['groot.jpeg', "2x"]],
    allow_flagging='never',
    cache_examples=False,
    ).queue(api_open=False).launch(show_error=True, show_api=False)