import torch import torch.nn as nn import os from transformers import ( CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast, ) from typing import Any, Callable, Dict, List, Optional, Union class FluxTextEncoderWithMask(nn.Module): def __init__(self, model_path, torch_dtype): super().__init__() # CLIP-G self.tokenizer = CLIPTokenizer.from_pretrained(os.path.join(model_path, 'tokenizer'), torch_dtype=torch_dtype) self.tokenizer_max_length = ( self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 ) self.text_encoder = CLIPTextModel.from_pretrained(os.path.join(model_path, 'text_encoder'), torch_dtype=torch_dtype) # T5 self.tokenizer_2 = T5TokenizerFast.from_pretrained(os.path.join(model_path, 'tokenizer_2')) self.text_encoder_2 = T5EncoderModel.from_pretrained(os.path.join(model_path, 'text_encoder_2'), torch_dtype=torch_dtype) self._freeze() def _freeze(self): for param in self.parameters(): param.requires_grad = False def _get_t5_prompt_embeds( self, prompt: Union[str, List[str]] = None, num_images_per_prompt: int = 1, max_sequence_length: int = 128, device: Optional[torch.device] = None, ): prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) text_inputs = self.tokenizer_2( prompt, padding="max_length", max_length=max_sequence_length, truncation=True, return_length=False, return_overflowing_tokens=False, return_tensors="pt", ) text_input_ids = text_inputs.input_ids prompt_attention_mask = text_inputs.attention_mask prompt_attention_mask = prompt_attention_mask.to(device) prompt_embeds = self.text_encoder_2(text_input_ids.to(device), attention_mask=prompt_attention_mask, output_hidden_states=False)[0] dtype = self.text_encoder_2.dtype prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) _, seq_len, _ = prompt_embeds.shape # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) prompt_attention_mask = prompt_attention_mask.view(batch_size, -1) prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1) return prompt_embeds, prompt_attention_mask def _get_clip_prompt_embeds( self, prompt: Union[str, List[str]], num_images_per_prompt: int = 1, device: Optional[torch.device] = None, ): prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer_max_length, truncation=True, return_overflowing_tokens=False, return_length=False, return_tensors="pt", ) text_input_ids = text_inputs.input_ids prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False) # Use pooled output of CLIPTextModel prompt_embeds = prompt_embeds.pooler_output prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt) prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) return prompt_embeds def encode_prompt(self, prompt, num_images_per_prompt=1, device=None, ): prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) pooled_prompt_embeds = self._get_clip_prompt_embeds( prompt=prompt, device=device, num_images_per_prompt=num_images_per_prompt, ) prompt_embeds, prompt_attention_mask = self._get_t5_prompt_embeds( prompt=prompt, num_images_per_prompt=num_images_per_prompt, device=device, ) return prompt_embeds, prompt_attention_mask, pooled_prompt_embeds def forward(self, input_prompts, device): with torch.no_grad(): prompt_embeds, prompt_attention_mask, pooled_prompt_embeds = self.encode_prompt(input_prompts, 1, device=device) return prompt_embeds, prompt_attention_mask, pooled_prompt_embeds