Switti / models /pipeline.py
realantonvoronov
remove apply_late_temperature checkbox
b5f551b
import torch
from torchvision.transforms import ToPILImage
from PIL.Image import Image as PILImage
from models.vqvae import VQVAEHF
from models.clip import FrozenCLIPEmbedder
from models.switti import SwittiHF, get_crop_condition
from models.helpers import sample_with_top_k_top_p_, gumbel_softmax_with_rng
class SwittiPipeline:
vae_path = "yresearch/VQVAE-Switti"
text_encoder_path = "openai/clip-vit-large-patch14"
text_encoder_2_path = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
def __init__(self, switti, vae, text_encoder, text_encoder_2, device,
dtype=torch.bfloat16,
):
self.switti = switti.to(dtype)
self.vae = vae.to(dtype)
self.text_encoder = text_encoder.to(dtype)
self.text_encoder_2 = text_encoder_2.to(dtype)
self.switti.eval()
self.vae.eval()
self.device = device
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, device="cuda"):
switti = SwittiHF.from_pretrained(pretrained_model_name_or_path, device=device).to(device)
vae = VQVAEHF.from_pretrained(cls.vae_path).to(device)
text_encoder = FrozenCLIPEmbedder(cls.text_encoder_path, device=device)
text_encoder_2 = FrozenCLIPEmbedder(cls.text_encoder_2_path, device=device)
return cls(switti, vae, text_encoder, text_encoder_2, device)
@staticmethod
def to_image(tensor):
return [ToPILImage()(
(255 * img.cpu().detach()).to(torch.uint8))
for img in tensor]
def _encode_prompt(self, prompt: str | list[str]):
prompt = [prompt] if isinstance(prompt, str) else prompt
encodings = [
self.text_encoder.encode(prompt),
self.text_encoder_2.encode(prompt),
]
prompt_embeds = torch.concat(
[encoding.last_hidden_state for encoding in encodings], dim=-1
)
pooled_prompt_embeds = encodings[-1].pooler_output
attn_bias = encodings[-1].attn_bias
return prompt_embeds, pooled_prompt_embeds, attn_bias
def encode_prompt(
self,
prompt: str | list[str],
null_prompt: str = "",
encode_null: bool = True,
):
prompt_embeds, pooled_prompt_embeds, attn_bias = self._encode_prompt(prompt)
if encode_null:
B, L, hidden_dim = prompt_embeds.shape
pooled_dim = pooled_prompt_embeds.shape[1]
null_embeds, null_pooled_embeds, null_attn_bias = self._encode_prompt(null_prompt)
null_embeds = null_embeds[:, :L].expand(B, L, hidden_dim).to(prompt_embeds.device)
null_pooled_embeds = null_pooled_embeds.expand(B, pooled_dim).to(pooled_prompt_embeds.device)
null_attn_bias = null_attn_bias[:, :L].expand(B, L).to(attn_bias.device)
prompt_embeds = torch.cat([prompt_embeds, null_embeds], dim=0)
pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, null_pooled_embeds], dim=0)
attn_bias = torch.cat([attn_bias, null_attn_bias], dim=0)
return prompt_embeds, pooled_prompt_embeds, attn_bias
@torch.inference_mode()
def __call__(
self,
prompt: str | list[str],
null_prompt: str = "",
seed: int | None = None,
cfg: float = 4.0,
top_k: int = 400,
top_p: float = 0.95,
more_smooth: bool = False,
return_pil: bool = True,
smooth_start_si: int = 0,
turn_off_cfg_start_si: int = 10,
turn_on_cfg_start_si: int = 0,
image_size: tuple[int, int] = (512, 512),
last_scale_temp: float = 1.,
) -> torch.Tensor | list[PILImage]:
"""
only used for inference, on autoregressive mode
:param prompt: text prompt to generate an image
:param null_prompt: negative prompt for CFG
:param seed: random seed
:param cfg: classifier-free guidance ratio
:param top_k: top-k sampling
:param top_p: top-p sampling
:param more_smooth: sampling using gumbel softmax; only used in visualization, not used in FID/IS benchmarking
:return: if return_pil: list of PIL Images, else: torch.tensor (B, 3, H, W) in [0, 1]
"""
assert not self.switti.training
switti = self.switti
vae = self.vae
vae_quant = self.vae.quantize
if seed is None:
rng = None
else:
rng = torch.Generator(self.device).manual_seed(seed)
context, cond_vector, context_attn_bias = self.encode_prompt(prompt, null_prompt)
B = context.shape[0] // 2
cond_vector = switti.text_pooler(cond_vector)
if switti.use_crop_cond:
crop_coords = get_crop_condition(2 * B * [image_size[0]],
2 * B * [image_size[1]],
).to(cond_vector.device)
crop_embed = switti.crop_embed(crop_coords.view(-1)).reshape(2 * B, switti.D)
crop_cond = switti.crop_proj(crop_embed)
else:
crop_cond = None
sos = cond_BD = cond_vector
lvl_pos = switti.lvl_embed(switti.lvl_1L)
if not switti.rope:
lvl_pos += switti.pos_1LC
next_token_map = (
sos.unsqueeze(1)
+ switti.pos_start.expand(2 * B, switti.first_l, -1)
+ lvl_pos[:, : switti.first_l]
)
cur_L = 0
f_hat = sos.new_zeros(B, switti.Cvae, switti.patch_nums[-1], switti.patch_nums[-1])
for b in switti.blocks:
b.attn.kv_caching(switti.use_ar) # Use KV caching if switti is in the AR mode
b.cross_attn.kv_caching(True)
for si, pn in enumerate(switti.patch_nums): # si: i-th segment
ratio = si / switti.num_stages_minus_1
x_BLC = next_token_map
if switti.rope:
freqs_cis = switti.freqs_cis[:, cur_L : cur_L + pn * pn]
else:
freqs_cis = switti.freqs_cis
if si >= turn_off_cfg_start_si:
apply_smooth = False
x_BLC = x_BLC[:B]
context = context[:B]
context_attn_bias = context_attn_bias[:B]
freqs_cis = freqs_cis[:B]
cond_BD = cond_BD[:B]
if crop_cond is not None:
crop_cond = crop_cond[:B]
for b in switti.blocks:
if b.attn.caching and b.attn.cached_k is not None:
b.attn.cached_k = b.attn.cached_k[:B]
b.attn.cached_v = b.attn.cached_v[:B]
if b.cross_attn.caching and b.cross_attn.cached_k is not None:
b.cross_attn.cached_k = b.cross_attn.cached_k[:B]
b.cross_attn.cached_v = b.cross_attn.cached_v[:B]
else:
apply_smooth = more_smooth
for block in switti.blocks:
x_BLC = block(
x=x_BLC,
cond_BD=cond_BD,
attn_bias=None,
context=context,
context_attn_bias=context_attn_bias,
freqs_cis=freqs_cis,
crop_cond=crop_cond,
)
cur_L += pn * pn
logits_BlV = switti.get_logits(x_BLC, cond_BD)
# Guidance
if si < turn_on_cfg_start_si:
# t = 0, i. e. no guidance
logits_BlV = logits_BlV[:B]
elif si >= turn_on_cfg_start_si and si < turn_off_cfg_start_si:
# default const cfg
t = cfg
logits_BlV = (1 + t) * logits_BlV[:B] - t * logits_BlV[B:]
else:
logits_BlV = logits_BlV / last_scale_temp
if apply_smooth and si >= smooth_start_si:
# not used when evaluating FID/IS/Precision/Recall
gum_t = max(0.27 * (1 - ratio * 0.95), 0.005) # refer to mask-git
idx_Bl = gumbel_softmax_with_rng(
logits_BlV.mul(1 + ratio), tau=gum_t, hard=False, dim=-1, rng=rng,
)
h_BChw = idx_Bl @ vae_quant.embedding.weight.unsqueeze(0)
else:
# defaul nucleus sampling
idx_Bl = sample_with_top_k_top_p_(
logits_BlV, rng=rng, top_k=top_k, top_p=top_p, num_samples=1,
)[:, :, 0]
h_BChw = vae_quant.embedding(idx_Bl)
h_BChw = h_BChw.transpose_(1, 2).reshape(B, switti.Cvae, pn, pn)
f_hat, next_token_map = vae_quant.get_next_autoregressive_input(
si, len(switti.patch_nums), f_hat, h_BChw,
)
if si != switti.num_stages_minus_1: # prepare for next stage
next_token_map = next_token_map.view(B, switti.Cvae, -1).transpose(1, 2)
next_token_map = (
switti.word_embed(next_token_map)
+ lvl_pos[:, cur_L : cur_L + switti.patch_nums[si + 1] ** 2]
)
# double the batch sizes due to CFG
next_token_map = next_token_map.repeat(2, 1, 1)
for b in switti.blocks:
b.attn.kv_caching(False)
b.cross_attn.kv_caching(False)
# de-normalize, from [-1, 1] to [0, 1]
img = vae.fhat_to_img(f_hat).add(1).mul(0.5)
if return_pil:
img = self.to_image(img)
return img