AlekseyCalvin commited on
Commit
f1385c1
·
verified ·
1 Parent(s): 9f7998f

Upload open_flux_pipeline.py

Browse files
Files changed (1) hide show
  1. open_flux_pipeline.py +222 -0
open_flux_pipeline.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from diffusers.pipelines.flux.pipeline_output import FluxPipeline, FluxPipelineOutput
4
+ from typing import List, Union, Optional, Dict, Any, Callable
5
+ from diffusers.pipelines.flux.pipeline_flux import calculate_shift, retrieve_timesteps
6
+
7
+ from diffusers.utils import is_torch_xla_available
8
+
9
+ if is_torch_xla_available():
10
+ import torch_xla.core.xla_model as xm
11
+
12
+ XLA_AVAILABLE = True
13
+ else:
14
+ XLA_AVAILABLE = False
15
+
16
+ # TODO this is rough. Need to properly stack unconditional or make it optional
17
+ class FluxWithCFGPipeline(FluxPipeline):
18
+ def __call__(
19
+ self,
20
+ prompt: Union[str, List[str]] = None,
21
+ prompt_2: Optional[Union[str, List[str]]] = None,
22
+ negative_prompt: Optional[Union[str, List[str]]] = None,
23
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
24
+ height: Optional[int] = None,
25
+ width: Optional[int] = None,
26
+ num_inference_steps: int = 28,
27
+ timesteps: List[int] = None,
28
+ guidance_scale: float = 7.0,
29
+ num_images_per_prompt: Optional[int] = 1,
30
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
31
+ latents: Optional[torch.FloatTensor] = None,
32
+ prompt_embeds: Optional[torch.FloatTensor] = None,
33
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
34
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
35
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
36
+ output_type: Optional[str] = "pil",
37
+ return_dict: bool = True,
38
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
39
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
40
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
41
+ max_sequence_length: int = 512,
42
+ ):
43
+
44
+ height = height or self.default_sample_size * self.vae_scale_factor
45
+ width = width or self.default_sample_size * self.vae_scale_factor
46
+
47
+ # 1. Check inputs. Raise error if not correct
48
+ self.check_inputs(
49
+ prompt,
50
+ prompt_2,
51
+ height,
52
+ width,
53
+ prompt_embeds=prompt_embeds,
54
+ pooled_prompt_embeds=pooled_prompt_embeds,
55
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
56
+ max_sequence_length=max_sequence_length,
57
+ )
58
+
59
+ self._guidance_scale = guidance_scale
60
+ self._joint_attention_kwargs = joint_attention_kwargs
61
+ self._interrupt = False
62
+
63
+ # 2. Define call parameters
64
+ if prompt is not None and isinstance(prompt, str):
65
+ batch_size = 1
66
+ elif prompt is not None and isinstance(prompt, list):
67
+ batch_size = len(prompt)
68
+ else:
69
+ batch_size = prompt_embeds.shape[0]
70
+
71
+ device = self._execution_device
72
+
73
+ lora_scale = (
74
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
75
+ )
76
+ (
77
+ prompt_embeds,
78
+ pooled_prompt_embeds,
79
+ text_ids,
80
+ ) = self.encode_prompt(
81
+ prompt=prompt,
82
+ prompt_2=prompt_2,
83
+ prompt_embeds=prompt_embeds,
84
+ pooled_prompt_embeds=pooled_prompt_embeds,
85
+ device=device,
86
+ num_images_per_prompt=num_images_per_prompt,
87
+ max_sequence_length=max_sequence_length,
88
+ lora_scale=lora_scale,
89
+ )
90
+ (
91
+ negative_prompt_embeds,
92
+ negative_pooled_prompt_embeds,
93
+ negative_text_ids,
94
+ ) = self.encode_prompt(
95
+ prompt=negative_prompt,
96
+ prompt_2=negative_prompt_2,
97
+ prompt_embeds=negative_prompt_embeds,
98
+ pooled_prompt_embeds=negative_pooled_prompt_embeds,
99
+ device=device,
100
+ num_images_per_prompt=num_images_per_prompt,
101
+ max_sequence_length=max_sequence_length,
102
+ lora_scale=lora_scale,
103
+ )
104
+
105
+ # 4. Prepare latent variables
106
+ num_channels_latents = self.transformer.config.in_channels // 4
107
+ latents, latent_image_ids = self.prepare_latents(
108
+ batch_size * num_images_per_prompt,
109
+ num_channels_latents,
110
+ height,
111
+ width,
112
+ prompt_embeds.dtype,
113
+ device,
114
+ generator,
115
+ latents,
116
+ )
117
+
118
+ # 5. Prepare timesteps
119
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
120
+ image_seq_len = latents.shape[1]
121
+ mu = calculate_shift(
122
+ image_seq_len,
123
+ self.scheduler.config.base_image_seq_len,
124
+ self.scheduler.config.max_image_seq_len,
125
+ self.scheduler.config.base_shift,
126
+ self.scheduler.config.max_shift,
127
+ )
128
+ timesteps, num_inference_steps = retrieve_timesteps(
129
+ self.scheduler,
130
+ num_inference_steps,
131
+ device,
132
+ timesteps,
133
+ sigmas,
134
+ mu=mu,
135
+ )
136
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
137
+ self._num_timesteps = len(timesteps)
138
+
139
+ # 6. Denoising loop
140
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
141
+ for i, t in enumerate(timesteps):
142
+ if self.interrupt:
143
+ continue
144
+
145
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
146
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
147
+
148
+ # handle guidance
149
+ if self.transformer.config.guidance_embeds:
150
+ guidance = torch.tensor([guidance_scale], device=device)
151
+ guidance = guidance.expand(latents.shape[0])
152
+ else:
153
+ guidance = None
154
+
155
+ noise_pred_text = self.transformer(
156
+ hidden_states=latents,
157
+ timestep=timestep / 1000,
158
+ guidance=guidance,
159
+ pooled_projections=pooled_prompt_embeds,
160
+ encoder_hidden_states=prompt_embeds,
161
+ txt_ids=text_ids,
162
+ img_ids=latent_image_ids,
163
+ joint_attention_kwargs=self.joint_attention_kwargs,
164
+ return_dict=False,
165
+ )[0]
166
+
167
+ # todo combine these
168
+ noise_pred_uncond = self.transformer(
169
+ hidden_states=latents,
170
+ timestep=timestep / 1000,
171
+ guidance=guidance,
172
+ pooled_projections=negative_pooled_prompt_embeds,
173
+ encoder_hidden_states=negative_prompt_embeds,
174
+ txt_ids=negative_text_ids,
175
+ img_ids=latent_image_ids,
176
+ joint_attention_kwargs=self.joint_attention_kwargs,
177
+ return_dict=False,
178
+ )[0]
179
+
180
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
181
+
182
+ # compute the previous noisy sample x_t -> x_t-1
183
+ latents_dtype = latents.dtype
184
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
185
+
186
+ if latents.dtype != latents_dtype:
187
+ if torch.backends.mps.is_available():
188
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
189
+ latents = latents.to(latents_dtype)
190
+
191
+ if callback_on_step_end is not None:
192
+ callback_kwargs = {}
193
+ for k in callback_on_step_end_tensor_inputs:
194
+ callback_kwargs[k] = locals()[k]
195
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
196
+
197
+ latents = callback_outputs.pop("latents", latents)
198
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
199
+
200
+ # call the callback, if provided
201
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
202
+ progress_bar.update()
203
+
204
+ if XLA_AVAILABLE:
205
+ xm.mark_step()
206
+
207
+ if output_type == "latent":
208
+ image = latents
209
+
210
+ else:
211
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
212
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
213
+ image = self.vae.decode(latents, return_dict=False)[0]
214
+ image = self.image_processor.postprocess(image, output_type=output_type)
215
+
216
+ # Offload all models
217
+ self.maybe_free_model_hooks()
218
+
219
+ if not return_dict:
220
+ return (image,)
221
+
222
+ return FluxPipelineOutput(images=image)