awacke1 commited on
Commit
e6c79f5
1 Parent(s): 4842bdd

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +457 -0
app.py ADDED
@@ -0,0 +1,457 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import json
4
+ import base64
5
+ from datetime import datetime
6
+ import numpy as np
7
+ import torch
8
+ import gradio as gr
9
+ from gradio_imageslider import ImageSlider
10
+ from diffusers import ControlNetModel, StableDiffusionXLControlNetImg2ImgPipeline, DDIMScheduler
11
+ from controlnet_aux import AnylineDetector
12
+ from compel import Compel, ReturnedEmbeddingsType
13
+ from PIL import Image
14
+ import pandas as pd
15
+
16
+ # Configuration
17
+ IS_SPACES_ZERO = os.environ.get("SPACES_ZERO_GPU", "0") == "1"
18
+ IS_SPACE = os.environ.get("SPACE_ID", None) is not None
19
+ device = "cuda" if torch.cuda.is_available() else "cpu"
20
+ dtype = torch.float16
21
+ LOW_MEMORY = os.getenv("LOW_MEMORY", "0") == "1"
22
+
23
+ print(f"device: {device}")
24
+ print(f"dtype: {dtype}")
25
+ print(f"low memory: {LOW_MEMORY}")
26
+
27
+ # Model initialization
28
+ model = "stabilityai/stable-diffusion-xl-base-1.0"
29
+ scheduler = DDIMScheduler.from_pretrained(model, subfolder="scheduler")
30
+ controlnet = ControlNetModel.from_pretrained(
31
+ "TheMistoAI/MistoLine",
32
+ torch_dtype=torch.float16,
33
+ revision="refs/pr/3",
34
+ variant="fp16",
35
+ )
36
+ pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained(
37
+ model,
38
+ controlnet=controlnet,
39
+ torch_dtype=dtype,
40
+ variant="fp16",
41
+ use_safetensors=True,
42
+ scheduler=scheduler,
43
+ )
44
+
45
+ compel = Compel(
46
+ tokenizer=[pipe.tokenizer, pipe.tokenizer_2],
47
+ text_encoder=[pipe.text_encoder, pipe.text_encoder_2],
48
+ returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
49
+ requires_pooled=[False, True],
50
+ )
51
+ pipe = pipe.to(device)
52
+
53
+ anyline = AnylineDetector.from_pretrained(
54
+ "TheMistoAI/MistoLine", filename="MTEED.pth", subfolder="Anyline"
55
+ ).to(device)
56
+
57
+ # Global variables for metadata and likes cache
58
+ image_metadata = pd.DataFrame(columns=['Filename', 'Prompt', 'Likes', 'Dislikes', 'Hearts', 'Created'])
59
+ LIKES_CACHE_FILE = "likes_cache.json"
60
+
61
+ def load_likes_cache():
62
+ if os.path.exists(LIKES_CACHE_FILE):
63
+ with open(LIKES_CACHE_FILE, 'r') as f:
64
+ return json.load(f)
65
+ return {}
66
+
67
+ def save_likes_cache(cache):
68
+ with open(LIKES_CACHE_FILE, 'w') as f:
69
+ json.dump(cache, f)
70
+
71
+ likes_cache = load_likes_cache()
72
+
73
+ def pad_image(image):
74
+ w, h = image.size
75
+ if w == h:
76
+ return image
77
+ elif w > h:
78
+ new_image = Image.new(image.mode, (w, w), (0, 0, 0))
79
+ new_image.paste(image, (0, (w - h) // 2))
80
+ return new_image
81
+ else:
82
+ new_image = Image.new(image.mode, (h, h), (0, 0, 0))
83
+ new_image.paste(image, ((h - w) // 2, 0))
84
+ return new_image
85
+
86
+ def create_download_link(filename):
87
+ with open(filename, "rb") as file:
88
+ encoded_string = base64.b64encode(file.read()).decode('utf-8')
89
+ download_link = f'<a href="data:image/png;base64,{encoded_string}" download="{filename}">Download Image</a>'
90
+ return download_link
91
+
92
+ def save_image(image: Image.Image, prompt: str) -> str:
93
+ global image_metadata, likes_cache
94
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
95
+ safe_prompt = ''.join(e for e in prompt if e.isalnum() or e.isspace())[:50]
96
+ filename = f"{timestamp}_{safe_prompt}.png"
97
+ image.save(filename)
98
+ new_row = pd.DataFrame({
99
+ 'Filename': [filename],
100
+ 'Prompt': [prompt],
101
+ 'Likes': [0],
102
+ 'Dislikes': [0],
103
+ 'Hearts': [0],
104
+ 'Created': [datetime.now()]
105
+ })
106
+ image_metadata = pd.concat([image_metadata, new_row], ignore_index=True)
107
+ likes_cache[filename] = {'likes': 0, 'dislikes': 0, 'hearts': 0}
108
+ save_likes_cache(likes_cache)
109
+ return filename
110
+
111
+ def get_image_gallery():
112
+ global image_metadata
113
+ image_files = image_metadata['Filename'].tolist()
114
+ return [(file, get_image_caption(file)) for file in image_files if os.path.exists(file)]
115
+
116
+ def get_image_caption(filename):
117
+ global likes_cache, image_metadata
118
+ if filename in likes_cache:
119
+ likes = likes_cache[filename]['likes']
120
+ dislikes = likes_cache[filename]['dislikes']
121
+ hearts = likes_cache[filename]['hearts']
122
+ prompt = image_metadata[image_metadata['Filename'] == filename]['Prompt'].values[0]
123
+ return f"{filename}\nPrompt: {prompt}\n👍 {likes} 👎 {dislikes} ❤️ {hearts}"
124
+ return filename
125
+
126
+ def delete_all_images():
127
+ global image_metadata, likes_cache
128
+ for file in image_metadata['Filename']:
129
+ if os.path.exists(file):
130
+ os.remove(file)
131
+ image_metadata = pd.DataFrame(columns=['Filename', 'Prompt', 'Likes', 'Dislikes', 'Hearts', 'Created'])
132
+ likes_cache = {}
133
+ save_likes_cache(likes_cache)
134
+ return get_image_gallery(), image_metadata.values.tolist()
135
+
136
+ def delete_image(filename):
137
+ global image_metadata, likes_cache
138
+ if filename and os.path.exists(filename):
139
+ os.remove(filename)
140
+ image_metadata = image_metadata[image_metadata['Filename'] != filename]
141
+ if filename in likes_cache:
142
+ del likes_cache[filename]
143
+ save_likes_cache(likes_cache)
144
+ return get_image_gallery(), image_metadata.values.tolist()
145
+
146
+ def vote(filename, vote_type):
147
+ global likes_cache
148
+ if filename in likes_cache:
149
+ likes_cache[filename][vote_type.lower()] += 1
150
+ save_likes_cache(likes_cache)
151
+ return get_image_gallery(), image_metadata.values.tolist()
152
+
153
+ @gr.on(queue_pred_done=True)
154
+ def predict(
155
+ input_image,
156
+ prompt,
157
+ negative_prompt,
158
+ seed,
159
+ guidance_scale=8.5,
160
+ controlnet_conditioning_scale=0.5,
161
+ strength=1.0,
162
+ controlnet_start=0.0,
163
+ controlnet_end=1.0,
164
+ guassian_sigma=2.0,
165
+ intensity_threshold=3,
166
+ progress=gr.Progress(track_tqdm=True),
167
+ ):
168
+ if input_image is None:
169
+ raise gr.Error("Please upload an image.")
170
+ padded_image = pad_image(input_image).resize((1024, 1024)).convert("RGB")
171
+ conditioning, pooled = compel([prompt, negative_prompt])
172
+ generator = torch.manual_seed(seed)
173
+ last_time = time.time()
174
+ anyline_image = anyline(
175
+ padded_image,
176
+ detect_resolution=1280,
177
+ guassian_sigma=max(0.01, guassian_sigma),
178
+ intensity_threshold=intensity_threshold,
179
+ )
180
+
181
+ images = pipe(
182
+ image=padded_image,
183
+ control_image=anyline_image,
184
+ strength=strength,
185
+ prompt_embeds=conditioning[0:1],
186
+ pooled_prompt_embeds=pooled[0:1],
187
+ negative_prompt_embeds=conditioning[1:2],
188
+ negative_pooled_prompt_embeds=pooled[1:2],
189
+ width=1024,
190
+ height=1024,
191
+ controlnet_conditioning_scale=float(controlnet_conditioning_scale),
192
+ controlnet_start=float(controlnet_start),
193
+ controlnet_end=float(controlnet_end),
194
+ generator=generator,
195
+ num_inference_steps=30,
196
+ guidance_scale=guidance_scale,
197
+ eta=1.0,
198
+ )
199
+ print(f"Time taken: {time.time() - last_time}")
200
+ generated_image = images.images[0]
201
+ filename = save_image(generated_image, prompt)
202
+ download_link = create_download_link(filename)
203
+ return (padded_image, generated_image), padded_image, anyline_image, download_link, get_image_gallery(), image_metadata.values.tolist()
204
+
205
+ css = """
206
+ #intro {
207
+ max-width: 100%;
208
+ text-align: center;
209
+ margin: 0 auto;
210
+ }
211
+ .gradio-container {max-width: 1200px !important}
212
+ footer {visibility: hidden}
213
+ """
214
+
215
+ with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
216
+ gr.Markdown(
217
+ """
218
+ # 🎨 ArtForge: MistoLine ControlNet Masterpiece Gallery
219
+
220
+ Create, curate, and compete with AI-enhanced images using MistoLine ControlNet. Join our creative multiplayer experience! 🖼️🏆✨
221
+
222
+ This demo showcases the capabilities of [TheMistoAI/MistoLine](https://huggingface.co/TheMistoAI/MistoLine) ControlNet with SDXL.
223
+
224
+ - SDXL Controlnet: [TheMistoAI/MistoLine](https://huggingface.co/TheMistoAI/MistoLine)
225
+ - [Anyline with Controlnet Aux](https://github.com/huggingface/controlnet_aux)
226
+ - For upscaling, see [Enhance This Demo](https://huggingface.co/spaces/radames/Enhance-This-HiDiffusion-SDXL)
227
+ """,
228
+ elem_id="intro",
229
+ )
230
+
231
+ with gr.Tab("Generate Images"):
232
+ with gr.Row():
233
+ with gr.Column(scale=1):
234
+ image_input = gr.Image(type="pil", label="Input Image")
235
+ prompt = gr.Textbox(
236
+ label="Prompt",
237
+ info="The prompt is very important to get the desired results. Please try to describe the image as best as you can. Accepts Compel Syntax",
238
+ )
239
+ negative_prompt = gr.Textbox(
240
+ label="Negative Prompt",
241
+ value="blurry, ugly, duplicate, poorly drawn, deformed, mosaic",
242
+ )
243
+ seed = gr.Slider(
244
+ minimum=0,
245
+ maximum=2**64 - 1,
246
+ value=1415926535897932,
247
+ step=1,
248
+ label="Seed",
249
+ randomize=True,
250
+ )
251
+ with gr.Accordion(label="Advanced", open=False):
252
+ guidance_scale = gr.Slider(
253
+ minimum=0,
254
+ maximum=50,
255
+ value=8.5,
256
+ step=0.001,
257
+ label="Guidance Scale",
258
+ )
259
+ controlnet_conditioning_scale = gr.Slider(
260
+ minimum=0,
261
+ maximum=1,
262
+ step=0.001,
263
+ value=0.5,
264
+ label="ControlNet Conditioning Scale",
265
+ )
266
+ strength = gr.Slider(
267
+ minimum=0,
268
+ maximum=1,
269
+ step=0.001,
270
+ value=1,
271
+ label="Strength",
272
+ )
273
+ controlnet_start = gr.Slider(
274
+ minimum=0,
275
+ maximum=1,
276
+ step=0.001,
277
+ value=0.0,
278
+ label="ControlNet Start",
279
+ )
280
+ controlnet_end = gr.Slider(
281
+ minimum=0.0,
282
+ maximum=1.0,
283
+ step=0.001,
284
+ value=1.0,
285
+ label="ControlNet End",
286
+ )
287
+ guassian_sigma = gr.Slider(
288
+ minimum=0.01,
289
+ maximum=10.0,
290
+ step=0.1,
291
+ value=2.0,
292
+ label="(Anyline) Guassian Sigma",
293
+ )
294
+ intensity_threshold = gr.Slider(
295
+ minimum=0,
296
+ maximum=255,
297
+ step=1,
298
+ value=3,
299
+ label="(Anyline) Intensity Threshold",
300
+ )
301
+
302
+ btn = gr.Button("Generate")
303
+ with gr.Column(scale=2):
304
+ with gr.Group():
305
+ image_slider = ImageSlider(position=0.5)
306
+ with gr.Row():
307
+ padded_image = gr.Image(type="pil", label="Padded Image")
308
+ anyline_image = gr.Image(type="pil", label="Anyline Image")
309
+ download_link = gr.HTML(label="Download Generated Image")
310
+
311
+ with gr.Tab("Gallery and Voting"):
312
+ image_gallery = gr.Gallery(label="Generated Images", show_label=True, columns=4, height="auto")
313
+
314
+ with gr.Row():
315
+ like_button = gr.Button("👍 Like")
316
+ dislike_button = gr.Button("👎 Dislike")
317
+ heart_button = gr.Button("❤️ Heart")
318
+ delete_image_button = gr.Button("🗑️ Delete Selected Image")
319
+
320
+ selected_image = gr.State(None)
321
+
322
+ with gr.Tab("Metadata and Management"):
323
+ metadata_df = gr.Dataframe(
324
+ label="Image Metadata",
325
+ headers=["Filename", "Prompt", "Likes", "Dislikes", "Hearts", "Created"],
326
+ interactive=False
327
+ )
328
+ delete_all_button = gr.Button("🗑️ Delete All Images")
329
+
330
+ inputs = [
331
+ image_input,
332
+ prompt,
333
+ negative_prompt,
334
+ seed,
335
+ guidance_scale,
336
+ controlnet_conditioning_scale,
337
+ strength,
338
+ controlnet_start,
339
+ controlnet_end,
340
+ guassian_sigma,
341
+ intensity_threshold,
342
+ ]
343
+ outputs = [image_slider, padded_image, anyline_image, download_link, image_gallery, metadata_df]
344
+
345
+ btn.click(fn=predict, inputs=inputs, outputs=outputs)
346
+
347
+ image_gallery.select(fn=lambda evt: evt, inputs=[], outputs=[selected_image])
348
+
349
+ like_button.click(fn=lambda x: vote(x, 'likes'), inputs=[selected_image], outputs=[image_gallery, metadata_df])
350
+ dislike_button.click(fn=lambda x: vote(x, 'dislikes'), inputs=[selected_image], outputs=[image_gallery, metadata_df])
351
+ heart_button.click(fn=lambda x: vote(x, 'hearts'), inputs=[selected_image], outputs=[image_gallery, metadata_df])
352
+ delete_image_button.click(fn=deletedelete_image_button.click(fn=delete_image, inputs=[selected_image], outputs=[image_gallery, metadata_df])
353
+ delete_all_button.click(fn=delete_all_images, inputs=[], outputs=[image_gallery, metadata_df])
354
+
355
+ demo.load(fn=lambda: (get_image_gallery(), image_metadata.values.tolist()), outputs=[image_gallery, metadata_df])
356
+
357
+ gr.Examples(
358
+ fn=predict,
359
+ inputs=inputs,
360
+ outputs=outputs,
361
+ examples=[
362
+ [
363
+ "./examples/city.png",
364
+ "hyperrealistic surreal cityscape scene at sunset, buildings",
365
+ "blurry, ugly, duplicate, poorly drawn, deformed, mosaic",
366
+ 13113544138610326000,
367
+ 8.5,
368
+ 0.481,
369
+ 1.0,
370
+ 0.0,
371
+ 0.9,
372
+ 2,
373
+ 3,
374
+ ],
375
+ [
376
+ "./examples/lara.jpeg",
377
+ "photography of lara croft 8k high definition award winning",
378
+ "blurry, ugly, duplicate, poorly drawn, deformed, mosaic",
379
+ 5436236241,
380
+ 8.5,
381
+ 0.8,
382
+ 1.0,
383
+ 0.0,
384
+ 0.9,
385
+ 2,
386
+ 3,
387
+ ],
388
+ [
389
+ "./examples/cybetruck.jpeg",
390
+ "photo of tesla cybertruck futuristic car 8k high definition on a sand dune in mars, future",
391
+ "blurry, ugly, duplicate, poorly drawn, deformed, mosaic",
392
+ 383472451451,
393
+ 8.5,
394
+ 0.8,
395
+ 0.8,
396
+ 0.0,
397
+ 0.9,
398
+ 2,
399
+ 3,
400
+ ],
401
+ [
402
+ "./examples/jesus.png",
403
+ "a photorealistic painting of Jesus Christ, 4k high definition",
404
+ "blurry, ugly, duplicate, poorly drawn, deformed, mosaic",
405
+ 13317204146129588000,
406
+ 8.5,
407
+ 0.8,
408
+ 0.8,
409
+ 0.0,
410
+ 0.9,
411
+ 2,
412
+ 3,
413
+ ],
414
+ [
415
+ "./examples/anna-sullivan-DioLM8ViiO8-unsplash.jpg",
416
+ "A crowded stadium with enthusiastic fans watching a daytime sporting event, the stands filled with colorful attire and the sun casting a warm glow",
417
+ "blurry, ugly, duplicate, poorly drawn, deformed, mosaic",
418
+ 5623124123512,
419
+ 8.5,
420
+ 0.8,
421
+ 0.8,
422
+ 0.0,
423
+ 0.9,
424
+ 2,
425
+ 3,
426
+ ],
427
+ [
428
+ "./examples/img_aef651cb-2919-499d-aa49-6d4e2e21a56e_1024.jpg",
429
+ "a large red flower on a black background 4k high definition",
430
+ "blurry, ugly, duplicate, poorly drawn, deformed, mosaic",
431
+ 23123412341234,
432
+ 8.5,
433
+ 0.8,
434
+ 0.8,
435
+ 0.0,
436
+ 0.9,
437
+ 2,
438
+ 3,
439
+ ],
440
+ [
441
+ "./examples/huggingface.jpg",
442
+ "photo realistic huggingface human emoji costume, round, yellow, (human skin)+++ (human texture)+++",
443
+ "blurry, ugly, duplicate, poorly drawn, deformed, mosaic, emoji cartoon, drawing, pixelated",
444
+ 12312353423,
445
+ 15.206,
446
+ 0.364,
447
+ 0.8,
448
+ 0.0,
449
+ 0.9,
450
+ 2,
451
+ 3,
452
+ ],
453
+ ],
454
+ cache_examples=True,
455
+ )
456
+
457
+ demo.queue(concurrency_count=1, max_size=20).launch(debug=True)