Spaces:
Sleeping
Sleeping
import gradio as gr | |
from diffusion_lens import get_images | |
import numpy as np | |
MAX_SEED = np.iinfo(np.int32).max | |
# Description | |
title = r""" | |
<h1 align="center">Diffusion Lens: Interpreting Text Encoders in Text-to-Image Pipelines</h1> | |
""" | |
description = r""" | |
<b>Based on the paper <a href='https://arxiv.org/abs/2403.05846' target='_blank'>InstantStyle: Diffusion Lens: Interpreting Text Encoders in Text-to-Image Pipelines</a>.<br> | |
""" | |
article = r""" | |
--- | |
π **Citation** | |
<br> | |
If our work is helpful for your research or applications, please cite us via: | |
```bibtex | |
@article{toker2024diffusion, | |
title={Diffusion Lens: Interpreting Text Encoders in Text-to-Image Pipelines}, | |
author={Toker, Michael and Orgad, Hadas and Ventura, Mor and Arad, Dana and Belinkov, Yonatan}, | |
journal={arXiv preprint arXiv:2403.05846}, | |
year={2024} | |
} | |
} | |
``` | |
π§ **Contact** | |
<br> | |
If you have any questions, please feel free to open an issue or directly reach us out at <b>tok@cs.technuin.ac.il</b>. | |
""" | |
model_num_of_layers = { | |
'Stable Diffusion 1.4': 12, | |
'Stable Diffusion 2.1': 22, | |
} | |
def generate_images(prompt, model, seed): | |
seed = random.randint(0, MAX_SEED) if seed == -1 else seed | |
print('calling diffusion lens with model:', model, 'and seed:', seed) | |
gr.Info('Generating images from intermediate layers..') | |
all_images = [] # Initialize a list to store all images | |
max_num_of_layers = model_num_of_layers[model] | |
for skip_layers in range(max_num_of_layers, -1, -1): | |
# Pass the model and seed to the get_images function | |
images = get_images(prompt, skip_layers=skip_layers, model=model, seed=seed) | |
all_images.append((images[0], f'layer_{12 - skip_layers}')) | |
yield all_images | |
with gr.Blocks() as demo: | |
gr.Markdown(title) | |
gr.Markdown(description) | |
# text_input = gr.Textbox(label="Enter prompt") | |
model_select = gr.Dropdown(label="Select Model", choices=['sd1', 'sd2']) | |
seed_input = gr.Number(label="Enter Seed", value=0) # Default seed set to 0 | |
gallery = gr.Gallery(label="Generated Images", columns=6, rows=2, object_fit="contain", height="auto") | |
# Update the submit function to include the new inputs | |
# text_input.submit(fn=generate_images, inputs=[text_input, model_select, seed_input], outputs=gallery) | |
with gr.Column(): | |
prompt = gr.Textbox( | |
label="Prompt", | |
value="a cat, masterpiece, best quality, high quality", | |
) | |
model = gr.Radio( | |
[ | |
"Stable Diffusion 1.4", | |
"Stable Diffusion 2.1", | |
], | |
value="Stable Diffusion 1.4", | |
label="Model", | |
) | |
seed = gr.Slider( | |
minimum=-1, | |
maximum=MAX_SEED, | |
value=-1, | |
step=1, | |
label="Seed Value", | |
) | |
inputs = [ | |
prompt, | |
model, | |
seed, | |
] | |
outputs = [gallery] | |
generate_button = gr.Button("Generate Image") | |
gr.on( | |
triggers=[ | |
prompt.submit, | |
generate_button.click, | |
seed.input, | |
model.input | |
], | |
fn=generate_images, | |
inputs=inputs, | |
outputs=outputs, | |
show_progress="full", | |
show_api=False, | |
trigger_mode="always_last", | |
) | |
gr.Markdown(article) | |
demo.queue(api_open=False) | |
demo.launch(show_api=False) | |