Spaces:
Sleeping
Sleeping
import gradio as gr | |
from diffusion_lens import get_images | |
import numpy as np | |
MAX_SEED = np.iinfo(np.int32).max | |
# Description | |
title = r""" | |
<h1 align="center">Diffusion Lens: Interpreting Text Encoders in Text-to-Image Pipelines</h1> | |
""" | |
description = r""" | |
<b>A demo for the paper <a href='https://arxiv.org/abs/2403.05846' target='_blank'>Diffusion Lens: Interpreting Text Encoders in Text-to-Image Pipelines</a>.<br> | |
""" | |
article = r""" | |
--- | |
π **Citation** | |
<br> | |
If our work is helpful for your research or applications, please cite us via: | |
``` | |
bibtex | |
@article{toker2024diffusion, | |
title={Diffusion Lens: Interpreting Text Encoders in Text-to-Image Pipelines}, | |
author={Toker, Michael and Orgad, Hadas and Ventura, Mor and Arad, Dana and Belinkov, Yonatan}, | |
journal={arXiv preprint arXiv:2403.05846}, | |
year={2024} | |
} | |
``` | |
π§ **Abstact** | |
<br> | |
Text-to-image diffusion models (T2I) use a latent representation of a text prompt to guide the image generation process. | |
However, the process by which the encoder produces the text representation is unknown. | |
We propose the Diffusion Lens, a method for analyzing the text encoder of T2I models by generating images from its intermediate representations. | |
Using the Diffusion Lens, we perform an extensive analysis of two recent T2I models. | |
Exploring compound prompts, we find that complex scenes describing multiple objects are composed progressively and more slowly compared to simple scenes; | |
Exploring knowledge retrieval, we find that representation of uncommon concepts requires further computation compared to common concepts, | |
and that knowledge retrieval is gradual across layers. | |
Overall, our findings provide valuable insights into the text encoder component in T2I pipelines. | |
<br> | |
``` | |
π§ **Contact** | |
<br> | |
If you have any questions, please feel free to open an issue or directly reach us out at <b>tok@cs.technion.ac.il | |
</b>. | |
""" | |
model_num_of_layers = { | |
'Stable Diffusion 1.4': 12, | |
'Stable Diffusion 2.1': 22, | |
} | |
def generate_images(prompt, model, seed): | |
seed = random.randint(0, MAX_SEED) if seed == -1 else seed | |
print('calling diffusion lens with model:', model, 'and seed:', seed) | |
gr.Info('Generating images from intermediate layers..') | |
all_images = [] # Initialize a list to store all images | |
max_num_of_layers = model_num_of_layers[model] | |
for skip_layers in range(max_num_of_layers - 1, -1, -1): | |
# Pass the model and seed to the get_images function | |
images = get_images(prompt, skip_layers=skip_layers, model=model, seed=seed) | |
all_images.append((images[0], f'layer_{max_num_of_layers - skip_layers}')) | |
yield all_images | |
with gr.Blocks() as demo: | |
gr.Markdown(title) | |
gr.Markdown(description) | |
# text_input = gr.Textbox(label="Enter prompt") | |
# model_select = gr.Dropdown(label="Select Model", choices=['sd1', 'sd2']) | |
# seed_input = gr.Number(label="Enter Seed", value=0) # Default seed set to 0 | |
# Update the submit function to include the new inputs | |
# text_input.submit(fn=generate_images, inputs=[text_input, model_select, seed_input], outputs=gallery) | |
with gr.Column(): | |
prompt = gr.Textbox( | |
label="Prompt", | |
value="A photo of Steve Jobs", | |
) | |
model = gr.Radio( | |
[ | |
"Stable Diffusion 1.4", | |
"Stable Diffusion 2.1", | |
], | |
value="Stable Diffusion 1.4", | |
label="Model", | |
) | |
seed = gr.Slider( | |
minimum=0, | |
maximum=MAX_SEED, | |
value=42, | |
step=1, | |
label="Seed Value", | |
) | |
inputs = [ | |
prompt, | |
model, | |
seed, | |
] | |
generate_button = gr.Button("Generate Image") | |
with gr.Column(): | |
gallery = gr.Gallery(label="Generated Images", columns=6, rows=1, object_fit="contain", height="auto") | |
outputs = [gallery] | |
gr.on( | |
triggers=[ | |
# prompt.submit, | |
generate_button.click, | |
# seed.input, | |
# model.input | |
], | |
fn=generate_images, | |
inputs=inputs, | |
outputs=outputs, | |
show_progress="full", | |
show_api=False, | |
trigger_mode="always_last", | |
) | |
gr.Markdown(article) | |
demo.queue(api_open=False) | |
demo.launch(show_api=False) | |