Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -31,6 +31,7 @@ import glob
|
|
31 |
import pathlib
|
32 |
from functools import partial
|
33 |
from pprint import pprint
|
|
|
34 |
|
35 |
import numpy as np
|
36 |
from PIL import Image
|
@@ -150,6 +151,11 @@ opt.excluded_keys = ['inpainting_mode', 'is_running', 'active_palettes', 'curren
|
|
150 |
opt.prep_time = 20
|
151 |
|
152 |
|
|
|
|
|
|
|
|
|
|
|
153 |
### Event handlers
|
154 |
|
155 |
def add_palette(state):
|
@@ -385,6 +391,11 @@ def register(state, drawpad, model):
|
|
385 |
|
386 |
@spaces.GPU(duration=(opt.prep_time + opt.run_time + 5))
|
387 |
def run(state, drawpad):
|
|
|
|
|
|
|
|
|
|
|
388 |
model.device = torch.device('cuda')
|
389 |
model.reset_seed(model.generator, opt.seed)
|
390 |
model.reset_latent()
|
@@ -395,6 +406,17 @@ def run(state, drawpad):
|
|
395 |
|
396 |
tic = time.time()
|
397 |
while True:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
398 |
yield [state, model()]
|
399 |
toc = time.time()
|
400 |
tdelta = toc - tic
|
@@ -412,12 +434,14 @@ def show_element():
|
|
412 |
return gr.update(visible=True)
|
413 |
|
414 |
|
415 |
-
@spaces.GPU
|
416 |
def draw(state, drawpad):
|
417 |
if not state.is_running:
|
418 |
print('[WARNING] Streaming is currently off, update ignored.')
|
419 |
return
|
420 |
|
|
|
|
|
|
|
421 |
user_input = np.asarray(drawpad['layers'][0]) # (H, W, 4)
|
422 |
foreground_mask = torch.tensor(user_input[..., -1])[None, None] # (1, 1, H, W)
|
423 |
user_input = torch.tensor(user_input[..., :-1]) # (H, W, 3)
|
@@ -441,13 +465,15 @@ def draw(state, drawpad):
|
|
441 |
# mask_strengths = [1] + [state.mask_strengths[v] for v in has_masks]
|
442 |
# mask_stds = [0] + [state.mask_stds[v] for v in has_masks]
|
443 |
|
444 |
-
for i in range(len(has_masks)):
|
445 |
-
|
446 |
-
|
447 |
-
|
448 |
-
|
449 |
-
|
450 |
-
|
|
|
|
|
451 |
|
452 |
### Load examples
|
453 |
|
|
|
31 |
import pathlib
|
32 |
from functools import partial
|
33 |
from pprint import pprint
|
34 |
+
from multiprocessing.connection import Client, Listener
|
35 |
|
36 |
import numpy as np
|
37 |
from PIL import Image
|
|
|
151 |
opt.prep_time = 20
|
152 |
|
153 |
|
154 |
+
### Shared memory hack for ZeroGPU
|
155 |
+
opt.address = ('localhost', 6000)
|
156 |
+
= b'secret password'
|
157 |
+
|
158 |
+
|
159 |
### Event handlers
|
160 |
|
161 |
def add_palette(state):
|
|
|
391 |
|
392 |
@spaces.GPU(duration=(opt.prep_time + opt.run_time + 5))
|
393 |
def run(state, drawpad):
|
394 |
+
# ZeroGPU hack.
|
395 |
+
listener = Listener(opt.address, authkey=opt.authkey)
|
396 |
+
conn = listener.accept()
|
397 |
+
|
398 |
+
# Reset model.
|
399 |
model.device = torch.device('cuda')
|
400 |
model.reset_seed(model.generator, opt.seed)
|
401 |
model.reset_latent()
|
|
|
406 |
|
407 |
tic = time.time()
|
408 |
while True:
|
409 |
+
# Receive real-time mask inputs from the main process.
|
410 |
+
msg = conn.recv()
|
411 |
+
print(msg + ' Received!!!')
|
412 |
+
# for i in range(opt.max_palettes):
|
413 |
+
# model.update_single_layer(
|
414 |
+
# idx=i,
|
415 |
+
# mask=masks[i],
|
416 |
+
# mask_strength=mask_strengths[i],
|
417 |
+
# mask_std=mask_stds[i],
|
418 |
+
# )
|
419 |
+
|
420 |
yield [state, model()]
|
421 |
toc = time.time()
|
422 |
tdelta = toc - tic
|
|
|
434 |
return gr.update(visible=True)
|
435 |
|
436 |
|
|
|
437 |
def draw(state, drawpad):
|
438 |
if not state.is_running:
|
439 |
print('[WARNING] Streaming is currently off, update ignored.')
|
440 |
return
|
441 |
|
442 |
+
# ZeroGPU hack.
|
443 |
+
conn = Client(opt.address, authkey=opt.authkey)
|
444 |
+
|
445 |
user_input = np.asarray(drawpad['layers'][0]) # (H, W, 4)
|
446 |
foreground_mask = torch.tensor(user_input[..., -1])[None, None] # (1, 1, H, W)
|
447 |
user_input = torch.tensor(user_input[..., :-1]) # (H, W, 3)
|
|
|
465 |
# mask_strengths = [1] + [state.mask_strengths[v] for v in has_masks]
|
466 |
# mask_stds = [0] + [state.mask_stds[v] for v in has_masks]
|
467 |
|
468 |
+
# for i in range(len(has_masks)):
|
469 |
+
# model.update_single_layer(
|
470 |
+
# idx=i,
|
471 |
+
# mask=masks[i],
|
472 |
+
# mask_strength=mask_strengths[i],
|
473 |
+
# mask_std=mask_stds[i],
|
474 |
+
# )
|
475 |
+
conn.send('Hello!!!!')
|
476 |
+
conn.close()
|
477 |
|
478 |
### Load examples
|
479 |
|