File size: 3,336 Bytes
2ec2ebd
 
e2d57c7
2ec2ebd
d415ad5
fe9c201
d415ad5
 
 
 
 
 
7eec65f
d415ad5
 
 
 
 
 
 
 
 
 
 
 
 
 
7eec65f
d415ad5
 
 
 
 
 
 
 
 
 
 
 
e2d57c7
d415ad5
740a0ae
c3eb335
d415ad5
7eec65f
d415ad5
 
 
740a0ae
2ec2ebd
e18f8f6
d415ad5
 
 
 
 
d329d58
 
87aaca5
7eec65f
d415ad5
 
 
 
 
 
 
 
7eec65f
d415ad5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
df0a063
 
d415ad5
304f5cd
 
 
 
 
 
 
 
 
 
 
 
 
c3eb335
d415ad5
c3eb335
3e50201
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import gradio as gr
from diffusion_lens import get_images
import numpy as np

MAX_SEED = np.iinfo(np.int32).max

# Description
title = r"""
<h1 align="center">Diffusion Lens: Interpreting Text Encoders in Text-to-Image Pipelines</h1>
"""

description = r"""
<b>A demo for the paper <a href='https://arxiv.org/abs/2403.05846' target='_blank'>Diffusion Lens: Interpreting Text Encoders in Text-to-Image Pipelines</a>.<br>
"""

article = r"""
---
πŸ“ **Citation**
<br>
If our work is helpful for your research or applications, please cite us via:
```bibtex
@article{toker2024diffusion,
  title={Diffusion Lens: Interpreting Text Encoders in Text-to-Image Pipelines},
  author={Toker, Michael and Orgad, Hadas and Ventura, Mor and Arad, Dana and Belinkov, Yonatan},
  journal={arXiv preprint arXiv:2403.05846},
  year={2024}
}
}```
πŸ“§ **Contact**
<br>
If you have any questions, please feel free to open an issue or directly reach us out at <b>tok@cs.technuin.ac.il</b>.
"""


model_num_of_layers = {
    'Stable Diffusion 1.4': 12,
    'Stable Diffusion 2.1': 22,
}

def generate_images(prompt, model, seed):
    seed = random.randint(0, MAX_SEED) if seed == -1 else seed
    print('calling diffusion lens with model:', model, 'and seed:', seed)
    gr.Info('Generating images from intermediate layers..')
    all_images = []  # Initialize a list to store all images
    max_num_of_layers = model_num_of_layers[model]
    for skip_layers in range(max_num_of_layers - 1, -1, -1):
        # Pass the model and seed to the get_images function
        images = get_images(prompt, skip_layers=skip_layers, model=model, seed=seed)
        all_images.append((images[0], f'layer_{12 - skip_layers}'))
        yield all_images

with gr.Blocks() as demo:
    
    gr.Markdown(title)
    gr.Markdown(description)
    
    # text_input = gr.Textbox(label="Enter prompt")
    # model_select = gr.Dropdown(label="Select Model", choices=['sd1', 'sd2'])
    # seed_input = gr.Number(label="Enter Seed", value=0)  # Default seed set to 0
    with gr.Column():
        gallery = gr.Gallery(label="Generated Images", columns=4, rows=3, object_fit="contain", height="auto")
    # Update the submit function to include the new inputs

    
    # text_input.submit(fn=generate_images, inputs=[text_input, model_select, seed_input], outputs=gallery)

    with gr.Column():
        prompt = gr.Textbox(
            label="Prompt",
            value="A photo of Steve Jobs",
        )

    model = gr.Radio(
        [
            "Stable Diffusion 1.4",
            "Stable Diffusion 2.1",
        ],
        value="Stable Diffusion 1.4",
        label="Model",
    )
    
    seed = gr.Slider(
        minimum=-1,
        maximum=MAX_SEED,
        value=-1,
        step=1,
        label="Seed Value",
    )

    inputs = [
        prompt,
        model,
        seed,
    ]
    outputs = [gallery]

    generate_button = gr.Button("Generate Image")

    gr.on(
        triggers=[
            prompt.submit,
            generate_button.click,
            seed.input,
            model.input
        ],
        fn=generate_images,
        inputs=inputs,
        outputs=outputs,
        show_progress="full",
        show_api=False,
        trigger_mode="always_last",
        )

    gr.Markdown(article)

demo.queue(api_open=False)
demo.launch(show_api=False)