File size: 3,695 Bytes
3f9d71f
 
 
 
 
 
 
399866a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3f9d71f
 
399866a
 
 
 
 
 
 
 
 
 
 
3f9d71f
 
 
399866a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3f9d71f
 
399866a
3f9d71f
399866a
3f9d71f
399866a
 
 
 
 
3f9d71f
399866a
3f9d71f
399866a
 
 
 
 
 
3f9d71f
 
 
 
 
d1ed60e
3f9d71f
 
 
 
 
 
6b6047c
3f9d71f
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
"""
Some preprocessing utilities have been taken from:
https://github.com/google-research/maxim/blob/main/maxim/run_eval.py
"""
import gradio as gr
import numpy as np
import tensorflow as tf
from huggingface_hub.keras_mixin import from_pretrained_keras
from PIL import Image

from create_maxim_model import Model
from maxim.configs import MAXIM_CONFIGS

_MODEL = from_pretrained_keras("sayakpaul/S-2_enhancement_lol")


def mod_padding_symmetric(image, factor=64):
    """Padding the image to be divided by factor."""
    height, width = image.shape[0], image.shape[1]
    height_pad, width_pad = ((height + factor) // factor) * factor, (
        (width + factor) // factor
    ) * factor
    padh = height_pad - height if height % factor != 0 else 0
    padw = width_pad - width if width % factor != 0 else 0
    image = tf.pad(
        image, [(padh // 2, padh // 2), (padw // 2, padw // 2), (0, 0)], mode="REFLECT"
    )
    return image


def make_shape_even(image):
    """Pad the image to have even shapes."""
    height, width = image.shape[0], image.shape[1]
    padh = 1 if height % 2 != 0 else 0
    padw = 1 if width % 2 != 0 else 0
    image = tf.pad(image, [(0, padh), (0, padw), (0, 0)], mode="REFLECT")
    return image


def process_image(image: Image):
    input_img = np.asarray(image) / 255.0
    height, width = input_img.shape[0], input_img.shape[1]

    # Padding images to have even shapes
    input_img = make_shape_even(input_img)
    height_even, width_even = input_img.shape[0], input_img.shape[1]

    # padding images to be multiplies of 64
    input_img = mod_padding_symmetric(input_img, factor=64)
    input_img = tf.expand_dims(input_img, axis=0)
    return input_img, height, width, height_even, width_even


def init_new_model(input_img):
    configs = MAXIM_CONFIGS.get("S-2")
    configs.update(
        {
            "variant": "S-2",
            "dropout_rate": 0.0,
            "num_outputs": 3,
            "use_bias": True,
            "num_supervision_scales": 3,
        }
    )
    configs.update({"input_resolution": (input_img.shape[1], input_img.shape[2])})
    new_model = Model(**configs)
    new_model.set_weights(_MODEL.get_weights())
    return new_model


def infer(image):
    preprocessed_image, height, width, height_even, width_even = process_image(image)
    new_model = init_new_model(preprocessed_image)

    preds = new_model.predict(preprocessed_image)
    if isinstance(preds, list):
        preds = preds[-1]
        if isinstance(preds, list):
            preds = preds[-1]

    preds = np.array(preds[0], np.float32)

    new_height, new_width = preds.shape[0], preds.shape[1]
    h_start = new_height // 2 - height_even // 2
    h_end = h_start + height
    w_start = new_width // 2 - width_even // 2
    w_end = w_start + width
    preds = preds[h_start:h_end, w_start:w_end, :]

    return Image.fromarray(np.array((np.clip(preds, 0.0, 1.0) * 255.0).astype(np.uint8)))


title = "Enhance low-light images."
description = "The underlying model is [this](https://huggingface.co/sayakpaul/S-2_enhancement_lol). You can use the model to enhance low-light images, which may be useful for aiding vision-impaired people. To quickly try out the model, you can choose from the available sample images below, or you can submit your own image. Not that, internally, the model is re-initialized based on the spatial dimensions of the input image and this process is time-consuming."

iface = gr.Interface(
    infer,
    inputs="image",
    outputs="image",
    title=title,
    description=description,
    allow_flagging="never",
    examples=[["1.png"], ["111.png"], ["748.png"], ["a4541-DSC_0040-2.png"]],
)
iface.launch(debug=True)