KingNish commited on
Commit
59efc5a
·
verified ·
1 Parent(s): 21e68e9

Update custom_pipeline.py

Browse files
Files changed (1) hide show
  1. custom_pipeline.py +53 -147
custom_pipeline.py CHANGED
@@ -1,16 +1,8 @@
1
- import numpy as np
2
  import torch
3
- from diffusers.pipelines.flux.pipeline_output import FluxPipeline, FluxPipelineOutput
4
- from typing import List, Union, Optional, Dict, Any, Callable
5
- from diffusers.pipelines.flux.pipeline_flux import calculate_shift, retrieve_timesteps
6
- from diffusers.utils import is_torch_xla_available
7
-
8
- if is_torch_xla_available():
9
- import torch_xla.core.xla_model as xm
10
-
11
- XLA_AVAILABLE = True
12
- else:
13
- XLA_AVAILABLE = False
14
 
15
  # Constants for shift calculation
16
  BASE_SEQ_LEN = 256
@@ -27,7 +19,7 @@ def calculate_timestep_shift(image_seq_len: int) -> float:
27
  return mu
28
 
29
  def prepare_timesteps(
30
- scheduler,
31
  num_inference_steps: Optional[int] = None,
32
  device: Optional[Union[str, torch.device]] = None,
33
  timesteps: Optional[List[int]] = None,
@@ -49,23 +41,24 @@ def prepare_timesteps(
49
  num_inference_steps = len(timesteps)
50
  return timesteps, num_inference_steps
51
 
52
- # FLUX pipeline with CFG and intermediate outputs
53
  class FluxWithCFGPipeline(FluxPipeline):
54
  """
55
- Flux pipeline with Classifier-Free Guidance and the ability to yield
56
- intermediate images during the denoising process with progressively
57
- increasing resolution for faster generation.
58
  """
 
 
 
 
59
  @torch.inference_mode()
60
- def __call__(
61
  self,
62
  prompt: Union[str, List[str]] = None,
63
  prompt_2: Optional[Union[str, List[str]]] = None,
64
- negative_prompt: Optional[Union[str, List[str]]] = None,
65
- negative_prompt_2: Optional[Union[str, List[str]]] = None,
66
  height: Optional[int] = None,
67
  width: Optional[int] = None,
68
- num_inference_steps: int = 28,
69
  timesteps: List[int] = None,
70
  guidance_scale: float = 3.5,
71
  num_images_per_prompt: Optional[int] = 1,
@@ -73,21 +66,16 @@ class FluxWithCFGPipeline(FluxPipeline):
73
  latents: Optional[torch.FloatTensor] = None,
74
  prompt_embeds: Optional[torch.FloatTensor] = None,
75
  pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
76
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
77
- negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
78
  output_type: Optional[str] = "pil",
79
  return_dict: bool = True,
80
  joint_attention_kwargs: Optional[Dict[str, Any]] = None,
81
- callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
82
- callback_on_step_end_tensor_inputs: List[str] = ["latents"],
83
- max_sequence_length: int = 512,
84
- yield_intermediates: bool = False, # New parameter for yielding intermediates
85
  ):
86
-
87
  height = height or self.default_sample_size * self.vae_scale_factor
88
  width = width or self.default_sample_size * self.vae_scale_factor
89
 
90
- # 1. Check inputs. Raise error if not correct
91
  self.check_inputs(
92
  prompt,
93
  prompt_2,
@@ -95,7 +83,6 @@ class FluxWithCFGPipeline(FluxPipeline):
95
  width,
96
  prompt_embeds=prompt_embeds,
97
  pooled_prompt_embeds=pooled_prompt_embeds,
98
- callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
99
  max_sequence_length=max_sequence_length,
100
  )
101
 
@@ -104,23 +91,12 @@ class FluxWithCFGPipeline(FluxPipeline):
104
  self._interrupt = False
105
 
106
  # 2. Define call parameters
107
- if prompt is not None and isinstance(prompt, str):
108
- batch_size = 1
109
- elif prompt is not None and isinstance(prompt, list):
110
- batch_size = len(prompt)
111
- else:
112
- batch_size = prompt_embeds.shape[0]
113
-
114
  device = self._execution_device
115
 
116
- lora_scale = (
117
- self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
118
- )
119
- (
120
- prompt_embeds,
121
- pooled_prompt_embeds,
122
- text_ids,
123
- ) = self.encode_prompt(
124
  prompt=prompt,
125
  prompt_2=prompt_2,
126
  prompt_embeds=prompt_embeds,
@@ -130,20 +106,6 @@ class FluxWithCFGPipeline(FluxPipeline):
130
  max_sequence_length=max_sequence_length,
131
  lora_scale=lora_scale,
132
  )
133
- (
134
- negative_prompt_embeds,
135
- negative_pooled_prompt_embeds,
136
- negative_text_ids,
137
- ) = self.encode_prompt(
138
- prompt=negative_prompt,
139
- prompt_2=negative_prompt_2,
140
- prompt_embeds=negative_prompt_embeds,
141
- pooled_prompt_embeds=negative_pooled_prompt_embeds,
142
- device=device,
143
- num_images_per_prompt=num_images_per_prompt,
144
- max_sequence_length=max_sequence_length,
145
- lora_scale=lora_scale,
146
- )
147
 
148
  # 4. Prepare latent variables
149
  num_channels_latents = self.transformer.config.in_channels // 4
@@ -170,97 +132,41 @@ class FluxWithCFGPipeline(FluxPipeline):
170
  sigmas,
171
  mu=mu,
172
  )
173
- num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
174
  self._num_timesteps = len(timesteps)
175
 
176
- # 6. Denoising loop
177
- with self.progress_bar(total=num_inference_steps) as progress_bar:
178
- for i, t in enumerate(timesteps):
179
- if self.interrupt:
180
- continue
181
-
182
- # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
183
- timestep = t.expand(latents.shape[0]).to(latents.dtype)
184
-
185
- # handle guidance
186
- if self.transformer.config.guidance_embeds:
187
- guidance = torch.tensor([guidance_scale], device=device)
188
- guidance = guidance.expand(latents.shape[0])
189
- else:
190
- guidance = None
191
 
192
- noise_pred_text = self.transformer(
193
- hidden_states=latents,
194
- timestep=timestep / 1000,
195
- guidance=guidance,
196
- pooled_projections=pooled_prompt_embeds,
197
- encoder_hidden_states=prompt_embeds,
198
- txt_ids=text_ids,
199
- img_ids=latent_image_ids,
200
- joint_attention_kwargs=self.joint_attention_kwargs,
201
- return_dict=False,
202
- )[0]
203
-
204
- noise_pred_uncond = self.transformer(
205
- hidden_states=latents,
206
- timestep=timestep / 1000,
207
- guidance=guidance,
208
- pooled_projections=negative_pooled_prompt_embeds,
209
- encoder_hidden_states=negative_prompt_embeds,
210
- txt_ids=negative_text_ids,
211
- img_ids=latent_image_ids,
212
- joint_attention_kwargs=self.joint_attention_kwargs,
213
- return_dict=False,
214
- )[0]
215
-
216
- noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
217
-
218
- # compute the previous noisy sample x_t -> x_t-1
219
- latents_dtype = latents.dtype
220
- latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
221
-
222
- if latents.dtype != latents_dtype:
223
- if torch.backends.mps.is_available():
224
- # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
225
- latents = latents.to(latents_dtype)
226
-
227
- if callback_on_step_end is not None:
228
- callback_kwargs = {}
229
- for k in callback_on_step_end_tensor_inputs:
230
- callback_kwargs[k] = locals()[k]
231
- callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
232
-
233
- latents = callback_outputs.pop("latents", latents)
234
- prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
235
-
236
- # call the callback, if provided
237
- if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
238
- progress_bar.update()
239
-
240
- # Yield intermediate images if requested
241
- if yield_intermediates:
242
- yield self._decode_latents_to_image(latents, height, width, output_type)
243
-
244
- if XLA_AVAILABLE:
245
- xm.mark_step()
246
-
247
- # Final image decoding
248
- if output_type == "latent":
249
- image = latents
250
- else:
251
- image = self._decode_latents_to_image(latents, height, width, output_type)
252
-
253
- # Offload all models
254
- self.maybe_free_model_hooks()
255
-
256
- if not return_dict:
257
- return (image,)
258
-
259
- return FluxPipelineOutput(images=image)
260
-
261
- def _decode_latents_to_image(self, latents, height, width, output_type):
262
  """Decodes the given latents into an image."""
 
263
  latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
264
- latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
265
- image = self.vae.decode(latents, return_dict=False)[0]
266
  return self.image_processor.postprocess(image, output_type=output_type)[0]
 
 
1
  import torch
2
+ import numpy as np
3
+ from diffusers import FluxPipeline, FlowMatchEulerDiscreteScheduler
4
+ from typing import Any, Dict, List, Optional, Union
5
+ from PIL import Image
 
 
 
 
 
 
 
6
 
7
  # Constants for shift calculation
8
  BASE_SEQ_LEN = 256
 
19
  return mu
20
 
21
  def prepare_timesteps(
22
+ scheduler: FlowMatchEulerDiscreteScheduler,
23
  num_inference_steps: Optional[int] = None,
24
  device: Optional[Union[str, torch.device]] = None,
25
  timesteps: Optional[List[int]] = None,
 
41
  num_inference_steps = len(timesteps)
42
  return timesteps, num_inference_steps
43
 
44
+ # FLUX pipeline function
45
  class FluxWithCFGPipeline(FluxPipeline):
46
  """
47
+ Extends the FluxPipeline to yield intermediate images during the denoising process
48
+ with progressively increasing resolution for faster generation.
 
49
  """
50
+ def __init__(self, *args, **kwargs):
51
+ super().__init__(*args, **kwargs)
52
+ self.default_sample_size = 512 # Default sample size from the first pipeline
53
+
54
  @torch.inference_mode()
55
+ def generate_images(
56
  self,
57
  prompt: Union[str, List[str]] = None,
58
  prompt_2: Optional[Union[str, List[str]]] = None,
 
 
59
  height: Optional[int] = None,
60
  width: Optional[int] = None,
61
+ num_inference_steps: int = 4,
62
  timesteps: List[int] = None,
63
  guidance_scale: float = 3.5,
64
  num_images_per_prompt: Optional[int] = 1,
 
66
  latents: Optional[torch.FloatTensor] = None,
67
  prompt_embeds: Optional[torch.FloatTensor] = None,
68
  pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
 
 
69
  output_type: Optional[str] = "pil",
70
  return_dict: bool = True,
71
  joint_attention_kwargs: Optional[Dict[str, Any]] = None,
72
+ max_sequence_length: int = 300,
 
 
 
73
  ):
74
+ """Generates images and yields intermediate results during the denoising process."""
75
  height = height or self.default_sample_size * self.vae_scale_factor
76
  width = width or self.default_sample_size * self.vae_scale_factor
77
 
78
+ # 1. Check inputs
79
  self.check_inputs(
80
  prompt,
81
  prompt_2,
 
83
  width,
84
  prompt_embeds=prompt_embeds,
85
  pooled_prompt_embeds=pooled_prompt_embeds,
 
86
  max_sequence_length=max_sequence_length,
87
  )
88
 
 
91
  self._interrupt = False
92
 
93
  # 2. Define call parameters
94
+ batch_size = 1 if isinstance(prompt, str) else len(prompt)
 
 
 
 
 
 
95
  device = self._execution_device
96
 
97
+ # 3. Encode prompt
98
+ lora_scale = joint_attention_kwargs.get("scale", None) if joint_attention_kwargs is not None else None
99
+ prompt_embeds, pooled_prompt_embeds, text_ids = self.encode_prompt(
 
 
 
 
 
100
  prompt=prompt,
101
  prompt_2=prompt_2,
102
  prompt_embeds=prompt_embeds,
 
106
  max_sequence_length=max_sequence_length,
107
  lora_scale=lora_scale,
108
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
 
110
  # 4. Prepare latent variables
111
  num_channels_latents = self.transformer.config.in_channels // 4
 
132
  sigmas,
133
  mu=mu,
134
  )
 
135
  self._num_timesteps = len(timesteps)
136
 
137
+ # Handle guidance
138
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float16).expand(latents.shape[0]) if self.transformer.config.guidance_embeds else None
 
 
 
 
 
 
 
 
 
 
 
 
 
139
 
140
+ # 6. Denoising loop
141
+ for i, t in enumerate(timesteps):
142
+ if self.interrupt:
143
+ continue
144
+
145
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
146
+
147
+ noise_pred = self.transformer(
148
+ hidden_states=latents,
149
+ timestep=timestep / 1000,
150
+ guidance=guidance,
151
+ pooled_projections=pooled_prompt_embeds,
152
+ encoder_hidden_states=prompt_embeds,
153
+ txt_ids=text_ids,
154
+ img_ids=latent_image_ids,
155
+ joint_attention_kwargs=self.joint_attention_kwargs,
156
+ return_dict=False,
157
+ )[0]
158
+
159
+ # Yield intermediate result
160
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
161
+ torch.cuda.empty_cache()
162
+
163
+ # Final image
164
+ return self._decode_latents_to_image(latents, height, width, output_type)
165
+
166
+ def _decode_latents_to_image(self, latents, height, width, output_type, vae=None):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
  """Decodes the given latents into an image."""
168
+ vae = vae or self.vae
169
  latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
170
+ latents = (latents / vae.config.scaling_factor) + vae.config.shift_factor
171
+ image = vae.decode(latents, return_dict=False)[0]
172
  return self.image_processor.postprocess(image, output_type=output_type)[0]