import gradio as gr
from diffusion_lens import get_images
import numpy as np
MAX_SEED = np.iinfo(np.int32).max
# Description
title = r"""
Diffusion Lens: Interpreting Text Encoders in Text-to-Image Pipelines
"""
description = r"""
A demo for the paper Diffusion Lens: Interpreting Text Encoders in Text-to-Image Pipelines.
Visit our project webpage for more information.
"""
article = r"""
---
📝 **Citation**
If our work is helpful for your research or applications, please cite us via:
```
bibtex
@article{toker2024diffusion,
title={Diffusion Lens: Interpreting Text Encoders in Text-to-Image Pipelines},
author={Toker, Michael and Orgad, Hadas and Ventura, Mor and Arad, Dana and Belinkov, Yonatan},
journal={arXiv preprint arXiv:2403.05846},
year={2024}
}
```
📧 **Abstact**
```
Text-to-image diffusion models (T2I) use a latent representation of a text prompt to guide the image generation process.
However, the process by which the encoder produces the text representation is unknown.
We propose the Diffusion Lens, a method for analyzing the text encoder of T2I models by generating images from its intermediate representations.
Using the Diffusion Lens, we perform an extensive analysis of two recent T2I models.
Exploring compound prompts, we find that complex scenes describing multiple objects are composed progressively and more slowly compared to simple scenes;
Exploring knowledge retrieval, we find that representation of uncommon concepts requires further computation compared to common concepts,
and that knowledge retrieval is gradual across layers.
Overall, our findings provide valuable insights into the text encoder component in T2I pipelines.
```
📧 **Contact**
```
If you have any questions, please feel free to open an issue or directly reach us out at tok@cs.technion.ac.il
```
.
"""
model_num_of_layers = {
'Stable Diffusion 1.4': 12,
'Stable Diffusion 2.1': 22,
}
# def run_for_examples(prompt, model, seed, skip):
# return generate_images(prompt, model, seed, skip);
def generate_images(prompt, model, seed, skip):
seed = random.randint(0, MAX_SEED) if seed == -1 else seed
print('calling diffusion lens with model:', model, 'and seed:', seed)
gr.Info('Generating images from intermediate layers..')
all_images = [] # Initialize a list to store all images
max_num_of_layers = model_num_of_layers[model]
for skip_layers in range(max_num_of_layers - 1, -1, -1 * skip):
# Pass the model and seed to the get_images function
images = get_images(prompt, skip_layers=skip_layers, model=model, seed=seed)
all_images.append((images[0], f'layer_{max_num_of_layers - skip_layers}'))
yield all_images
gr.Info('Image generation complete')
with gr.Blocks() as demo:
gr.Markdown(title)
gr.Markdown(description)
# text_input = gr.Textbox(label="Enter prompt")
# model_select = gr.Dropdown(label="Select Model", choices=['sd1', 'sd2'])
# seed_input = gr.Number(label="Enter Seed", value=0) # Default seed set to 0
# Update the submit function to include the new inputs
# text_input.submit(fn=generate_images, inputs=[text_input, model_select, seed_input], outputs=gallery)
with gr.Column():
prompt = gr.Textbox(
label="Prompt",
value="A photo of Steve Jobs",
)
model = gr.Radio(
[
"Stable Diffusion 1.4",
"Stable Diffusion 2.1",
],
value="Stable Diffusion 1.4",
label="Model",
)
seed = gr.Slider(
minimum=0,
maximum=MAX_SEED,
value=42,
step=1,
label="Seed Value",
)
skip = gr.Slider(
minimum=1,
maximum=6,
value=3,
step=1,
label="# Layers to Skip Between Generations",
)
inputs = [
prompt,
model,
seed,
skip,
]
generate_button = gr.Button("Generate Image")
with gr.Column():
gallery = gr.Gallery(label="Generated Images", columns=6, rows=1, object_fit="contain", height="auto")
outputs = [gallery]
gr.on(
triggers=[
# prompt.submit,
generate_button.click,
# seed.input,
# model.input
],
fn=generate_images,
inputs=inputs,
outputs=outputs,
show_progress="full",
show_api=False,
trigger_mode="always_last",
)
examples = [
[
"A photo of an Aye-aye.",
"Stable Diffusion 2.1",
42,
1
],
[
"A photo of an Beagle.",
"Stable Diffusion 2.1",
42,
1
],
[
"A green cat and a blue dog.",
"Stable Diffusion 2.1",
42,
1
],
]
gr.Examples(
examples=examples,
inputs = [prompt, model, seed, skip],
fn=generate_images,
outputs=[gallery],
cache_examples=True,
)
gr.Markdown(article)
demo.queue(api_open=False)
demo.launch(show_api=False)