Jordan Legg commited on
Commit
242b4ef
β€’
1 Parent(s): 3d05f5b

mixed precision

Browse files
Files changed (1) hide show
  1. app.py +45 -163
app.py CHANGED
@@ -5,181 +5,63 @@ import spaces
5
  import torch
6
  from diffusers import FluxPipeline
7
 
8
- # Check for CUDA and set device
 
 
 
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
- print(f"Using device: {device}")
11
 
12
  # Load the model
13
- pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.float16)
14
- pipe = pipe.to(device)
 
 
15
 
16
- # Convert text encoders to full precision
17
- pipe.text_encoder = pipe.text_encoder.to(torch.float32)
18
- if hasattr(pipe, 'text_encoder_2'):
19
- pipe.text_encoder_2 = pipe.text_encoder_2.to(torch.float32)
20
 
21
  MAX_SEED = np.iinfo(np.int32).max
22
  MAX_IMAGE_SIZE = 2048
23
 
24
  @spaces.GPU()
25
  def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=4, progress=gr.Progress(track_tqdm=True)):
26
- try:
27
- if randomize_seed:
28
- seed = random.randint(0, MAX_SEED)
29
- generator = torch.Generator(device=device).manual_seed(seed)
30
-
31
- # Use full precision for text encoding
32
- with torch.no_grad():
33
- text_inputs = pipe.tokenizer(prompt, return_tensors="pt").to(device)
34
- text_embeddings = pipe.text_encoder(text_inputs.input_ids)[0]
35
-
36
- # Use mixed precision for the rest of the pipeline
37
- with torch.autocast(device_type=device, dtype=torch.float16):
38
- image = pipe(
39
- prompt_embeds=text_embeddings,
40
- width=width,
41
- height=height,
42
- num_inference_steps=num_inference_steps,
43
- generator=generator,
44
- guidance_scale=0.0
45
- ).images[0]
46
-
47
- return image, seed
48
- except Exception as e:
49
- print(f"Error during inference: {e}")
50
- return None, seed
51
-
52
- examples = [
53
- "a tiny astronaut hatching from an egg on the moon",
54
- "a cat holding a sign that says hello world",
55
- "an anime illustration of a wiener schnitzel",
56
- ]
57
-
58
- css = """
59
- #col-container {
60
- margin: 0 auto;
61
- max-width: 720px;
62
- }
63
- .container {
64
- margin: 0 auto;
65
- padding: 20px;
66
- border-radius: 10px;
67
- background-color: #f0f0f0;
68
- }
69
- .title {
70
- text-align: center;
71
- color: #2c3e50;
72
- margin-bottom: 20px;
73
- }
74
- .subtitle {
75
- text-align: center;
76
- color: #34495e;
77
- margin-bottom: 30px;
78
- }
79
- .speed-info {
80
- background-color: #e74c3c;
81
- color: white;
82
- padding: 10px;
83
- border-radius: 5px;
84
- text-align: center;
85
- margin-bottom: 20px;
86
- }
87
- .prompt-container {
88
- display: flex;
89
- gap: 10px;
90
- margin-bottom: 20px;
91
- }
92
- .advanced-settings {
93
- background-color: #ecf0f1;
94
- padding: 15px;
95
- border-radius: 5px;
96
- margin-top: 20px;
97
- }
98
- """
99
 
100
- with gr.Blocks(css=css) as demo:
101
- with gr.Column(elem_id="col-container"):
102
- gr.HTML(
103
- """
104
- <div class="container">
105
- <h1 class="title">FLUX.1 [schnell] - Mixed Precision Edition</h1>
106
- <h3 class="subtitle">12B param rectified flow transformer optimized for maximum inference speed</h3>
107
- <div class="speed-info">
108
- <strong>Mixed Precision Pipeline:</strong> FP32 Text Encoders + FP16 Core for optimal speed and quality
109
- </div>
110
- </div>
111
- """
112
- )
113
-
114
- with gr.Column(elem_id="prompt-container"):
115
- prompt = gr.Text(
116
- label="Enter your prompt",
117
- placeholder="A futuristic cityscape with flying cars",
118
- lines=2
119
- )
120
- run_button = gr.Button("Generate Image", variant="primary")
121
-
122
- result = gr.Image(label="Generated Image")
123
-
124
- with gr.Accordion("Advanced Settings", open=False):
125
- with gr.Column(elem_id="advanced-settings"):
126
- seed = gr.Slider(
127
- label="Seed",
128
- minimum=0,
129
- maximum=MAX_SEED,
130
- step=1,
131
- value=0,
132
- info="Set to 0 for random seed"
133
- )
134
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
135
-
136
- with gr.Row():
137
- width = gr.Slider(
138
- label="Width",
139
- minimum=256,
140
- maximum=MAX_IMAGE_SIZE,
141
- step=32,
142
- value=1024,
143
- )
144
- height = gr.Slider(
145
- label="Height",
146
- minimum=256,
147
- maximum=MAX_IMAGE_SIZE,
148
- step=32,
149
- value=1024,
150
- )
151
-
152
- num_inference_steps = gr.Slider(
153
- label="Number of inference steps",
154
- minimum=1,
155
- maximum=50,
156
- step=1,
157
- value=4,
158
- info="Lower values = faster generation, higher values = potentially better quality"
159
- )
160
-
161
- gr.Markdown(
162
- """
163
- ### About FLUX.1 [schnell]
164
- - Distilled from [FLUX.1 [pro]](https://blackforestlabs.ai/)
165
- - Optimized for 4-step generation
166
- - Mixed precision pipeline for maximum speed
167
-
168
- [[Blog]](https://blackforestlabs.ai/announcing-black-forest-labs/) | [[Model]](https://huggingface.co/black-forest-labs/FLUX.1-schnell)
169
- """
170
- )
171
-
172
- gr.Examples(
173
- examples=examples,
174
- fn=infer,
175
- inputs=[prompt],
176
- outputs=[result, seed],
177
- cache_examples="lazy"
178
- )
179
 
180
- gr.on(
181
- triggers=[run_button.click, prompt.submit],
182
- fn=infer,
183
  inputs=[prompt, seed, randomize_seed, width, height, num_inference_steps],
184
  outputs=[result, seed]
185
  )
 
5
  import torch
6
  from diffusers import FluxPipeline
7
 
8
+ # Enable cuDNN benchmarking for potential performance improvement
9
+ torch.backends.cudnn.benchmark = True
10
+
11
+ # Set up device and data types
12
  device = "cuda" if torch.cuda.is_available() else "cpu"
13
+ DTYPE = torch.float16
14
 
15
  # Load the model
16
+ pipe = FluxPipeline.from_pretrained(
17
+ "black-forest-labs/FLUX.1-schnell",
18
+ torch_dtype=torch.bfloat16,
19
+ )
20
 
21
+ # Configure the pipeline
22
+ pipe.enable_sequential_cpu_offload()
23
+ pipe.vae.enable_tiling()
24
+ pipe = pipe.to(DTYPE)
25
 
26
  MAX_SEED = np.iinfo(np.int32).max
27
  MAX_IMAGE_SIZE = 2048
28
 
29
  @spaces.GPU()
30
  def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=4, progress=gr.Progress(track_tqdm=True)):
31
+ if randomize_seed:
32
+ seed = random.randint(0, MAX_SEED)
33
+ generator = torch.Generator(device=device).manual_seed(seed)
34
+
35
+ image = pipe(
36
+ prompt,
37
+ num_inference_steps=num_inference_steps,
38
+ num_images_per_prompt=1,
39
+ guidance_scale=0.0,
40
+ height=height,
41
+ width=width,
42
+ generator=generator,
43
+ ).images[0]
44
+
45
+ return image, seed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
+ # Gradio interface
48
+ with gr.Blocks() as demo:
49
+ gr.Markdown("# FLUX.1 [schnell] Image Generator")
50
+ with gr.Row():
51
+ with gr.Column():
52
+ prompt = gr.Textbox(label="Prompt")
53
+ run_button = gr.Button("Generate")
54
+ with gr.Column():
55
+ result = gr.Image(label="Generated Image")
56
+ with gr.Accordion("Advanced Settings", open=False):
57
+ seed = gr.Slider(minimum=0, maximum=MAX_SEED, step=1, label="Seed", randomize=True)
58
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
59
+ width = gr.Slider(minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024, label="Width")
60
+ height = gr.Slider(minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024, label="Height")
61
+ num_inference_steps = gr.Slider(minimum=1, maximum=50, step=1, value=4, label="Number of inference steps")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
+ run_button.click(
64
+ infer,
 
65
  inputs=[prompt, seed, randomize_seed, width, height, num_inference_steps],
66
  outputs=[result, seed]
67
  )