Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -421,7 +421,6 @@ def import_state(state, json_text):
|
|
421 |
|
422 |
### Main worker
|
423 |
|
424 |
-
|
425 |
def register(state, drawpad, model):
|
426 |
seed_everything(state.seed if state.seed >=0 else np.random.randint(2147483647))
|
427 |
print('Generate!')
|
@@ -436,13 +435,13 @@ def register(state, drawpad, model):
|
|
436 |
print('Inpainting mode: ', inpainting_mode)
|
437 |
|
438 |
user_input = np.asarray(drawpad['layers'][0]) # (H, W, 4)
|
439 |
-
foreground_mask = torch.tensor(user_input[..., -1])[None, None] # (1, 1, H, W)
|
440 |
-
user_input = torch.tensor(user_input[..., :-1]) # (H, W, 3)
|
441 |
|
442 |
palette = torch.tensor([
|
443 |
tuple(int(s[i+1:i+3], 16) for i in (0, 2, 4))
|
444 |
for s in opt.colors[1:]
|
445 |
-
]) # (N, 3)
|
446 |
masks = (palette[:, None, None, :] == user_input[None]).all(dim=-1)[:, None, ...] # (N, 1, H, W)
|
447 |
# has_masks = [i for i, m in enumerate(masks.sum(dim=(1, 2, 3)) == 0) if not m]
|
448 |
has_masks = list(range(opt.max_palettes))
|
@@ -542,13 +541,13 @@ def draw(state, drawpad):
|
|
542 |
# conn = Client(opt.address, authkey=opt.authkey)
|
543 |
|
544 |
user_input = np.asarray(drawpad['layers'][0]) # (H, W, 4)
|
545 |
-
foreground_mask = torch.tensor(user_input[..., -1])[None, None] # (1, 1, H, W)
|
546 |
-
user_input = torch.tensor(user_input[..., :-1]) # (H, W, 3)
|
547 |
|
548 |
palette = torch.tensor([
|
549 |
tuple(int(s[i+1:i+3], 16) for i in (0, 2, 4))
|
550 |
for s in opt.colors[1:]
|
551 |
-
]) # (N, 3)
|
552 |
masks = (palette[:, None, None, :] == user_input[None]).all(dim=-1)[:, None, ...] # (N, 1, H, W)
|
553 |
# has_masks = [i for i, m in enumerate(masks.sum(dim=(1, 2, 3)) == 0) if not m]
|
554 |
has_masks = list(range(opt.max_palettes))
|
|
|
421 |
|
422 |
### Main worker
|
423 |
|
|
|
424 |
def register(state, drawpad, model):
|
425 |
seed_everything(state.seed if state.seed >=0 else np.random.randint(2147483647))
|
426 |
print('Generate!')
|
|
|
435 |
print('Inpainting mode: ', inpainting_mode)
|
436 |
|
437 |
user_input = np.asarray(drawpad['layers'][0]) # (H, W, 4)
|
438 |
+
foreground_mask = torch.tensor(user_input[..., -1], device=model.device)[None, None] # (1, 1, H, W)
|
439 |
+
user_input = torch.tensor(user_input[..., :-1], device=model.device) # (H, W, 3)
|
440 |
|
441 |
palette = torch.tensor([
|
442 |
tuple(int(s[i+1:i+3], 16) for i in (0, 2, 4))
|
443 |
for s in opt.colors[1:]
|
444 |
+
], device=model.device) # (N, 3)
|
445 |
masks = (palette[:, None, None, :] == user_input[None]).all(dim=-1)[:, None, ...] # (N, 1, H, W)
|
446 |
# has_masks = [i for i, m in enumerate(masks.sum(dim=(1, 2, 3)) == 0) if not m]
|
447 |
has_masks = list(range(opt.max_palettes))
|
|
|
541 |
# conn = Client(opt.address, authkey=opt.authkey)
|
542 |
|
543 |
user_input = np.asarray(drawpad['layers'][0]) # (H, W, 4)
|
544 |
+
foreground_mask = torch.tensor(user_input[..., -1], device=model.device)[None, None] # (1, 1, H, W)
|
545 |
+
user_input = torch.tensor(user_input[..., :-1], device=model.device) # (H, W, 3)
|
546 |
|
547 |
palette = torch.tensor([
|
548 |
tuple(int(s[i+1:i+3], 16) for i in (0, 2, 4))
|
549 |
for s in opt.colors[1:]
|
550 |
+
], device=model.device) # (N, 3)
|
551 |
masks = (palette[:, None, None, :] == user_input[None]).all(dim=-1)[:, None, ...] # (N, 1, H, W)
|
552 |
# has_masks = [i for i, m in enumerate(masks.sum(dim=(1, 2, 3)) == 0) if not m]
|
553 |
has_masks = list(range(opt.max_palettes))
|