KingNish commited on
Commit
03e077d
·
verified ·
1 Parent(s): deb928c

Update custom_pipeline.py

Browse files
Files changed (1) hide show
  1. custom_pipeline.py +4 -21
custom_pipeline.py CHANGED
@@ -42,7 +42,7 @@ def prepare_timesteps(
42
  return timesteps, num_inference_steps
43
 
44
  # FLUX pipeline function
45
- class FLUXPipelineWithIntermediateOutputs(FluxPipeline):
46
  """
47
  Extends the FluxPipeline to yield intermediate images during the denoising process
48
  with progressively increasing resolution for faster generation.
@@ -56,7 +56,6 @@ class FLUXPipelineWithIntermediateOutputs(FluxPipeline):
56
  width: Optional[int] = None,
57
  num_inference_steps: int = 4,
58
  timesteps: List[int] = None,
59
- guidance_scale: float = 3.5,
60
  num_images_per_prompt: Optional[int] = 1,
61
  generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
62
  latents: Optional[torch.FloatTensor] = None,
@@ -64,8 +63,7 @@ class FLUXPipelineWithIntermediateOutputs(FluxPipeline):
64
  pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
65
  output_type: Optional[str] = "pil",
66
  return_dict: bool = True,
67
- joint_attention_kwargs: Optional[Dict[str, Any]] = None,
68
- max_sequence_length: int = 300,
69
  ):
70
  """Generates images and yields intermediate results during the denoising process."""
71
  height = height or self.default_sample_size * self.vae_scale_factor
@@ -82,16 +80,10 @@ class FLUXPipelineWithIntermediateOutputs(FluxPipeline):
82
  max_sequence_length=max_sequence_length,
83
  )
84
 
85
- self._guidance_scale = guidance_scale
86
- self._joint_attention_kwargs = joint_attention_kwargs
87
- self._interrupt = False
88
-
89
  # 2. Define call parameters
90
  batch_size = 1 if isinstance(prompt, str) else len(prompt)
91
  device = self._execution_device
92
 
93
- # 3. Encode prompt
94
- lora_scale = joint_attention_kwargs.get("scale", None) if joint_attention_kwargs is not None else None
95
  prompt_embeds, pooled_prompt_embeds, text_ids = self.encode_prompt(
96
  prompt=prompt,
97
  prompt_2=prompt_2,
@@ -100,7 +92,6 @@ class FLUXPipelineWithIntermediateOutputs(FluxPipeline):
100
  device=device,
101
  num_images_per_prompt=num_images_per_prompt,
102
  max_sequence_length=max_sequence_length,
103
- lora_scale=lora_scale,
104
  )
105
  # 4. Prepare latent variables
106
  num_channels_latents = self.transformer.config.in_channels // 4
@@ -128,29 +119,21 @@ class FLUXPipelineWithIntermediateOutputs(FluxPipeline):
128
  )
129
  self._num_timesteps = len(timesteps)
130
 
131
- # Handle guidance
132
- guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float16).expand(latents.shape[0]) if self.transformer.config.guidance_embeds else None
133
-
134
  # 6. Denoising loop
135
  for i, t in enumerate(timesteps):
136
- if self.interrupt:
137
- continue
138
-
139
  timestep = t.expand(latents.shape[0]).to(latents.dtype)
140
 
141
  noise_pred = self.transformer(
142
  hidden_states=latents,
143
  timestep=timestep / 1000,
144
- guidance=guidance,
145
  pooled_projections=pooled_prompt_embeds,
146
  encoder_hidden_states=prompt_embeds,
147
  txt_ids=text_ids,
148
  img_ids=latent_image_ids,
149
- joint_attention_kwargs=self.joint_attention_kwargs,
150
  return_dict=False,
151
  )[0]
152
 
153
- # Yield intermediate result
154
  latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
155
  torch.cuda.empty_cache()
156
 
@@ -165,4 +148,4 @@ class FLUXPipelineWithIntermediateOutputs(FluxPipeline):
165
  latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
166
  latents = (latents / vae.config.scaling_factor) + vae.config.shift_factor
167
  image = vae.decode(latents, return_dict=False)[0]
168
- return self.image_processor.postprocess(image, output_type=output_type)[0]
 
42
  return timesteps, num_inference_steps
43
 
44
  # FLUX pipeline function
45
+ class HighSpeedFluxPipeline(FluxPipeline):
46
  """
47
  Extends the FluxPipeline to yield intermediate images during the denoising process
48
  with progressively increasing resolution for faster generation.
 
56
  width: Optional[int] = None,
57
  num_inference_steps: int = 4,
58
  timesteps: List[int] = None,
 
59
  num_images_per_prompt: Optional[int] = 1,
60
  generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
61
  latents: Optional[torch.FloatTensor] = None,
 
63
  pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
64
  output_type: Optional[str] = "pil",
65
  return_dict: bool = True,
66
+ max_sequence_length: int = 128,
 
67
  ):
68
  """Generates images and yields intermediate results during the denoising process."""
69
  height = height or self.default_sample_size * self.vae_scale_factor
 
80
  max_sequence_length=max_sequence_length,
81
  )
82
 
 
 
 
 
83
  # 2. Define call parameters
84
  batch_size = 1 if isinstance(prompt, str) else len(prompt)
85
  device = self._execution_device
86
 
 
 
87
  prompt_embeds, pooled_prompt_embeds, text_ids = self.encode_prompt(
88
  prompt=prompt,
89
  prompt_2=prompt_2,
 
92
  device=device,
93
  num_images_per_prompt=num_images_per_prompt,
94
  max_sequence_length=max_sequence_length,
 
95
  )
96
  # 4. Prepare latent variables
97
  num_channels_latents = self.transformer.config.in_channels // 4
 
119
  )
120
  self._num_timesteps = len(timesteps)
121
 
 
 
 
122
  # 6. Denoising loop
123
  for i, t in enumerate(timesteps):
124
+
 
 
125
  timestep = t.expand(latents.shape[0]).to(latents.dtype)
126
 
127
  noise_pred = self.transformer(
128
  hidden_states=latents,
129
  timestep=timestep / 1000,
 
130
  pooled_projections=pooled_prompt_embeds,
131
  encoder_hidden_states=prompt_embeds,
132
  txt_ids=text_ids,
133
  img_ids=latent_image_ids,
 
134
  return_dict=False,
135
  )[0]
136
 
 
137
  latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
138
  torch.cuda.empty_cache()
139
 
 
148
  latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
149
  latents = (latents / vae.config.scaling_factor) + vae.config.shift_factor
150
  image = vae.decode(latents, return_dict=False)[0]
151
+ return self.image_processor.postprocess(image, output_type=output_type)[0]