Spaces:
Running
Running
update model
Browse files- app.py +20 -19
- examples/01.jpg +0 -0
- examples/01.png +0 -0
- examples/02.jpg +0 -0
- examples/02.png +0 -0
- examples/03.jpg +0 -0
- examples/03.png +0 -0
- examples/04.jpg +0 -0
- examples/04.png +0 -0
app.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
import gradio as gr
|
2 |
import imageio
|
3 |
import numpy as np
|
@@ -93,14 +94,15 @@ class Model:
|
|
93 |
detector_path = huggingface_hub.hf_hub_download("skytnt/fbanime-gan", "waifu_dect.onnx")
|
94 |
anime_seg_path = huggingface_hub.hf_hub_download("skytnt/anime-seg", "isnetis.onnx")
|
95 |
|
96 |
-
providers = ['
|
|
|
97 |
g_mapping = onnx.load(g_mapping_path)
|
98 |
w_avg = [x for x in g_mapping.graph.initializer if x.name == "w_avg"][0]
|
99 |
w_avg = np.frombuffer(w_avg.raw_data, dtype=np.float32)[np.newaxis, :]
|
100 |
w_avg = w_avg.repeat(16, axis=0)[np.newaxis, :]
|
101 |
self.w_avg = w_avg
|
102 |
-
self.g_mapping = rt.InferenceSession(g_mapping_path, providers=providers)
|
103 |
-
self.g_synthesis = rt.InferenceSession(g_synthesis_path, providers=providers)
|
104 |
self.encoder = rt.InferenceSession(encoder_path, providers=providers)
|
105 |
self.detector = rt.InferenceSession(detector_path, providers=providers)
|
106 |
detector_meta = self.detector.get_modelmeta().custom_metadata_map
|
@@ -130,7 +132,7 @@ class Model:
|
|
130 |
mask = np.transpose(mask, (1, 2, 0))
|
131 |
mask = mask[ph // 2:ph // 2 + h, pw // 2:pw // 2 + w]
|
132 |
mask = transform.resize(mask, (h0, w0))
|
133 |
-
img0 = (img0*mask + 255*(1-mask)).astype(np.uint8)
|
134 |
return img0
|
135 |
|
136 |
def encode_img(self, img):
|
@@ -247,10 +249,12 @@ def get_thumbnail(img):
|
|
247 |
|
248 |
|
249 |
def gen_fn(method, seed, psi1, psi2, noise):
|
250 |
-
|
|
|
|
|
251 |
w = model.get_w(z.astype(dtype=np.float32), psi1, psi2)
|
252 |
img_out = model.get_img(w, noise)
|
253 |
-
return img_out, w, get_thumbnail(img_out)
|
254 |
|
255 |
|
256 |
def encode_img_fn(img, noise):
|
@@ -259,7 +263,7 @@ def encode_img_fn(img, noise):
|
|
259 |
img = model.remove_bg(img)
|
260 |
imgs = model.detect(img, 0.2, 0.03)
|
261 |
if len(imgs) == 0:
|
262 |
-
return "failed to detect
|
263 |
w = model.encode_img(imgs[0])
|
264 |
img_out = model.get_img(w, noise)
|
265 |
return "success", imgs[0], img_out, w, get_thumbnail(img_out)
|
@@ -278,8 +282,7 @@ if __name__ == '__main__':
|
|
278 |
app = gr.Blocks()
|
279 |
with app:
|
280 |
gr.Markdown("# full-body anime GAN\n\n"
|
281 |
-
"![visitor badge](https://visitor-badge.glitch.me/badge?page_id=skytnt.full-body-anime-gan)\n\n"
|
282 |
-
"the model is not well, just use for fun.")
|
283 |
with gr.Tabs():
|
284 |
with gr.TabItem("generate image"):
|
285 |
with gr.Row():
|
@@ -287,9 +290,9 @@ if __name__ == '__main__':
|
|
287 |
gr.Markdown("generate image randomly or by seed")
|
288 |
with gr.Row():
|
289 |
gen_input1 = gr.Radio(label="method", value="random",
|
290 |
-
choices=["random", "
|
291 |
-
gen_input2 = gr.
|
292 |
-
gen_input3 = gr.Slider(minimum=0, maximum=1, step=0.01, value=
|
293 |
gen_input4 = gr.Slider(minimum=0, maximum=1, step=0.01, value=1, label="truncation psi 2")
|
294 |
gen_input5 = gr.Slider(minimum=0, maximum=1, step=0.01, value=1, label="noise strength")
|
295 |
with gr.Group():
|
@@ -304,7 +307,7 @@ if __name__ == '__main__':
|
|
304 |
with gr.Column():
|
305 |
gr.Markdown("you'd better upload a standing full-body image")
|
306 |
encode_img_input = gr.Image(label="input image")
|
307 |
-
examples_data = [[f"examples/{x:02d}.
|
308 |
encode_img_examples = gr.Dataset(components=[encode_img_input], samples=examples_data)
|
309 |
with gr.Group():
|
310 |
encode_img_submit = gr.Button("Run", variant="primary")
|
@@ -319,11 +322,10 @@ if __name__ == '__main__':
|
|
319 |
with gr.TabItem("generate video"):
|
320 |
with gr.Row():
|
321 |
with gr.Column():
|
322 |
-
gr.Markdown("
|
323 |
with gr.Row():
|
324 |
with gr.Column():
|
325 |
-
gr.
|
326 |
-
select_img1_dropdown = gr.Radio(label="source", value="current generated image",
|
327 |
choices=["current generated image",
|
328 |
"current encoded image"], type="index")
|
329 |
with gr.Group():
|
@@ -331,8 +333,7 @@ if __name__ == '__main__':
|
|
331 |
select_img1_output_img = gr.Image(label="selected image 1")
|
332 |
select_img1_output_w = gr.Variable()
|
333 |
with gr.Column():
|
334 |
-
gr.
|
335 |
-
select_img2_dropdown = gr.Radio(label="source", value="current generated image",
|
336 |
choices=["current generated image",
|
337 |
"current encoded image"], type="index")
|
338 |
with gr.Group():
|
@@ -345,7 +346,7 @@ if __name__ == '__main__':
|
|
345 |
with gr.Column():
|
346 |
generate_video_output = gr.Video(label="output video")
|
347 |
gen_submit.click(gen_fn, [gen_input1, gen_input2, gen_input3, gen_input4, gen_input5],
|
348 |
-
[gen_output1, select_img_input_w1, select_img_input_img1])
|
349 |
encode_img_submit.click(encode_img_fn, [encode_img_input, gen_input5],
|
350 |
[encode_img_output1, encode_img_output2, encode_img_output3, select_img_input_w2,
|
351 |
select_img_input_img2])
|
|
|
1 |
+
import random
|
2 |
import gradio as gr
|
3 |
import imageio
|
4 |
import numpy as np
|
|
|
94 |
detector_path = huggingface_hub.hf_hub_download("skytnt/fbanime-gan", "waifu_dect.onnx")
|
95 |
anime_seg_path = huggingface_hub.hf_hub_download("skytnt/anime-seg", "isnetis.onnx")
|
96 |
|
97 |
+
providers = ['CPUExecutionProvider']
|
98 |
+
gpu_providers = ['CUDAExecutionProvider']
|
99 |
g_mapping = onnx.load(g_mapping_path)
|
100 |
w_avg = [x for x in g_mapping.graph.initializer if x.name == "w_avg"][0]
|
101 |
w_avg = np.frombuffer(w_avg.raw_data, dtype=np.float32)[np.newaxis, :]
|
102 |
w_avg = w_avg.repeat(16, axis=0)[np.newaxis, :]
|
103 |
self.w_avg = w_avg
|
104 |
+
self.g_mapping = rt.InferenceSession(g_mapping_path, providers=gpu_providers + providers)
|
105 |
+
self.g_synthesis = rt.InferenceSession(g_synthesis_path, providers=gpu_providers + providers)
|
106 |
self.encoder = rt.InferenceSession(encoder_path, providers=providers)
|
107 |
self.detector = rt.InferenceSession(detector_path, providers=providers)
|
108 |
detector_meta = self.detector.get_modelmeta().custom_metadata_map
|
|
|
132 |
mask = np.transpose(mask, (1, 2, 0))
|
133 |
mask = mask[ph // 2:ph // 2 + h, pw // 2:pw // 2 + w]
|
134 |
mask = transform.resize(mask, (h0, w0))
|
135 |
+
img0 = (img0 * mask + 255 * (1 - mask)).astype(np.uint8)
|
136 |
return img0
|
137 |
|
138 |
def encode_img(self, img):
|
|
|
249 |
|
250 |
|
251 |
def gen_fn(method, seed, psi1, psi2, noise):
|
252 |
+
if method == 0:
|
253 |
+
seed = random.randint(0, 2 ** 32 - 1)
|
254 |
+
z = RandomState(int(seed)).randn(1, 1024)
|
255 |
w = model.get_w(z.astype(dtype=np.float32), psi1, psi2)
|
256 |
img_out = model.get_img(w, noise)
|
257 |
+
return img_out, seed, w, get_thumbnail(img_out)
|
258 |
|
259 |
|
260 |
def encode_img_fn(img, noise):
|
|
|
263 |
img = model.remove_bg(img)
|
264 |
imgs = model.detect(img, 0.2, 0.03)
|
265 |
if len(imgs) == 0:
|
266 |
+
return "failed to detect anime character", None, None, None, None
|
267 |
w = model.encode_img(imgs[0])
|
268 |
img_out = model.get_img(w, noise)
|
269 |
return "success", imgs[0], img_out, w, get_thumbnail(img_out)
|
|
|
282 |
app = gr.Blocks()
|
283 |
with app:
|
284 |
gr.Markdown("# full-body anime GAN\n\n"
|
285 |
+
"![visitor badge](https://visitor-badge.glitch.me/badge?page_id=skytnt.full-body-anime-gan)\n\n")
|
|
|
286 |
with gr.Tabs():
|
287 |
with gr.TabItem("generate image"):
|
288 |
with gr.Row():
|
|
|
290 |
gr.Markdown("generate image randomly or by seed")
|
291 |
with gr.Row():
|
292 |
gen_input1 = gr.Radio(label="method", value="random",
|
293 |
+
choices=["random", "seed"], type="index")
|
294 |
+
gen_input2 = gr.Slider(minimum=0, maximum=2 ** 32 - 1, step=1, value=0, label="seed")
|
295 |
+
gen_input3 = gr.Slider(minimum=0, maximum=1, step=0.01, value=1, label="truncation psi 1")
|
296 |
gen_input4 = gr.Slider(minimum=0, maximum=1, step=0.01, value=1, label="truncation psi 2")
|
297 |
gen_input5 = gr.Slider(minimum=0, maximum=1, step=0.01, value=1, label="noise strength")
|
298 |
with gr.Group():
|
|
|
307 |
with gr.Column():
|
308 |
gr.Markdown("you'd better upload a standing full-body image")
|
309 |
encode_img_input = gr.Image(label="input image")
|
310 |
+
examples_data = [[f"examples/{x:02d}.jpg"] for x in range(1, 5)]
|
311 |
encode_img_examples = gr.Dataset(components=[encode_img_input], samples=examples_data)
|
312 |
with gr.Group():
|
313 |
encode_img_submit = gr.Button("Run", variant="primary")
|
|
|
322 |
with gr.TabItem("generate video"):
|
323 |
with gr.Row():
|
324 |
with gr.Column():
|
325 |
+
gr.Markdown("generate video between 2 images")
|
326 |
with gr.Row():
|
327 |
with gr.Column():
|
328 |
+
select_img1_dropdown = gr.Radio(label="Select image 1", value="current generated image",
|
|
|
329 |
choices=["current generated image",
|
330 |
"current encoded image"], type="index")
|
331 |
with gr.Group():
|
|
|
333 |
select_img1_output_img = gr.Image(label="selected image 1")
|
334 |
select_img1_output_w = gr.Variable()
|
335 |
with gr.Column():
|
336 |
+
select_img2_dropdown = gr.Radio(label="Select image 2", value="current generated image",
|
|
|
337 |
choices=["current generated image",
|
338 |
"current encoded image"], type="index")
|
339 |
with gr.Group():
|
|
|
346 |
with gr.Column():
|
347 |
generate_video_output = gr.Video(label="output video")
|
348 |
gen_submit.click(gen_fn, [gen_input1, gen_input2, gen_input3, gen_input4, gen_input5],
|
349 |
+
[gen_output1, gen_input2, select_img_input_w1, select_img_input_img1])
|
350 |
encode_img_submit.click(encode_img_fn, [encode_img_input, gen_input5],
|
351 |
[encode_img_output1, encode_img_output2, encode_img_output3, select_img_input_w2,
|
352 |
select_img_input_img2])
|
examples/01.jpg
ADDED
examples/01.png
DELETED
Binary file (405 kB)
|
|
examples/02.jpg
ADDED
examples/02.png
DELETED
Binary file (331 kB)
|
|
examples/03.jpg
ADDED
examples/03.png
DELETED
Binary file (369 kB)
|
|
examples/04.jpg
ADDED
examples/04.png
DELETED
Binary file (452 kB)
|
|