add gradio app
Browse files- README.md +3 -0
- app.py +70 -0
- climategan/generator.py +5 -1
- climategan/masker.py +3 -3
- inferences.py +108 -0
README.md
CHANGED
@@ -10,6 +10,9 @@ title: ClimateGAN
|
|
10 |
emoji: π
|
11 |
colorFrom: blue
|
12 |
colorTo: green
|
|
|
|
|
|
|
13 |
# datasets:
|
14 |
# -
|
15 |
---
|
|
|
10 |
emoji: π
|
11 |
colorFrom: blue
|
12 |
colorTo: green
|
13 |
+
sdk: gradio
|
14 |
+
sdk_version: 4.6
|
15 |
+
app_file: app.py
|
16 |
# datasets:
|
17 |
# -
|
18 |
---
|
app.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# based on https://huggingface.co/spaces/NimaBoscarino/climategan/blob/main/app.py # noqa: E501
|
2 |
+
# thank you @NimaBoscarino
|
3 |
+
|
4 |
+
import os
|
5 |
+
import gradio as gr
|
6 |
+
import googlemaps
|
7 |
+
from skimage import io
|
8 |
+
from urllib import parse
|
9 |
+
from inferences import ClimateGAN
|
10 |
+
|
11 |
+
|
12 |
+
def predict(api_key):
|
13 |
+
def _predict(*args):
|
14 |
+
print("args: ", args)
|
15 |
+
image = place = None
|
16 |
+
if len(args) == 1:
|
17 |
+
image = args[0]
|
18 |
+
else:
|
19 |
+
assert len(args) == 2, "Unknown number of inputs {}".format(len(args))
|
20 |
+
image, place = args
|
21 |
+
|
22 |
+
if api_key and place:
|
23 |
+
geocode_result = gmaps.geocode(place)
|
24 |
+
|
25 |
+
address = geocode_result[0]["formatted_address"]
|
26 |
+
static_map_url = f"https://maps.googleapis.com/maps/api/streetview?size=640x640&location={parse.quote(address)}&source=outdoor&key={api_key}"
|
27 |
+
img_np = io.imread(static_map_url)
|
28 |
+
else:
|
29 |
+
img_np = image
|
30 |
+
flood, wildfire, smog = model.inference(img_np)
|
31 |
+
return img_np, flood, wildfire, smog
|
32 |
+
|
33 |
+
return _predict
|
34 |
+
|
35 |
+
|
36 |
+
if __name__ == "__main__":
|
37 |
+
|
38 |
+
api_key = os.environ.get("GMAPS_API_KEY")
|
39 |
+
gmaps = None
|
40 |
+
if api_key is not None:
|
41 |
+
gmaps = googlemaps.Client(key=api_key)
|
42 |
+
|
43 |
+
model = ClimateGAN(model_path="config/model/masker")
|
44 |
+
|
45 |
+
inputs = inputs = [gr.inputs.Image(label="Input Image")]
|
46 |
+
if api_key:
|
47 |
+
inputs += [gr.inputs.Textbox(label="Address or place name")]
|
48 |
+
|
49 |
+
gr.Interface(
|
50 |
+
predict(api_key),
|
51 |
+
inputs=[
|
52 |
+
gr.inputs.Textbox(label="Address or place name"),
|
53 |
+
gr.inputs.Image(label="Input Image"),
|
54 |
+
],
|
55 |
+
outputs=[
|
56 |
+
gr.outputs.Image(type="numpy", label="Original image"),
|
57 |
+
gr.outputs.Image(type="numpy", label="Flooding"),
|
58 |
+
gr.outputs.Image(type="numpy", label="Wildfire"),
|
59 |
+
gr.outputs.Image(type="numpy", label="Smog"),
|
60 |
+
],
|
61 |
+
title="ClimateGAN: Visualize Climate Change",
|
62 |
+
description='Climate change does not impact everyone equally. This Space shows the effects of the climate emergency, "one address at a time". Visit the original experience at <a href="https://thisclimatedoesnotexist.com/">ThisClimateDoesNotExist.com</a>.<br>Enter an address or place name, and ClimateGAN will generate images showing how the location could be impacted by flooding, wildfires, or smog.', # noqa: E501
|
63 |
+
article="<p style='text-align: center'>This project is an unofficial clone of <a href='https://thisclimatedoesnotexist.com/'>ThisClimateDoesNotExist</a> | <a href='https://github.com/cc-ai/climategan'>ClimateGAN GitHub Repo</a></p>", # noqa: E501
|
64 |
+
# examples=[
|
65 |
+
# "Vancouver Art Gallery",
|
66 |
+
# "Chicago Bean",
|
67 |
+
# "Duomo Siracusa",
|
68 |
+
# ],
|
69 |
+
css=".footer{display:none !important}",
|
70 |
+
).launch()
|
climategan/generator.py
CHANGED
@@ -101,6 +101,10 @@ class OmniGenerator(nn.Module):
|
|
101 |
if self.verbose > 0:
|
102 |
print(" - Add Empty Painter")
|
103 |
|
|
|
|
|
|
|
|
|
104 |
def __str__(self):
|
105 |
return strings.generator(self)
|
106 |
|
@@ -381,7 +385,7 @@ class OmniGenerator(nn.Module):
|
|
381 |
val_painter_opts = Dict(yaml.safe_load(f))
|
382 |
|
383 |
# load checkpoint
|
384 |
-
state_dict = torch.load(ckpt_path)
|
385 |
|
386 |
# create dummy painter from loaded opts
|
387 |
painter = create_painter(val_painter_opts)
|
|
|
101 |
if self.verbose > 0:
|
102 |
print(" - Add Empty Painter")
|
103 |
|
104 |
+
@property
|
105 |
+
def device(self):
|
106 |
+
return next(self.parameters()).device
|
107 |
+
|
108 |
def __str__(self):
|
109 |
return strings.generator(self)
|
110 |
|
|
|
385 |
val_painter_opts = Dict(yaml.safe_load(f))
|
386 |
|
387 |
# load checkpoint
|
388 |
+
state_dict = torch.load(ckpt_path, map_location=self.device)
|
389 |
|
390 |
# create dummy painter from loaded opts
|
391 |
painter = create_painter(val_painter_opts)
|
climategan/masker.py
CHANGED
@@ -186,18 +186,18 @@ class MaskSpadeDecoder(nn.Module):
|
|
186 |
for i in range(self.num_layers):
|
187 |
self.spade_blocks.append(
|
188 |
SPADEResnetBlock(
|
189 |
-
int(self.z_nc / (2
|
190 |
int(self.z_nc / (2 ** (i + 1))),
|
191 |
cond_nc,
|
192 |
spade_use_spectral_norm,
|
193 |
spade_param_free_norm,
|
194 |
spade_kernel_size,
|
195 |
spade_activation,
|
196 |
-
)
|
197 |
)
|
198 |
self.spade_blocks = nn.Sequential(*self.spade_blocks)
|
199 |
|
200 |
-
self.final_nc = int(self.z_nc / (2
|
201 |
self.mask_conv = Conv2dBlock(
|
202 |
self.final_nc,
|
203 |
1,
|
|
|
186 |
for i in range(self.num_layers):
|
187 |
self.spade_blocks.append(
|
188 |
SPADEResnetBlock(
|
189 |
+
int(self.z_nc / (2**i)),
|
190 |
int(self.z_nc / (2 ** (i + 1))),
|
191 |
cond_nc,
|
192 |
spade_use_spectral_norm,
|
193 |
spade_param_free_norm,
|
194 |
spade_kernel_size,
|
195 |
spade_activation,
|
196 |
+
)
|
197 |
)
|
198 |
self.spade_blocks = nn.Sequential(*self.spade_blocks)
|
199 |
|
200 |
+
self.final_nc = int(self.z_nc / (2**self.num_layers))
|
201 |
self.mask_conv = Conv2dBlock(
|
202 |
self.final_nc,
|
203 |
1,
|
inferences.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# based on https://huggingface.co/spaces/NimaBoscarino/climategan/blob/main/inferences.py # noqa: E501
|
2 |
+
# thank you @NimaBoscarino
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from skimage.color import rgba2rgb
|
6 |
+
from skimage.transform import resize
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
from climategan.trainer import Trainer
|
10 |
+
|
11 |
+
|
12 |
+
def uint8(array):
|
13 |
+
"""
|
14 |
+
convert an array to np.uint8 (does not rescale or anything else than changing dtype)
|
15 |
+
Args:
|
16 |
+
array (np.array): array to modify
|
17 |
+
Returns:
|
18 |
+
np.array(np.uint8): converted array
|
19 |
+
"""
|
20 |
+
return array.astype(np.uint8)
|
21 |
+
|
22 |
+
|
23 |
+
def resize_and_crop(img, to=640):
|
24 |
+
"""
|
25 |
+
Resizes an image so that it keeps the aspect ratio and the smallest dimensions
|
26 |
+
is `to`, then crops this resized image in its center so that the output is `to x to`
|
27 |
+
without aspect ratio distortion
|
28 |
+
Args:
|
29 |
+
img (np.array): np.uint8 255 image
|
30 |
+
Returns:
|
31 |
+
np.array: [0, 1] np.float32 image
|
32 |
+
"""
|
33 |
+
# resize keeping aspect ratio: smallest dim is 640
|
34 |
+
h, w = img.shape[:2]
|
35 |
+
if h < w:
|
36 |
+
size = (to, int(to * w / h))
|
37 |
+
else:
|
38 |
+
size = (int(to * h / w), to)
|
39 |
+
|
40 |
+
r_img = resize(img, size, preserve_range=True, anti_aliasing=True)
|
41 |
+
r_img = uint8(r_img)
|
42 |
+
|
43 |
+
# crop in the center
|
44 |
+
H, W = r_img.shape[:2]
|
45 |
+
|
46 |
+
top = (H - to) // 2
|
47 |
+
left = (W - to) // 2
|
48 |
+
|
49 |
+
rc_img = r_img[top : top + to, left : left + to, :]
|
50 |
+
|
51 |
+
return rc_img / 255.0
|
52 |
+
|
53 |
+
|
54 |
+
def to_m1_p1(img):
|
55 |
+
"""
|
56 |
+
rescales a [0, 1] image to [-1, +1]
|
57 |
+
Args:
|
58 |
+
img (np.array): float32 numpy array of an image in [0, 1]
|
59 |
+
i (int): Index of the image being rescaled
|
60 |
+
Raises:
|
61 |
+
ValueError: If the image is not in [0, 1]
|
62 |
+
Returns:
|
63 |
+
np.array(np.float32): array in [-1, +1]
|
64 |
+
"""
|
65 |
+
if img.min() >= 0 and img.max() <= 1:
|
66 |
+
return (img.astype(np.float32) - 0.5) * 2
|
67 |
+
raise ValueError(f"Data range mismatch for image: ({img.min()}, {img.max()})")
|
68 |
+
|
69 |
+
|
70 |
+
# No need to do any timing in this, since it's just for the HF Space
|
71 |
+
class ClimateGAN:
|
72 |
+
def __init__(self, model_path) -> None:
|
73 |
+
torch.set_grad_enabled(False)
|
74 |
+
self.target_size = 640
|
75 |
+
self.trainer = Trainer.resume_from_path(
|
76 |
+
model_path,
|
77 |
+
setup=True,
|
78 |
+
inference=True,
|
79 |
+
new_exp=None,
|
80 |
+
)
|
81 |
+
|
82 |
+
# Does all three inferences at the moment.
|
83 |
+
def inference(self, orig_image):
|
84 |
+
image = self._preprocess_image(orig_image)
|
85 |
+
|
86 |
+
# Retrieve numpy events as a dict {event: array[BxHxWxC]}
|
87 |
+
outputs = self.trainer.infer_all(
|
88 |
+
image,
|
89 |
+
numpy=True,
|
90 |
+
bin_value=0.5,
|
91 |
+
)
|
92 |
+
|
93 |
+
return (
|
94 |
+
outputs["flood"].squeeze(),
|
95 |
+
outputs["wildfire"].squeeze(),
|
96 |
+
outputs["smog"].squeeze(),
|
97 |
+
)
|
98 |
+
|
99 |
+
def _preprocess_image(self, img):
|
100 |
+
# rgba to rgb
|
101 |
+
data = img if img.shape[-1] == 3 else uint8(rgba2rgb(img) * 255)
|
102 |
+
|
103 |
+
# to args.target_size
|
104 |
+
data = resize_and_crop(data, self.target_size)
|
105 |
+
|
106 |
+
# resize() produces [0, 1] images, rescale to [-1, 1]
|
107 |
+
data = to_m1_p1(data)
|
108 |
+
return data
|