import gradio as gr from diffusion_lens import get_images import numpy as np MAX_SEED = np.iinfo(np.int32).max # Description title = r"""

Diffusion Lens: Interpreting Text Encoders in Text-to-Image Pipelines

""" description = r""" A demo for the paper Diffusion Lens: Interpreting Text Encoders in Text-to-Image Pipelines.
""" article = r""" --- 📝 **Citation**
If our work is helpful for your research or applications, please cite us via: ```bibtex @article{toker2024diffusion, title={Diffusion Lens: Interpreting Text Encoders in Text-to-Image Pipelines}, author={Toker, Michael and Orgad, Hadas and Ventura, Mor and Arad, Dana and Belinkov, Yonatan}, journal={arXiv preprint arXiv:2403.05846}, year={2024} } }``` 📧 **Contact**
If you have any questions, please feel free to open an issue or directly reach us out at tok@cs.technuin.ac.il. """ model_num_of_layers = { 'Stable Diffusion 1.4': 12, 'Stable Diffusion 2.1': 22, } def generate_images(prompt, model, seed): seed = random.randint(0, MAX_SEED) if seed == -1 else seed print('calling diffusion lens with model:', model, 'and seed:', seed) gr.Info('Generating images from intermediate layers..') all_images = [] # Initialize a list to store all images max_num_of_layers = model_num_of_layers[model] for skip_layers in range(max_num_of_layers - 1, -1, -1): # Pass the model and seed to the get_images function images = get_images(prompt, skip_layers=skip_layers, model=model, seed=seed) all_images.append((images[0], f'layer_{12 - skip_layers}')) yield all_images with gr.Blocks() as demo: gr.Markdown(title) gr.Markdown(description) # text_input = gr.Textbox(label="Enter prompt") # model_select = gr.Dropdown(label="Select Model", choices=['sd1', 'sd2']) # seed_input = gr.Number(label="Enter Seed", value=0) # Default seed set to 0 with gr.Column(): gallery = gr.Gallery(label="Generated Images", columns=4, rows=3, object_fit="contain", height="auto") # Update the submit function to include the new inputs # text_input.submit(fn=generate_images, inputs=[text_input, model_select, seed_input], outputs=gallery) with gr.Column(): prompt = gr.Textbox( label="Prompt", value="A photo of Steve Jobs", ) model = gr.Radio( [ "Stable Diffusion 1.4", "Stable Diffusion 2.1", ], value="Stable Diffusion 1.4", label="Model", ) seed = gr.Slider( minimum=-1, maximum=MAX_SEED, value=-1, step=1, label="Seed Value", ) inputs = [ prompt, model, seed, ] outputs = [gallery] generate_button = gr.Button("Generate Image") gr.on( triggers=[ prompt.submit, generate_button.click, seed.input, model.input ], fn=generate_images, inputs=inputs, outputs=outputs, show_progress="full", show_api=False, trigger_mode="always_last", ) gr.Markdown(article) demo.queue(api_open=False) demo.launch(show_api=False)