AlekseyCalvin
commited on
Update pipeline.py
Browse files- pipeline.py +7 -14
pipeline.py
CHANGED
@@ -266,17 +266,17 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
|
|
266 |
" the batch size of `prompt`."
|
267 |
)
|
268 |
|
269 |
-
negative_clip_prompt_embed
|
270 |
prompt=negative_prompt,
|
271 |
device=device,
|
272 |
num_images_per_prompt=num_images_per_prompt,
|
273 |
)
|
274 |
|
275 |
-
t5_negative_prompt_embed
|
276 |
prompt=negative_prompt_2,
|
|
|
277 |
num_images_per_prompt=num_images_per_prompt,
|
278 |
max_sequence_length=max_sequence_length,
|
279 |
-
device=device,
|
280 |
)
|
281 |
|
282 |
negative_clip_prompt_embed = torch.nn.functional.pad(
|
@@ -286,7 +286,7 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
|
|
286 |
|
287 |
negative_prompt_embeds = torch.cat([negative_clip_prompt_embed, t5_negative_prompt_embed], dim=-2)
|
288 |
negative_pooled_prompt_embeds = torch.cat(
|
289 |
-
[
|
290 |
)
|
291 |
|
292 |
if self.text_encoder is not None:
|
@@ -305,6 +305,9 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
|
|
305 |
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
306 |
pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
|
307 |
|
|
|
|
|
|
|
308 |
return prompt_embeds, pooled_prompt_embeds, text_ids, negative_prompt_embeds, negative_pooled_prompt_embeds
|
309 |
|
310 |
def check_inputs(
|
@@ -347,16 +350,6 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
|
|
347 |
raise ValueError(
|
348 |
"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
|
349 |
)
|
350 |
-
if negative_prompt is not None and negative_prompt_embeds is not None:
|
351 |
-
raise ValueError(
|
352 |
-
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
353 |
-
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
354 |
-
)
|
355 |
-
elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
|
356 |
-
raise ValueError(
|
357 |
-
f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
|
358 |
-
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
359 |
-
)
|
360 |
if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
|
361 |
raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
|
362 |
|
|
|
266 |
" the batch size of `prompt`."
|
267 |
)
|
268 |
|
269 |
+
negative_clip_prompt_embed = self._get_clip_prompt_embeds(
|
270 |
prompt=negative_prompt,
|
271 |
device=device,
|
272 |
num_images_per_prompt=num_images_per_prompt,
|
273 |
)
|
274 |
|
275 |
+
t5_negative_prompt_embed = self._get_t5_prompt_embeds(
|
276 |
prompt=negative_prompt_2,
|
277 |
+
device=device,
|
278 |
num_images_per_prompt=num_images_per_prompt,
|
279 |
max_sequence_length=max_sequence_length,
|
|
|
280 |
)
|
281 |
|
282 |
negative_clip_prompt_embed = torch.nn.functional.pad(
|
|
|
286 |
|
287 |
negative_prompt_embeds = torch.cat([negative_clip_prompt_embed, t5_negative_prompt_embed], dim=-2)
|
288 |
negative_pooled_prompt_embeds = torch.cat(
|
289 |
+
[negative_clip_prompt_embed, t5_negative_prompt_embed], dim=-1
|
290 |
)
|
291 |
|
292 |
if self.text_encoder is not None:
|
|
|
305 |
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
306 |
pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
|
307 |
|
308 |
+
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
309 |
+
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
|
310 |
+
|
311 |
return prompt_embeds, pooled_prompt_embeds, text_ids, negative_prompt_embeds, negative_pooled_prompt_embeds
|
312 |
|
313 |
def check_inputs(
|
|
|
350 |
raise ValueError(
|
351 |
"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
|
352 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
353 |
if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
|
354 |
raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
|
355 |
|