import torch from torchvision.transforms import ToPILImage from PIL.Image import Image as PILImage from models.vqvae import VQVAEHF from models.clip import FrozenCLIPEmbedder from models.switti import SwittiHF, get_crop_condition from models.helpers import sample_with_top_k_top_p_, gumbel_softmax_with_rng class SwittiPipeline: vae_path = "yresearch/VQVAE-Switti" text_encoder_path = "openai/clip-vit-large-patch14" text_encoder_2_path = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k" def __init__(self, switti, vae, text_encoder, text_encoder_2, device): self.switti = switti self.vae = vae self.text_encoder = text_encoder self.text_encoder_2 = text_encoder_2 self.switti.eval() self.vae.eval() self.device = device @classmethod def from_pretrained(cls, pretrained_model_name_or_path, device="cuda"): switti = SwittiHF.from_pretrained(pretrained_model_name_or_path).to(device) vae = VQVAEHF.from_pretrained(cls.vae_path).to(device) text_encoder = FrozenCLIPEmbedder(cls.text_encoder_path, device=device) text_encoder_2 = FrozenCLIPEmbedder(cls.text_encoder_2_path, device=device) return cls(switti, vae, text_encoder, text_encoder_2, device) @staticmethod def to_image(tensor): return [ToPILImage()( (255 * img.cpu().detach()).to(torch.uint8)) for img in tensor] def _encode_prompt(self, prompt: str | list[str]): prompt = [prompt] if isinstance(prompt, str) else prompt encodings = [ self.text_encoder.encode(prompt), self.text_encoder_2.encode(prompt), ] prompt_embeds = torch.concat( [encoding.last_hidden_state for encoding in encodings], dim=-1 ) pooled_prompt_embeds = encodings[-1].pooler_output attn_bias = encodings[-1].attn_bias return prompt_embeds, pooled_prompt_embeds, attn_bias def encode_prompt( self, prompt: str | list[str], null_prompt: str = "", encode_null: bool = True, ): prompt_embeds, pooled_prompt_embeds, attn_bias = self._encode_prompt(prompt) if encode_null: B, L, hidden_dim = prompt_embeds.shape pooled_dim = pooled_prompt_embeds.shape[1] null_embeds, null_pooled_embeds, null_attn_bias = self._encode_prompt(null_prompt) null_embeds = null_embeds[:, :L].expand(B, L, hidden_dim).to(prompt_embeds.device) null_pooled_embeds = null_pooled_embeds.expand(B, pooled_dim).to(pooled_prompt_embeds.device) null_attn_bias = null_attn_bias[:, :L].expand(B, L).to(attn_bias.device) prompt_embeds = torch.cat([prompt_embeds, null_embeds], dim=0) pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, null_pooled_embeds], dim=0) attn_bias = torch.cat([attn_bias, null_attn_bias], dim=0) return prompt_embeds, pooled_prompt_embeds, attn_bias @torch.inference_mode() def __call__( self, prompt: str | list[str], null_prompt: str = "", seed: int | None = None, cfg: float = 4.0, top_k: int = 400, top_p: float = 0.95, more_smooth: bool = False, return_pil: bool = True, smooth_start_si: int = 0, turn_off_cfg_start_si: int = 10, image_size: tuple[int, int] = (512, 512), ) -> torch.Tensor | list[PILImage]: """ only used for inference, on autoregressive mode :param prompt: text prompt to generate an image :param null_prompt: negative prompt for CFG :param seed: random seed :param cfg: classifier-free guidance ratio :param top_k: top-k sampling :param top_p: top-p sampling :param more_smooth: sampling using gumbel softmax; only used in visualization, not used in FID/IS benchmarking :return: if return_pil: list of PIL Images, else: torch.tensor (B, 3, H, W) in [0, 1] """ assert not self.switti.training switti = self.switti vae = self.vae vae_quant = self.vae.quantize if seed is None: rng = None else: switti.rng.manual_seed(seed) rng = switti.rng context, cond_vector, context_attn_bias = self.encode_prompt(prompt, null_prompt) B = context.shape[0] // 2 cond_vector = switti.text_pooler(cond_vector) if switti.use_crop_cond: crop_coords = get_crop_condition(2 * B * [image_size[0]], 2 * B * [image_size[1]], ).to(cond_vector.device) crop_embed = switti.crop_embed(crop_coords.view(-1)).reshape(2 * B, switti.D) crop_cond = switti.crop_proj(crop_embed) else: crop_cond = None sos = cond_BD = cond_vector lvl_pos = switti.lvl_embed(switti.lvl_1L) if not switti.rope: lvl_pos += switti.pos_1LC next_token_map = ( sos.unsqueeze(1) + switti.pos_start.expand(2 * B, switti.first_l, -1) + lvl_pos[:, : switti.first_l] ) cur_L = 0 f_hat = sos.new_zeros(B, switti.Cvae, switti.patch_nums[-1], switti.patch_nums[-1]) for b in switti.blocks: b.attn.kv_caching(switti.use_ar) # Use KV caching if switti is in the AR mode b.cross_attn.kv_caching(True) for si, pn in enumerate(switti.patch_nums): # si: i-th segment ratio = si / switti.num_stages_minus_1 x_BLC = next_token_map if switti.rope: freqs_cis = switti.freqs_cis[:, cur_L : cur_L + pn * pn] else: freqs_cis = switti.freqs_cis if si >= turn_off_cfg_start_si: x_BLC = x_BLC[:B] context = context[:B] context_attn_bias = context_attn_bias[:B] freqs_cis = freqs_cis[:B] cond_BD = cond_BD[:B] if crop_cond is not None: crop_cond = crop_cond[:B] for b in switti.blocks: if b.attn.caching: b.attn.cached_k = b.attn.cached_k[:B] b.attn.cached_v = b.attn.cached_v[:B] if b.cross_attn.caching: b.cross_attn.cached_k = b.cross_attn.cached_k[:B] b.cross_attn.cached_v = b.cross_attn.cached_v[:B] for block in switti.blocks: x_BLC = block( x=x_BLC, cond_BD=cond_BD, attn_bias=None, context=context, context_attn_bias=context_attn_bias, freqs_cis=freqs_cis, crop_cond=crop_cond, ) cur_L += pn * pn logits_BlV = switti.get_logits(x_BLC, cond_BD) # Guidance if si < turn_off_cfg_start_si: t = cfg * ratio logits_BlV = (1 + t) * logits_BlV[:B] - t * logits_BlV[B:] if more_smooth and si >= smooth_start_si: # not used when evaluating FID/IS/Precision/Recall gum_t = max(0.27 * (1 - ratio * 0.95), 0.005) # refer to mask-git idx_Bl = gumbel_softmax_with_rng( logits_BlV.mul(1 + ratio), tau=gum_t, hard=False, dim=-1, rng=rng, ) h_BChw = idx_Bl @ vae_quant.embedding.weight.unsqueeze(0) else: # defaul nucleus sampling idx_Bl = sample_with_top_k_top_p_( logits_BlV, rng=rng, top_k=top_k, top_p=top_p, num_samples=1, )[:, :, 0] h_BChw = vae_quant.embedding(idx_Bl) h_BChw = h_BChw.transpose_(1, 2).reshape(B, switti.Cvae, pn, pn) f_hat, next_token_map = vae_quant.get_next_autoregressive_input( si, len(switti.patch_nums), f_hat, h_BChw, ) if si != switti.num_stages_minus_1: # prepare for next stage next_token_map = next_token_map.view(B, switti.Cvae, -1).transpose(1, 2) next_token_map = ( switti.word_embed(next_token_map) + lvl_pos[:, cur_L : cur_L + switti.patch_nums[si + 1] ** 2] ) # double the batch sizes due to CFG next_token_map = next_token_map.repeat(2, 1, 1) for b in switti.blocks: b.attn.kv_caching(False) b.cross_attn.kv_caching(False) # de-normalize, from [-1, 1] to [0, 1] img = vae.fhat_to_img(f_hat).add(1).mul(0.5) if return_pil: img = self.to_image(img) return img