import os import time import json import base64 from datetime import datetime import numpy as np import torch import gradio as gr from gradio_imageslider import ImageSlider from diffusers import ControlNetModel, StableDiffusionXLControlNetImg2ImgPipeline, DDIMScheduler from controlnet_aux import AnylineDetector from compel import Compel, ReturnedEmbeddingsType from PIL import Image import pandas as pd # Configuration IS_SPACES_ZERO = os.environ.get("SPACES_ZERO_GPU", "0") == "1" IS_SPACE = os.environ.get("SPACE_ID", None) is not None device = "cuda" if torch.cuda.is_available() else "cpu" dtype = torch.float16 LOW_MEMORY = os.getenv("LOW_MEMORY", "0") == "1" print(f"device: {device}") print(f"dtype: {dtype}") print(f"low memory: {LOW_MEMORY}") # Model initialization model = "stabilityai/stable-diffusion-xl-base-1.0" scheduler = DDIMScheduler.from_pretrained(model, subfolder="scheduler") controlnet = ControlNetModel.from_pretrained( "TheMistoAI/MistoLine", torch_dtype=torch.float16, revision="refs/pr/3", variant="fp16", ) pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained( model, controlnet=controlnet, torch_dtype=dtype, variant="fp16", use_safetensors=True, scheduler=scheduler, ) compel = Compel( tokenizer=[pipe.tokenizer, pipe.tokenizer_2], text_encoder=[pipe.text_encoder, pipe.text_encoder_2], returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, requires_pooled=[False, True], ) pipe = pipe.to(device) anyline = AnylineDetector.from_pretrained( "TheMistoAI/MistoLine", filename="MTEED.pth", subfolder="Anyline" ).to(device) # Global variables for metadata and likes cache image_metadata = pd.DataFrame(columns=['Filename', 'Prompt', 'Likes', 'Dislikes', 'Hearts', 'Created']) LIKES_CACHE_FILE = "likes_cache.json" def load_likes_cache(): if os.path.exists(LIKES_CACHE_FILE): with open(LIKES_CACHE_FILE, 'r') as f: return json.load(f) return {} def save_likes_cache(cache): with open(LIKES_CACHE_FILE, 'w') as f: json.dump(cache, f) likes_cache = load_likes_cache() def pad_image(image): w, h = image.size if w == h: return image elif w > h: new_image = Image.new(image.mode, (w, w), (0, 0, 0)) new_image.paste(image, (0, (w - h) // 2)) return new_image else: new_image = Image.new(image.mode, (h, h), (0, 0, 0)) new_image.paste(image, ((h - w) // 2, 0)) return new_image def create_download_link(filename): with open(filename, "rb") as file: encoded_string = base64.b64encode(file.read()).decode('utf-8') download_link = f'Download Image' return download_link def save_image(image: Image.Image, prompt: str) -> str: global image_metadata, likes_cache timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") safe_prompt = ''.join(e for e in prompt if e.isalnum() or e.isspace())[:50] filename = f"{timestamp}_{safe_prompt}.png" image.save(filename) new_row = pd.DataFrame({ 'Filename': [filename], 'Prompt': [prompt], 'Likes': [0], 'Dislikes': [0], 'Hearts': [0], 'Created': [datetime.now()] }) image_metadata = pd.concat([image_metadata, new_row], ignore_index=True) likes_cache[filename] = {'likes': 0, 'dislikes': 0, 'hearts': 0} save_likes_cache(likes_cache) return filename def get_image_gallery(): global image_metadata image_files = image_metadata['Filename'].tolist() return [(file, get_image_caption(file)) for file in image_files if os.path.exists(file)] def get_image_caption(filename): global likes_cache, image_metadata if filename in likes_cache: likes = likes_cache[filename]['likes'] dislikes = likes_cache[filename]['dislikes'] hearts = likes_cache[filename]['hearts'] prompt = image_metadata[image_metadata['Filename'] == filename]['Prompt'].values[0] return f"{filename}\nPrompt: {prompt}\nšŸ‘ {likes} šŸ‘Ž {dislikes} ā¤ļø {hearts}" return filename def delete_all_images(): global image_metadata, likes_cache for file in image_metadata['Filename']: if os.path.exists(file): os.remove(file) image_metadata = pd.DataFrame(columns=['Filename', 'Prompt', 'Likes', 'Dislikes', 'Hearts', 'Created']) likes_cache = {} save_likes_cache(likes_cache) return get_image_gallery(), image_metadata.values.tolist() def delete_image(filename): global image_metadata, likes_cache if filename and os.path.exists(filename): os.remove(filename) image_metadata = image_metadata[image_metadata['Filename'] != filename] if filename in likes_cache: del likes_cache[filename] save_likes_cache(likes_cache) return get_image_gallery(), image_metadata.values.tolist() def vote(filename, vote_type): global likes_cache if filename in likes_cache: likes_cache[filename][vote_type.lower()] += 1 save_likes_cache(likes_cache) return get_image_gallery(), image_metadata.values.tolist() @gr.on(queue_pred_done=True) def predict( input_image, prompt, negative_prompt, seed, guidance_scale=8.5, controlnet_conditioning_scale=0.5, strength=1.0, controlnet_start=0.0, controlnet_end=1.0, guassian_sigma=2.0, intensity_threshold=3, progress=gr.Progress(track_tqdm=True), ): if input_image is None: raise gr.Error("Please upload an image.") padded_image = pad_image(input_image).resize((1024, 1024)).convert("RGB") conditioning, pooled = compel([prompt, negative_prompt]) generator = torch.manual_seed(seed) last_time = time.time() anyline_image = anyline( padded_image, detect_resolution=1280, guassian_sigma=max(0.01, guassian_sigma), intensity_threshold=intensity_threshold, ) images = pipe( image=padded_image, control_image=anyline_image, strength=strength, prompt_embeds=conditioning[0:1], pooled_prompt_embeds=pooled[0:1], negative_prompt_embeds=conditioning[1:2], negative_pooled_prompt_embeds=pooled[1:2], width=1024, height=1024, controlnet_conditioning_scale=float(controlnet_conditioning_scale), controlnet_start=float(controlnet_start), controlnet_end=float(controlnet_end), generator=generator, num_inference_steps=30, guidance_scale=guidance_scale, eta=1.0, ) print(f"Time taken: {time.time() - last_time}") generated_image = images.images[0] filename = save_image(generated_image, prompt) download_link = create_download_link(filename) return (padded_image, generated_image), padded_image, anyline_image, download_link, get_image_gallery(), image_metadata.values.tolist() css = """ #intro { max-width: 100%; text-align: center; margin: 0 auto; } .gradio-container {max-width: 1200px !important} footer {visibility: hidden} """ with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo: gr.Markdown( """ # šŸŽØ ArtForge: MistoLine ControlNet Masterpiece Gallery Create, curate, and compete with AI-enhanced images using MistoLine ControlNet. Join our creative multiplayer experience! šŸ–¼ļøšŸ†āœØ This demo showcases the capabilities of [TheMistoAI/MistoLine](https://huggingface.co/TheMistoAI/MistoLine) ControlNet with SDXL. - SDXL Controlnet: [TheMistoAI/MistoLine](https://huggingface.co/TheMistoAI/MistoLine) - [Anyline with Controlnet Aux](https://github.com/huggingface/controlnet_aux) - For upscaling, see [Enhance This Demo](https://huggingface.co/spaces/radames/Enhance-This-HiDiffusion-SDXL) """, elem_id="intro", ) with gr.Tab("Generate Images"): with gr.Row(): with gr.Column(scale=1): image_input = gr.Image(type="pil", label="Input Image") prompt = gr.Textbox( label="Prompt", info="The prompt is very important to get the desired results. Please try to describe the image as best as you can. Accepts Compel Syntax", ) negative_prompt = gr.Textbox( label="Negative Prompt", value="blurry, ugly, duplicate, poorly drawn, deformed, mosaic", ) seed = gr.Slider( minimum=0, maximum=2**64 - 1, value=1415926535897932, step=1, label="Seed", randomize=True, ) with gr.Accordion(label="Advanced", open=False): guidance_scale = gr.Slider( minimum=0, maximum=50, value=8.5, step=0.001, label="Guidance Scale", ) controlnet_conditioning_scale = gr.Slider( minimum=0, maximum=1, step=0.001, value=0.5, label="ControlNet Conditioning Scale", ) strength = gr.Slider( minimum=0, maximum=1, step=0.001, value=1, label="Strength", ) controlnet_start = gr.Slider( minimum=0, maximum=1, step=0.001, value=0.0, label="ControlNet Start", ) controlnet_end = gr.Slider( minimum=0.0, maximum=1.0, step=0.001, value=1.0, label="ControlNet End", ) guassian_sigma = gr.Slider( minimum=0.01, maximum=10.0, step=0.1, value=2.0, label="(Anyline) Guassian Sigma", ) intensity_threshold = gr.Slider( minimum=0, maximum=255, step=1, value=3, label="(Anyline) Intensity Threshold", ) btn = gr.Button("Generate") with gr.Column(scale=2): with gr.Group(): image_slider = ImageSlider(position=0.5) with gr.Row(): padded_image = gr.Image(type="pil", label="Padded Image") anyline_image = gr.Image(type="pil", label="Anyline Image") download_link = gr.HTML(label="Download Generated Image") with gr.Tab("Gallery and Voting"): image_gallery = gr.Gallery(label="Generated Images", show_label=True, columns=4, height="auto") with gr.Row(): like_button = gr.Button("šŸ‘ Like") dislike_button = gr.Button("šŸ‘Ž Dislike") heart_button = gr.Button("ā¤ļø Heart") delete_image_button = gr.Button("šŸ—‘ļø Delete Selected Image") selected_image = gr.State(None) with gr.Tab("Metadata and Management"): metadata_df = gr.Dataframe( label="Image Metadata", headers=["Filename", "Prompt", "Likes", "Dislikes", "Hearts", "Created"], interactive=False ) delete_all_button = gr.Button("šŸ—‘ļø Delete All Images") inputs = [ image_input, prompt, negative_prompt, seed, guidance_scale, controlnet_conditioning_scale, strength, controlnet_start, controlnet_end, guassian_sigma, intensity_threshold, ] outputs = [image_slider, padded_image, anyline_image, download_link, image_gallery, metadata_df] btn.click(fn=predict, inputs=inputs, outputs=outputs) image_gallery.select(fn=lambda evt: evt, inputs=[], outputs=[selected_image]) like_button.click(fn=lambda x: vote(x, 'likes'), inputs=[selected_image], outputs=[image_gallery, metadata_df]) dislike_button.click(fn=lambda x: vote(x, 'dislikes'), inputs=[selected_image], outputs=[image_gallery, metadata_df]) heart_button.click(fn=lambda x: vote(x, 'hearts'), inputs=[selected_image], outputs=[image_gallery, metadata_df]) delete_image_button.click(fn=deletedelete_image_button.click(fn=delete_image, inputs=[selected_image], outputs=[image_gallery, metadata_df]) delete_all_button.click(fn=delete_all_images, inputs=[], outputs=[image_gallery, metadata_df]) demo.load(fn=lambda: (get_image_gallery(), image_metadata.values.tolist()), outputs=[image_gallery, metadata_df]) gr.Examples( fn=predict, inputs=inputs, outputs=outputs, examples=[ [ "./examples/city.png", "hyperrealistic surreal cityscape scene at sunset, buildings", "blurry, ugly, duplicate, poorly drawn, deformed, mosaic", 13113544138610326000, 8.5, 0.481, 1.0, 0.0, 0.9, 2, 3, ], [ "./examples/lara.jpeg", "photography of lara croft 8k high definition award winning", "blurry, ugly, duplicate, poorly drawn, deformed, mosaic", 5436236241, 8.5, 0.8, 1.0, 0.0, 0.9, 2, 3, ], [ "./examples/cybetruck.jpeg", "photo of tesla cybertruck futuristic car 8k high definition on a sand dune in mars, future", "blurry, ugly, duplicate, poorly drawn, deformed, mosaic", 383472451451, 8.5, 0.8, 0.8, 0.0, 0.9, 2, 3, ], [ "./examples/jesus.png", "a photorealistic painting of Jesus Christ, 4k high definition", "blurry, ugly, duplicate, poorly drawn, deformed, mosaic", 13317204146129588000, 8.5, 0.8, 0.8, 0.0, 0.9, 2, 3, ], [ "./examples/anna-sullivan-DioLM8ViiO8-unsplash.jpg", "A crowded stadium with enthusiastic fans watching a daytime sporting event, the stands filled with colorful attire and the sun casting a warm glow", "blurry, ugly, duplicate, poorly drawn, deformed, mosaic", 5623124123512, 8.5, 0.8, 0.8, 0.0, 0.9, 2, 3, ], [ "./examples/img_aef651cb-2919-499d-aa49-6d4e2e21a56e_1024.jpg", "a large red flower on a black background 4k high definition", "blurry, ugly, duplicate, poorly drawn, deformed, mosaic", 23123412341234, 8.5, 0.8, 0.8, 0.0, 0.9, 2, 3, ], [ "./examples/huggingface.jpg", "photo realistic huggingface human emoji costume, round, yellow, (human skin)+++ (human texture)+++", "blurry, ugly, duplicate, poorly drawn, deformed, mosaic, emoji cartoon, drawing, pixelated", 12312353423, 15.206, 0.364, 0.8, 0.0, 0.9, 2, 3, ], ], cache_examples=True, ) demo.queue(concurrency_count=1, max_size=20).launch(debug=True)