Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
apolinario
commited on
Commit
•
53066e3
1
Parent(s):
38d05ac
Swap to hybrid backend
Browse files
app.py
CHANGED
@@ -5,7 +5,11 @@ from torch import autocast
|
|
5 |
from diffusers import StableDiffusionPipeline
|
6 |
from datasets import load_dataset
|
7 |
from PIL import Image
|
|
|
|
|
8 |
import re
|
|
|
|
|
9 |
|
10 |
from share_btn import community_icon_html, loading_icon_html, share_js
|
11 |
|
@@ -21,27 +25,44 @@ torch.backends.cudnn.benchmark = True
|
|
21 |
word_list_dataset = load_dataset("stabilityai/word-list", data_files="list.txt", use_auth_token=True)
|
22 |
word_list = word_list_dataset["train"]['text']
|
23 |
|
24 |
-
|
|
|
|
|
|
|
|
|
|
|
25 |
#When running locally you can also remove this filter
|
26 |
for filter in word_list:
|
27 |
if re.search(rf"\b{filter}\b", prompt):
|
28 |
raise gr.Error("Unsafe content found. Please try again with different prompts.")
|
29 |
|
30 |
-
generator = torch.Generator(device=device).manual_seed(seed)
|
31 |
-
|
32 |
-
images_list = pipe(
|
33 |
-
[prompt] * samples,
|
34 |
-
num_inference_steps=steps,
|
35 |
-
guidance_scale=scale,
|
36 |
-
generator=generator,
|
37 |
-
)
|
38 |
images = []
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
return images, gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)
|
46 |
|
47 |
|
@@ -298,6 +319,7 @@ with block:
|
|
298 |
share_button = gr.Button("Share to community", elem_id="share-btn", visible=False)
|
299 |
|
300 |
with gr.Row(elem_id="advanced-options"):
|
|
|
301 |
samples = gr.Slider(label="Images", minimum=1, maximum=4, value=4, step=1)
|
302 |
steps = gr.Slider(label="Steps", minimum=1, maximum=50, value=45, step=1)
|
303 |
scale = gr.Slider(
|
@@ -311,13 +333,13 @@ with block:
|
|
311 |
randomize=True,
|
312 |
)
|
313 |
|
314 |
-
ex = gr.Examples(examples=examples, fn=infer, inputs=
|
315 |
ex.dataset.headers = [""]
|
316 |
|
317 |
|
318 |
-
text.submit(infer, inputs=
|
319 |
|
320 |
-
btn.click(infer, inputs=
|
321 |
|
322 |
advanced_button.click(
|
323 |
None,
|
@@ -350,4 +372,4 @@ Despite how impressive being able to turn text into image is, beware to the fact
|
|
350 |
"""
|
351 |
)
|
352 |
|
353 |
-
block.queue(max_size=25).launch()
|
|
|
5 |
from diffusers import StableDiffusionPipeline
|
6 |
from datasets import load_dataset
|
7 |
from PIL import Image
|
8 |
+
from io import BytesIO
|
9 |
+
import base64
|
10 |
import re
|
11 |
+
import os
|
12 |
+
import requests
|
13 |
|
14 |
from share_btn import community_icon_html, loading_icon_html, share_js
|
15 |
|
|
|
25 |
word_list_dataset = load_dataset("stabilityai/word-list", data_files="list.txt", use_auth_token=True)
|
26 |
word_list = word_list_dataset["train"]['text']
|
27 |
|
28 |
+
is_gpu_busy = False
|
29 |
+
def infer(prompt):
|
30 |
+
global is_gpu_busy
|
31 |
+
samples = 4
|
32 |
+
steps = 50
|
33 |
+
scale = 7.5
|
34 |
#When running locally you can also remove this filter
|
35 |
for filter in word_list:
|
36 |
if re.search(rf"\b{filter}\b", prompt):
|
37 |
raise gr.Error("Unsafe content found. Please try again with different prompts.")
|
38 |
|
39 |
+
#generator = torch.Generator(device=device).manual_seed(seed)
|
40 |
+
print("Is GPU busy? ", is_gpu_busy)
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
images = []
|
42 |
+
if(not is_gpu_busy):
|
43 |
+
is_gpu_busy = True
|
44 |
+
images_list = pipe(
|
45 |
+
[prompt] * samples,
|
46 |
+
num_inference_steps=steps,
|
47 |
+
guidance_scale=scale,
|
48 |
+
#generator=generator,
|
49 |
+
)
|
50 |
+
is_gpu_busy = False
|
51 |
+
safe_image = Image.open(r"unsafe.png")
|
52 |
+
for i, image in enumerate(images_list["sample"]):
|
53 |
+
if(images_list["nsfw_content_detected"][i]):
|
54 |
+
images.append(safe_image)
|
55 |
+
else:
|
56 |
+
images.append(image)
|
57 |
+
else:
|
58 |
+
url = os.getenv('JAX_BACKEND_URL')
|
59 |
+
payload = {'prompt': prompt}
|
60 |
+
images_request = requests.post(url, json = payload)
|
61 |
+
for image in images_request.json()["images"]:
|
62 |
+
image_decoded = Image.open(BytesIO(base64.b64decode(image)))
|
63 |
+
images.append(image_decoded)
|
64 |
+
|
65 |
+
|
66 |
return images, gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)
|
67 |
|
68 |
|
|
|
319 |
share_button = gr.Button("Share to community", elem_id="share-btn", visible=False)
|
320 |
|
321 |
with gr.Row(elem_id="advanced-options"):
|
322 |
+
gr.Markdown("Advanced settings are temporarily unavailable")
|
323 |
samples = gr.Slider(label="Images", minimum=1, maximum=4, value=4, step=1)
|
324 |
steps = gr.Slider(label="Steps", minimum=1, maximum=50, value=45, step=1)
|
325 |
scale = gr.Slider(
|
|
|
333 |
randomize=True,
|
334 |
)
|
335 |
|
336 |
+
ex = gr.Examples(examples=examples, fn=infer, inputs=text, outputs=[gallery, community_icon, loading_icon, share_button], cache_examples=True)
|
337 |
ex.dataset.headers = [""]
|
338 |
|
339 |
|
340 |
+
text.submit(infer, inputs=text, outputs=[gallery, community_icon, loading_icon, share_button])
|
341 |
|
342 |
+
btn.click(infer, inputs=text, outputs=[gallery, community_icon, loading_icon, share_button])
|
343 |
|
344 |
advanced_button.click(
|
345 |
None,
|
|
|
372 |
"""
|
373 |
)
|
374 |
|
375 |
+
block.queue(max_size=25, concurrency_count=2).launch()
|