File size: 3,864 Bytes
aaead26
 
 
7fb327a
aaead26
d61123d
7fb327a
aaead26
 
 
7fb327a
 
aaead26
d2d3f28
 
aaead26
7fb327a
fa31c55
7fb327a
 
 
 
 
aaead26
 
7fb327a
 
aaead26
 
d61123d
7fb327a
 
d2d3f28
 
7fb327a
 
d2d3f28
aaead26
 
 
7fb327a
aaead26
 
 
7fb327a
aaead26
 
 
 
 
 
 
 
 
7fb327a
 
d2d3f28
acd07e4
fa31c55
aaead26
fa31c55
7fb327a
fa31c55
 
7fb327a
 
aaead26
acd07e4
7fb327a
fa31c55
7fb327a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d2d3f28
7fb327a
aaead26
7fb327a
 
aaead26
 
 
7fb327a
aaead26
d2d3f28
aaead26
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import gradio as gr
import numpy as np
import random
from huggingface_hub import hf_hub_download

import spaces  # [uncomment to use ZeroGPU]
from diffusers import FluxPipeline
import torch

device = "cuda" if torch.cuda.is_available() else "cpu"
model_repo_id = "black-forest-labs/FLUX.1-schnell"  # Replace to the model you would like to use
torch_dtype = torch.bfloat16

# pipe = FluxPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
# pipe = pipe.to(device)

# load pruned model
pruned_pipe = FluxPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
pruned_pipe.transformer = torch.load(
    hf_hub_download("zhangyang-0123/EcoDiffPrunedModels", "model/flux/flux.pkl"),
    map_location="cpu",
)
pruned_pipe = pruned_pipe.to(device)


MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1024


@spaces.GPU
def generate_images(prompt, seed, steps):
    # Run the model and return images directly
    # g_cpu = torch.Generator("cuda").manual_seed(seed)
    # original_image = pipe(prompt=prompt, generator=g_cpu, num_inference_steps=steps).images[0]
    g_cpu = torch.Generator("cuda").manual_seed(seed)
    ecodiff_image = pruned_pipe(prompt=prompt, generator=g_cpu, num_inference_steps=steps).images[0]
    return ecodiff_image


examples = [
    "A clock tower floating in a sea of clouds",
    "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
    "An astronaut riding a green horse",
    "A delicious ceviche cheesecake slice",
    "A sprawling cyberpunk metropolis at night, with towering skyscrapers emitting neon lights of every color, holographic billboards advertising alien languages",
]

css = """
#col-container {
    margin: 0 auto;
    max-width: 640px;
}
"""

header = """
# 🌱 EcoDiff Pruned FLUX-Schnell (20% Pruning Ratio)

We are not able to host two FLUX models in the same space, so we only show the pruned model here. **👉 [Click here to compare with the Original FLUX Model](https://huggingface.co/spaces/black-forest-labs/FLUX.1-schnell)**.
"""

header_2 = """
<div style="text-align: center; display: flex; justify-content: left; gap: 5px;">
    <a href="https://arxiv.org/abs/2412.02852"><img src="https://img.shields.io/badge/ariXv-Paper-A42C25.svg" alt="arXiv"></a>
    <a href="https://huggingface.co/zhangyang-0123/EcoDiffPrunedModels"><img src="https://img.shields.io/badge/🤗-Model-ffbd45.svg" alt="HuggingFace"></a>
</div>
"""

with gr.Blocks() as demo:
    gr.Markdown(header)
    gr.HTML(header_2)
    with gr.Row():
        prompt = gr.Textbox(
            label="Prompt",
            value="A clock tower floating in a sea of clouds",
            scale=3,
        )
        seed = gr.Number(label="Seed", value=44, precision=0, scale=1)
        steps = gr.Slider(
            label="Number of Steps",
            minimum=1,
            maximum=100,
            value=5,
            step=1,
            scale=1,
        )
        generate_btn = gr.Button("Generate Images")
    gr.Examples(
        examples=[
            "A clock tower floating in a sea of clouds",
            "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
            "An astronaut riding a green horse",
            "A delicious ceviche cheesecake slice",
            "A sprawling cyberpunk metropolis at night, with towering skyscrapers emitting neon lights of every color, holographic billboards advertising alien languages",
        ],
        inputs=[prompt],
    )
    with gr.Row():
        # original_output = gr.Image(label="Original Output")
        ecodiff_output = gr.Image(label="EcoDiff Output")
    gr.on(
        triggers=[generate_btn.click, prompt.submit],
        fn=generate_images,
        inputs=[
            prompt,
            seed,
            steps,
        ],
        outputs=[ecodiff_output],
    )

if __name__ == "__main__":
    demo.launch()