ironjr commited on
Commit
855f3df
·
verified ·
1 Parent(s): 734882d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -7
app.py CHANGED
@@ -421,7 +421,6 @@ def import_state(state, json_text):
421
 
422
  ### Main worker
423
 
424
-
425
  def register(state, drawpad, model):
426
  seed_everything(state.seed if state.seed >=0 else np.random.randint(2147483647))
427
  print('Generate!')
@@ -436,13 +435,13 @@ def register(state, drawpad, model):
436
  print('Inpainting mode: ', inpainting_mode)
437
 
438
  user_input = np.asarray(drawpad['layers'][0]) # (H, W, 4)
439
- foreground_mask = torch.tensor(user_input[..., -1])[None, None] # (1, 1, H, W)
440
- user_input = torch.tensor(user_input[..., :-1]) # (H, W, 3)
441
 
442
  palette = torch.tensor([
443
  tuple(int(s[i+1:i+3], 16) for i in (0, 2, 4))
444
  for s in opt.colors[1:]
445
- ]) # (N, 3)
446
  masks = (palette[:, None, None, :] == user_input[None]).all(dim=-1)[:, None, ...] # (N, 1, H, W)
447
  # has_masks = [i for i, m in enumerate(masks.sum(dim=(1, 2, 3)) == 0) if not m]
448
  has_masks = list(range(opt.max_palettes))
@@ -542,13 +541,13 @@ def draw(state, drawpad):
542
  # conn = Client(opt.address, authkey=opt.authkey)
543
 
544
  user_input = np.asarray(drawpad['layers'][0]) # (H, W, 4)
545
- foreground_mask = torch.tensor(user_input[..., -1])[None, None] # (1, 1, H, W)
546
- user_input = torch.tensor(user_input[..., :-1]) # (H, W, 3)
547
 
548
  palette = torch.tensor([
549
  tuple(int(s[i+1:i+3], 16) for i in (0, 2, 4))
550
  for s in opt.colors[1:]
551
- ]) # (N, 3)
552
  masks = (palette[:, None, None, :] == user_input[None]).all(dim=-1)[:, None, ...] # (N, 1, H, W)
553
  # has_masks = [i for i, m in enumerate(masks.sum(dim=(1, 2, 3)) == 0) if not m]
554
  has_masks = list(range(opt.max_palettes))
 
421
 
422
  ### Main worker
423
 
 
424
  def register(state, drawpad, model):
425
  seed_everything(state.seed if state.seed >=0 else np.random.randint(2147483647))
426
  print('Generate!')
 
435
  print('Inpainting mode: ', inpainting_mode)
436
 
437
  user_input = np.asarray(drawpad['layers'][0]) # (H, W, 4)
438
+ foreground_mask = torch.tensor(user_input[..., -1], device=model.device)[None, None] # (1, 1, H, W)
439
+ user_input = torch.tensor(user_input[..., :-1], device=model.device) # (H, W, 3)
440
 
441
  palette = torch.tensor([
442
  tuple(int(s[i+1:i+3], 16) for i in (0, 2, 4))
443
  for s in opt.colors[1:]
444
+ ], device=model.device) # (N, 3)
445
  masks = (palette[:, None, None, :] == user_input[None]).all(dim=-1)[:, None, ...] # (N, 1, H, W)
446
  # has_masks = [i for i, m in enumerate(masks.sum(dim=(1, 2, 3)) == 0) if not m]
447
  has_masks = list(range(opt.max_palettes))
 
541
  # conn = Client(opt.address, authkey=opt.authkey)
542
 
543
  user_input = np.asarray(drawpad['layers'][0]) # (H, W, 4)
544
+ foreground_mask = torch.tensor(user_input[..., -1], device=model.device)[None, None] # (1, 1, H, W)
545
+ user_input = torch.tensor(user_input[..., :-1], device=model.device) # (H, W, 3)
546
 
547
  palette = torch.tensor([
548
  tuple(int(s[i+1:i+3], 16) for i in (0, 2, 4))
549
  for s in opt.colors[1:]
550
+ ], device=model.device) # (N, 3)
551
  masks = (palette[:, None, None, :] == user_input[None]).all(dim=-1)[:, None, ...] # (N, 1, H, W)
552
  # has_masks = [i for i, m in enumerate(masks.sum(dim=(1, 2, 3)) == 0) if not m]
553
  has_masks = list(range(opt.max_palettes))