Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -320,7 +320,7 @@ def import_state(state, json_text):
|
|
320 |
### Main worker
|
321 |
|
322 |
|
323 |
-
def register(state, drawpad):
|
324 |
seed_everything(state.seed if state.seed >=0 else np.random.randint(2147483647))
|
325 |
print('Generate!')
|
326 |
|
@@ -362,15 +362,15 @@ def register(state, drawpad):
|
|
362 |
# prompts, negative_prompts = preprocess_prompts(
|
363 |
# prompts, negative_prompts, style_name=state.style_name, quality_name=state.quality_name)
|
364 |
|
365 |
-
|
366 |
background.convert('RGB'),
|
367 |
prompt=None,
|
368 |
negative_prompt=None,
|
369 |
)
|
370 |
-
state.prompts[0] =
|
371 |
-
state.neg_prompts[0] =
|
372 |
|
373 |
-
|
374 |
prompts=prompts,
|
375 |
negative_prompts=negative_prompts,
|
376 |
masks=masks.to(device),
|
@@ -384,23 +384,23 @@ def register(state, drawpad):
|
|
384 |
|
385 |
@spaces.GPU(duration=120)
|
386 |
def run(state, drawpad):
|
387 |
-
|
388 |
-
|
389 |
-
|
390 |
-
|
391 |
-
state.model.prepare()
|
392 |
|
393 |
-
state = register(state, drawpad)
|
394 |
state.is_running = True
|
395 |
|
396 |
tic = time.time()
|
397 |
while True:
|
398 |
-
yield [state,
|
399 |
toc = time.time()
|
400 |
tdelta = toc - tic
|
401 |
if tdelta > opt.run_time:
|
402 |
state.is_running = False
|
403 |
-
|
|
|
404 |
|
405 |
|
406 |
def hide_element():
|
@@ -412,7 +412,11 @@ def show_element():
|
|
412 |
|
413 |
|
414 |
def draw(state, drawpad):
|
|
|
|
|
|
|
415 |
if not state.is_running:
|
|
|
416 |
return
|
417 |
|
418 |
user_input = np.asarray(drawpad['layers'][0]) # (H, W, 4)
|
@@ -601,7 +605,7 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css, head=head) as demo:
|
|
601 |
state.model_id = opt.model
|
602 |
state.style_name = '(None)'
|
603 |
state.quality_name = 'Standard v3.1'
|
604 |
-
state.model =
|
605 |
|
606 |
# State variables (one-hot).
|
607 |
state.active_palettes = 5
|
|
|
320 |
### Main worker
|
321 |
|
322 |
|
323 |
+
def register(state, drawpad, model):
|
324 |
seed_everything(state.seed if state.seed >=0 else np.random.randint(2147483647))
|
325 |
print('Generate!')
|
326 |
|
|
|
362 |
# prompts, negative_prompts = preprocess_prompts(
|
363 |
# prompts, negative_prompts, style_name=state.style_name, quality_name=state.quality_name)
|
364 |
|
365 |
+
model.update_background(
|
366 |
background.convert('RGB'),
|
367 |
prompt=None,
|
368 |
negative_prompt=None,
|
369 |
)
|
370 |
+
state.prompts[0] = model.background.prompt
|
371 |
+
state.neg_prompts[0] = model.background.negative_prompt
|
372 |
|
373 |
+
model.update_layers(
|
374 |
prompts=prompts,
|
375 |
negative_prompts=negative_prompts,
|
376 |
masks=masks.to(device),
|
|
|
384 |
|
385 |
@spaces.GPU(duration=120)
|
386 |
def run(state, drawpad):
|
387 |
+
model.device = torch.device('cuda')
|
388 |
+
model.reset_seed(model.generator, opt.seed)
|
389 |
+
model.reset_latent()
|
390 |
+
model.prepare()
|
|
|
391 |
|
392 |
+
state = register(state, drawpad, model)
|
393 |
state.is_running = True
|
394 |
|
395 |
tic = time.time()
|
396 |
while True:
|
397 |
+
yield [state, model()]
|
398 |
toc = time.time()
|
399 |
tdelta = toc - tic
|
400 |
if tdelta > opt.run_time:
|
401 |
state.is_running = False
|
402 |
+
state.model = None
|
403 |
+
return [state, model()]
|
404 |
|
405 |
|
406 |
def hide_element():
|
|
|
412 |
|
413 |
|
414 |
def draw(state, drawpad):
|
415 |
+
if not hasattr(state, 'model') or state.model is None:
|
416 |
+
print('[WARNING] Model is not registered, update ignored.')
|
417 |
+
return
|
418 |
if not state.is_running:
|
419 |
+
print('[WARNING] Streaming is currently off, update ignored.')
|
420 |
return
|
421 |
|
422 |
user_input = np.asarray(drawpad['layers'][0]) # (H, W, 4)
|
|
|
605 |
state.model_id = opt.model
|
606 |
state.style_name = '(None)'
|
607 |
state.quality_name = 'Standard v3.1'
|
608 |
+
state.model = None
|
609 |
|
610 |
# State variables (one-hot).
|
611 |
state.active_palettes = 5
|