AlekseyCalvin commited on
Commit
3bc5343
·
verified ·
1 Parent(s): 020882f

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +7 -14
pipeline.py CHANGED
@@ -266,17 +266,17 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
266
  " the batch size of `prompt`."
267
  )
268
 
269
- negative_clip_prompt_embed, negative_pooled_prompt_embed = self._get_clip_prompt_embeds(
270
  prompt=negative_prompt,
271
  device=device,
272
  num_images_per_prompt=num_images_per_prompt,
273
  )
274
 
275
- t5_negative_prompt_embed, negative_pooled_prompt_2_embed = self._get_t5_prompt_embeds(
276
  prompt=negative_prompt_2,
 
277
  num_images_per_prompt=num_images_per_prompt,
278
  max_sequence_length=max_sequence_length,
279
- device=device,
280
  )
281
 
282
  negative_clip_prompt_embed = torch.nn.functional.pad(
@@ -286,7 +286,7 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
286
 
287
  negative_prompt_embeds = torch.cat([negative_clip_prompt_embed, t5_negative_prompt_embed], dim=-2)
288
  negative_pooled_prompt_embeds = torch.cat(
289
- [negative_pooled_prompt_embed, negative_pooled_prompt_2_embed], dim=-1
290
  )
291
 
292
  if self.text_encoder is not None:
@@ -305,6 +305,9 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
305
  pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
306
  pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
307
 
 
 
 
308
  return prompt_embeds, pooled_prompt_embeds, text_ids, negative_prompt_embeds, negative_pooled_prompt_embeds
309
 
310
  def check_inputs(
@@ -347,16 +350,6 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
347
  raise ValueError(
348
  "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
349
  )
350
- if negative_prompt is not None and negative_prompt_embeds is not None:
351
- raise ValueError(
352
- f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
353
- f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
354
- )
355
- elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
356
- raise ValueError(
357
- f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
358
- f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
359
- )
360
  if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
361
  raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
362
 
 
266
  " the batch size of `prompt`."
267
  )
268
 
269
+ negative_clip_prompt_embed = self._get_clip_prompt_embeds(
270
  prompt=negative_prompt,
271
  device=device,
272
  num_images_per_prompt=num_images_per_prompt,
273
  )
274
 
275
+ t5_negative_prompt_embed = self._get_t5_prompt_embeds(
276
  prompt=negative_prompt_2,
277
+ device=device,
278
  num_images_per_prompt=num_images_per_prompt,
279
  max_sequence_length=max_sequence_length,
 
280
  )
281
 
282
  negative_clip_prompt_embed = torch.nn.functional.pad(
 
286
 
287
  negative_prompt_embeds = torch.cat([negative_clip_prompt_embed, t5_negative_prompt_embed], dim=-2)
288
  negative_pooled_prompt_embeds = torch.cat(
289
+ [negative_clip_prompt_embed, t5_negative_prompt_embed], dim=-1
290
  )
291
 
292
  if self.text_encoder is not None:
 
305
  pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
306
  pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
307
 
308
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
309
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
310
+
311
  return prompt_embeds, pooled_prompt_embeds, text_ids, negative_prompt_embeds, negative_pooled_prompt_embeds
312
 
313
  def check_inputs(
 
350
  raise ValueError(
351
  "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
352
  )
 
 
 
 
 
 
 
 
 
 
353
  if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
354
  raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
355