ironjr commited on
Commit
1773a6a
1 Parent(s): ad46e72

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -8
app.py CHANGED
@@ -31,6 +31,7 @@ import glob
31
  import pathlib
32
  from functools import partial
33
  from pprint import pprint
 
34
 
35
  import numpy as np
36
  from PIL import Image
@@ -150,6 +151,11 @@ opt.excluded_keys = ['inpainting_mode', 'is_running', 'active_palettes', 'curren
150
  opt.prep_time = 20
151
 
152
 
 
 
 
 
 
153
  ### Event handlers
154
 
155
  def add_palette(state):
@@ -385,6 +391,11 @@ def register(state, drawpad, model):
385
 
386
  @spaces.GPU(duration=(opt.prep_time + opt.run_time + 5))
387
  def run(state, drawpad):
 
 
 
 
 
388
  model.device = torch.device('cuda')
389
  model.reset_seed(model.generator, opt.seed)
390
  model.reset_latent()
@@ -395,6 +406,17 @@ def run(state, drawpad):
395
 
396
  tic = time.time()
397
  while True:
 
 
 
 
 
 
 
 
 
 
 
398
  yield [state, model()]
399
  toc = time.time()
400
  tdelta = toc - tic
@@ -412,12 +434,14 @@ def show_element():
412
  return gr.update(visible=True)
413
 
414
 
415
- @spaces.GPU
416
  def draw(state, drawpad):
417
  if not state.is_running:
418
  print('[WARNING] Streaming is currently off, update ignored.')
419
  return
420
 
 
 
 
421
  user_input = np.asarray(drawpad['layers'][0]) # (H, W, 4)
422
  foreground_mask = torch.tensor(user_input[..., -1])[None, None] # (1, 1, H, W)
423
  user_input = torch.tensor(user_input[..., :-1]) # (H, W, 3)
@@ -441,13 +465,15 @@ def draw(state, drawpad):
441
  # mask_strengths = [1] + [state.mask_strengths[v] for v in has_masks]
442
  # mask_stds = [0] + [state.mask_stds[v] for v in has_masks]
443
 
444
- for i in range(len(has_masks)):
445
- model.update_single_layer(
446
- idx=i,
447
- mask=masks[i],
448
- mask_strength=mask_strengths[i],
449
- mask_std=mask_stds[i],
450
- )
 
 
451
 
452
  ### Load examples
453
 
 
31
  import pathlib
32
  from functools import partial
33
  from pprint import pprint
34
+ from multiprocessing.connection import Client, Listener
35
 
36
  import numpy as np
37
  from PIL import Image
 
151
  opt.prep_time = 20
152
 
153
 
154
+ ### Shared memory hack for ZeroGPU
155
+ opt.address = ('localhost', 6000)
156
+ = b'secret password'
157
+
158
+
159
  ### Event handlers
160
 
161
  def add_palette(state):
 
391
 
392
  @spaces.GPU(duration=(opt.prep_time + opt.run_time + 5))
393
  def run(state, drawpad):
394
+ # ZeroGPU hack.
395
+ listener = Listener(opt.address, authkey=opt.authkey)
396
+ conn = listener.accept()
397
+
398
+ # Reset model.
399
  model.device = torch.device('cuda')
400
  model.reset_seed(model.generator, opt.seed)
401
  model.reset_latent()
 
406
 
407
  tic = time.time()
408
  while True:
409
+ # Receive real-time mask inputs from the main process.
410
+ msg = conn.recv()
411
+ print(msg + ' Received!!!')
412
+ # for i in range(opt.max_palettes):
413
+ # model.update_single_layer(
414
+ # idx=i,
415
+ # mask=masks[i],
416
+ # mask_strength=mask_strengths[i],
417
+ # mask_std=mask_stds[i],
418
+ # )
419
+
420
  yield [state, model()]
421
  toc = time.time()
422
  tdelta = toc - tic
 
434
  return gr.update(visible=True)
435
 
436
 
 
437
  def draw(state, drawpad):
438
  if not state.is_running:
439
  print('[WARNING] Streaming is currently off, update ignored.')
440
  return
441
 
442
+ # ZeroGPU hack.
443
+ conn = Client(opt.address, authkey=opt.authkey)
444
+
445
  user_input = np.asarray(drawpad['layers'][0]) # (H, W, 4)
446
  foreground_mask = torch.tensor(user_input[..., -1])[None, None] # (1, 1, H, W)
447
  user_input = torch.tensor(user_input[..., :-1]) # (H, W, 3)
 
465
  # mask_strengths = [1] + [state.mask_strengths[v] for v in has_masks]
466
  # mask_stds = [0] + [state.mask_stds[v] for v in has_masks]
467
 
468
+ # for i in range(len(has_masks)):
469
+ # model.update_single_layer(
470
+ # idx=i,
471
+ # mask=masks[i],
472
+ # mask_strength=mask_strengths[i],
473
+ # mask_std=mask_stds[i],
474
+ # )
475
+ conn.send('Hello!!!!')
476
+ conn.close()
477
 
478
  ### Load examples
479