Spaces:
Sleeping
Sleeping
File size: 3,864 Bytes
aaead26 7fb327a aaead26 d61123d 7fb327a aaead26 7fb327a aaead26 d2d3f28 aaead26 7fb327a fa31c55 7fb327a aaead26 7fb327a aaead26 d61123d 7fb327a d2d3f28 7fb327a d2d3f28 aaead26 7fb327a aaead26 7fb327a aaead26 7fb327a d2d3f28 acd07e4 fa31c55 aaead26 fa31c55 7fb327a fa31c55 7fb327a aaead26 acd07e4 7fb327a fa31c55 7fb327a d2d3f28 7fb327a aaead26 7fb327a aaead26 7fb327a aaead26 d2d3f28 aaead26 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 |
import gradio as gr
import numpy as np
import random
from huggingface_hub import hf_hub_download
import spaces # [uncomment to use ZeroGPU]
from diffusers import FluxPipeline
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
model_repo_id = "black-forest-labs/FLUX.1-schnell" # Replace to the model you would like to use
torch_dtype = torch.bfloat16
# pipe = FluxPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
# pipe = pipe.to(device)
# load pruned model
pruned_pipe = FluxPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
pruned_pipe.transformer = torch.load(
hf_hub_download("zhangyang-0123/EcoDiffPrunedModels", "model/flux/flux.pkl"),
map_location="cpu",
)
pruned_pipe = pruned_pipe.to(device)
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1024
@spaces.GPU
def generate_images(prompt, seed, steps):
# Run the model and return images directly
# g_cpu = torch.Generator("cuda").manual_seed(seed)
# original_image = pipe(prompt=prompt, generator=g_cpu, num_inference_steps=steps).images[0]
g_cpu = torch.Generator("cuda").manual_seed(seed)
ecodiff_image = pruned_pipe(prompt=prompt, generator=g_cpu, num_inference_steps=steps).images[0]
return ecodiff_image
examples = [
"A clock tower floating in a sea of clouds",
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
"An astronaut riding a green horse",
"A delicious ceviche cheesecake slice",
"A sprawling cyberpunk metropolis at night, with towering skyscrapers emitting neon lights of every color, holographic billboards advertising alien languages",
]
css = """
#col-container {
margin: 0 auto;
max-width: 640px;
}
"""
header = """
# 🌱 EcoDiff Pruned FLUX-Schnell (20% Pruning Ratio)
We are not able to host two FLUX models in the same space, so we only show the pruned model here. **👉 [Click here to compare with the Original FLUX Model](https://huggingface.co/spaces/black-forest-labs/FLUX.1-schnell)**.
"""
header_2 = """
<div style="text-align: center; display: flex; justify-content: left; gap: 5px;">
<a href="https://arxiv.org/abs/2412.02852"><img src="https://img.shields.io/badge/ariXv-Paper-A42C25.svg" alt="arXiv"></a>
<a href="https://huggingface.co/zhangyang-0123/EcoDiffPrunedModels"><img src="https://img.shields.io/badge/🤗-Model-ffbd45.svg" alt="HuggingFace"></a>
</div>
"""
with gr.Blocks() as demo:
gr.Markdown(header)
gr.HTML(header_2)
with gr.Row():
prompt = gr.Textbox(
label="Prompt",
value="A clock tower floating in a sea of clouds",
scale=3,
)
seed = gr.Number(label="Seed", value=44, precision=0, scale=1)
steps = gr.Slider(
label="Number of Steps",
minimum=1,
maximum=100,
value=5,
step=1,
scale=1,
)
generate_btn = gr.Button("Generate Images")
gr.Examples(
examples=[
"A clock tower floating in a sea of clouds",
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
"An astronaut riding a green horse",
"A delicious ceviche cheesecake slice",
"A sprawling cyberpunk metropolis at night, with towering skyscrapers emitting neon lights of every color, holographic billboards advertising alien languages",
],
inputs=[prompt],
)
with gr.Row():
# original_output = gr.Image(label="Original Output")
ecodiff_output = gr.Image(label="EcoDiff Output")
gr.on(
triggers=[generate_btn.click, prompt.submit],
fn=generate_images,
inputs=[
prompt,
seed,
steps,
],
outputs=[ecodiff_output],
)
if __name__ == "__main__":
demo.launch()
|