sayakpaul HF staff commited on
Commit
399866a
1 Parent(s): 66747c6

Create new file

Browse files
Files changed (1) hide show
  1. app.py +144 -0
app.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub.keras_mixin import from_pretrained_keras
2
+
3
+ from PIL import Image
4
+
5
+ import numpy as np
6
+
7
+ from create_maxim_model import Model
8
+ from maxim.configs import MAXIM_CONFIGS
9
+
10
+
11
+ _MODEL = from_pretrained_keras("sayakpaul/S-2_enhancement_lol")
12
+
13
+
14
+ def mod_padding_symmetric(image, factor=64):
15
+ """Padding the image to be divided by factor."""
16
+ height, width = image.shape[0], image.shape[1]
17
+ height_pad, width_pad = ((height + factor) // factor) * factor, (
18
+ (width + factor) // factor
19
+ ) * factor
20
+ padh = height_pad - height if height % factor != 0 else 0
21
+ padw = width_pad - width if width % factor != 0 else 0
22
+ image = tf.pad(
23
+ image, [(padh // 2, padh // 2), (padw // 2, padw // 2), (0, 0)], mode="REFLECT"
24
+ )
25
+ return image
26
+
27
+ def _convert_input_type_range(img):
28
+ """Convert the type and range of the input image.
29
+
30
+ It converts the input image to np.float32 type and range of [0, 1].
31
+ It is mainly used for pre-processing the input image in colorspace
32
+ convertion functions such as rgb2ycbcr and ycbcr2rgb.
33
+ Args:
34
+ img (ndarray): The input image. It accepts:
35
+ 1. np.uint8 type with range [0, 255];
36
+ 2. np.float32 type with range [0, 1].
37
+ Returns:
38
+ (ndarray): The converted image with type of np.float32 and range of
39
+ [0, 1].
40
+ """
41
+ img_type = img.dtype
42
+ img = img.astype(np.float32)
43
+ if img_type == np.float32:
44
+ pass
45
+ elif img_type == np.uint8:
46
+ img /= 255.0
47
+ else:
48
+ raise TypeError(
49
+ "The img type should be np.float32 or np.uint8, " f"but got {img_type}"
50
+ )
51
+ return img
52
+
53
+
54
+ def _convert_output_type_range(img, dst_type):
55
+ """Convert the type and range of the image according to dst_type.
56
+
57
+ It converts the image to desired type and range. If `dst_type` is np.uint8,
58
+ images will be converted to np.uint8 type with range [0, 255]. If
59
+ `dst_type` is np.float32, it converts the image to np.float32 type with
60
+ range [0, 1].
61
+ It is mainly used for post-processing images in colorspace convertion
62
+ functions such as rgb2ycbcr and ycbcr2rgb.
63
+ Args:
64
+ img (ndarray): The image to be converted with np.float32 type and
65
+ range [0, 255].
66
+ dst_type (np.uint8 | np.float32): If dst_type is np.uint8, it
67
+ converts the image to np.uint8 type with range [0, 255]. If
68
+ dst_type is np.float32, it converts the image to np.float32 type
69
+ with range [0, 1].
70
+ Returns:
71
+ (ndarray): The converted image with desired type and range.
72
+ """
73
+ if dst_type not in (np.uint8, np.float32):
74
+ raise TypeError(
75
+ "The dst_type should be np.float32 or np.uint8, " f"but got {dst_type}"
76
+ )
77
+ if dst_type == np.uint8:
78
+ img = img.round()
79
+ else:
80
+ img /= 255.0
81
+
82
+ return img.astype(dst_type)
83
+
84
+
85
+ def make_shape_even(image):
86
+ """Pad the image to have even shapes."""
87
+ height, width = image.shape[0], image.shape[1]
88
+ padh = 1 if height % 2 != 0 else 0
89
+ padw = 1 if width % 2 != 0 else 0
90
+ image = tf.pad(image, [(0, padh), (0, padw), (0, 0)], mode="REFLECT")
91
+ return image
92
+
93
+
94
+ def process_image(image: Image):
95
+ input_img = np.asarray(image) / 255.0
96
+ height, width = input_img.shape[0], input_img.shape[1]
97
+
98
+ # Padding images to have even shapes
99
+ input_img = make_shape_even(input_img)
100
+ height_even, width_even = input_img.shape[0], input_img.shape[1]
101
+
102
+ # padding images to be multiplies of 64
103
+ input_img = mod_padding_symmetric(input_img, factor=64)
104
+ input_img = tf.expand_dims(input_img, axis=0)
105
+ return input_img, height_even, width_even
106
+
107
+
108
+ def init_new_model(input_img):
109
+ configs = MAXIM_CONFIGS.get("S-2")
110
+ configs.update(
111
+ {
112
+ "variant": "S-2",
113
+ "dropout_rate": 0.0,
114
+ "num_outputs": 3,
115
+ "use_bias": True,
116
+ "num_supervision_scales": 3,
117
+ }
118
+ )
119
+ configs.update({"input_resolution": (input_img.shape[1], input_img.shape[2])})
120
+ new_model = Model(**configs)
121
+ new_model.set_weights(_MODEL.get_weights())
122
+ return new_model
123
+
124
+
125
+ def infer(image):
126
+ preprocessed_image, height_even, width_even = process_image(image)
127
+ new_model = init_new_model(preprocessed_image)
128
+
129
+ preds = new_model.predict(preprocessed_image)
130
+ if isinstance(preds, list):
131
+ preds = preds[-1]
132
+ if isinstance(preds, list):
133
+ preds = preds[-1]
134
+
135
+ preds = np.array(preds[0], np.float32)
136
+
137
+ new_height, new_width = preds.shape[0], preds.shape[1]
138
+ h_start = new_height // 2 - height_even // 2
139
+ h_end = h_start + height
140
+ w_start = new_width // 2 - width_even // 2
141
+ w_end = w_start + width
142
+ preds = preds[h_start:h_end, w_start:w_end, :]
143
+
144
+ return Image.fromarray(np.array((np.clip(preds, 0.0, 1.0) * 255.0).astype(np.uint8)))