adamelliotfields commited on
Commit
7736f5f
1 Parent(s): 8991603

Improve generate timing

Browse files
Files changed (2) hide show
  1. app.py +3 -15
  2. generate.py +8 -2
app.py CHANGED
@@ -1,5 +1,3 @@
1
- import time
2
-
3
  import gradio as gr
4
 
5
  from generate import generate
@@ -23,24 +21,14 @@ def read_file(path: str) -> str:
23
 
24
 
25
  # don't request a GPU if input is bad
26
- def generate_btn_click(*args, **kwargs):
27
- start = time.perf_counter()
28
-
29
- if "prompt" in kwargs:
30
- prompt = kwargs.get("prompt")
31
- elif len(args) > 0:
32
  prompt = args[0]
33
  else:
34
  prompt = None
35
-
36
  if prompt is None or prompt.strip() == "":
37
  raise gr.Error("You must enter a prompt")
38
-
39
- images = generate(*args, **kwargs, Error=gr.Error)
40
- end = time.perf_counter()
41
- diff = end - start
42
- gr.Info(f"Generated {len(images)} images in {diff:.2f}s")
43
- return images
44
 
45
 
46
  with gr.Blocks(
 
 
 
1
  import gradio as gr
2
 
3
  from generate import generate
 
21
 
22
 
23
  # don't request a GPU if input is bad
24
+ def generate_btn_click(*args):
25
+ if len(args) > 0:
 
 
 
 
26
  prompt = args[0]
27
  else:
28
  prompt = None
 
29
  if prompt is None or prompt.strip() == "":
30
  raise gr.Error("You must enter a prompt")
31
+ return generate(*args)
 
 
 
 
 
32
 
33
 
34
  with gr.Blocks(
generate.py CHANGED
@@ -1,4 +1,5 @@
1
  import re
 
2
  from contextlib import contextmanager
3
  from datetime import datetime
4
  from itertools import product
@@ -6,6 +7,7 @@ from os import environ
6
  from types import MethodType
7
  from warnings import filterwarnings
8
 
 
9
  import spaces
10
  import tomesd
11
  import torch
@@ -242,10 +244,10 @@ def generate(
242
  deepcache_interval=1,
243
  tgate_step=0,
244
  tome_ratio=0,
245
- Error=Exception,
246
  ):
247
  if not torch.cuda.is_available():
248
- raise Error("CUDA not available")
249
 
250
  if seed is None:
251
  seed = int(datetime.now().timestamp())
@@ -263,6 +265,7 @@ def generate(
263
  )
264
 
265
  with torch.inference_mode():
 
266
  loader = Loader()
267
  pipe = loader.load(model, scheduler, karras, taesd, deepcache_interval, TORCH_DTYPE)
268
 
@@ -319,4 +322,7 @@ def generate(
319
  # spaces always start fresh
320
  loader.pipe = None
321
 
 
 
 
322
  return images
 
1
  import re
2
+ import time
3
  from contextlib import contextmanager
4
  from datetime import datetime
5
  from itertools import product
 
7
  from types import MethodType
8
  from warnings import filterwarnings
9
 
10
+ import gradio as gr
11
  import spaces
12
  import tomesd
13
  import torch
 
244
  deepcache_interval=1,
245
  tgate_step=0,
246
  tome_ratio=0,
247
+ progress=gr.Progress(track_tqdm=True),
248
  ):
249
  if not torch.cuda.is_available():
250
+ raise gr.Error("CUDA not available")
251
 
252
  if seed is None:
253
  seed = int(datetime.now().timestamp())
 
265
  )
266
 
267
  with torch.inference_mode():
268
+ start = time.perf_counter()
269
  loader = Loader()
270
  pipe = loader.load(model, scheduler, karras, taesd, deepcache_interval, TORCH_DTYPE)
271
 
 
322
  # spaces always start fresh
323
  loader.pipe = None
324
 
325
+ end = time.perf_counter()
326
+ diff = end - start
327
+ gr.Info(f"Generated {len(images)} image{'s' if len(images) > 1 else ''} in {diff:.2f}s")
328
  return images