lunarfish commited on
Commit
52eef38
1 Parent(s): 54a1c0b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +79 -1
app.py CHANGED
@@ -1,3 +1,81 @@
1
  import gradio as gr
 
 
 
 
 
 
 
2
 
3
- gr.Interface.load("models/lunarfish/furrydiffusion").launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import torch
3
+ from torch import autocast
4
+ from diffusers import StableDiffusionPipeline
5
+ from datasets import load_dataset
6
+ from PIL import Image
7
+ import re
8
+ import os
9
 
10
+ model_id = "models/lunarfish/furrydiffusion"
11
+ device = "cpu"
12
+ pipe = StableDiffusionPipeline.from_pretrained(model_id, revision="fp16", torch_dtype=torch.float16)
13
+ pipe = pipe.to(device)
14
+
15
+ def infer(prompt, samples, steps, scale, seed):
16
+ generator = torch.Generator(device=device).manual_seed(seed)
17
+ images_list = pipe(
18
+ [prompt] * samples,
19
+ num_inference_steps=steps,
20
+ guidance_scale=scale,
21
+ generator=generator,
22
+ )
23
+ images = []
24
+ safe_image = Image.open(r"unsafe.png")
25
+ for i, image in enumerate(images_list["sample"]):
26
+ if(images_list["nsfw_content_detected"][i]):
27
+ images.append(safe_image)
28
+ else:
29
+ images.append(image)
30
+ return images
31
+
32
+
33
+
34
+ block = gr.Blocks()
35
+
36
+ with block:
37
+ with gr.Group():
38
+ with gr.Box():
39
+ with gr.Row().style(mobile_collapse=False, equal_height=True):
40
+ text = gr.Textbox(
41
+ label="Enter your prompt",
42
+ show_label=False,
43
+ max_lines=1,
44
+ placeholder="Enter your prompt",
45
+ ).style(
46
+ border=(True, False, True, True),
47
+ rounded=(True, False, False, True),
48
+ container=False,
49
+ )
50
+ btn = gr.Button("Generate image").style(
51
+ margin=False,
52
+ rounded=(False, True, True, False),
53
+ )
54
+ gallery = gr.Gallery(
55
+ label="Generated images", show_label=False, elem_id="gallery"
56
+ ).style(grid=[2], height="auto")
57
+
58
+ advanced_button = gr.Button("Advanced options", elem_id="advanced-btn")
59
+
60
+ with gr.Row(elem_id="advanced-options"):
61
+ samples = gr.Slider(label="Images", minimum=1, maximum=4, value=4, step=1)
62
+ steps = gr.Slider(label="Steps", minimum=1, maximum=50, value=45, step=1)
63
+ scale = gr.Slider(
64
+ label="Guidance Scale", minimum=0, maximum=50, value=7.5, step=0.1
65
+ )
66
+ seed = gr.Slider(
67
+ label="Seed",
68
+ minimum=0,
69
+ maximum=2147483647,
70
+ step=1,
71
+ randomize=True,
72
+ )
73
+ text.submit(infer, inputs=[text, samples, steps, scale, seed], outputs=gallery)
74
+ btn.click(infer, inputs=[text, samples, steps, scale, seed], outputs=gallery)
75
+ advanced_button.click(
76
+ None,
77
+ [],
78
+ text,
79
+ )
80
+
81
+ block.launch()