Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
from torchvision.transforms import ToPILImage | |
from PIL.Image import Image as PILImage | |
from models.vqvae import VQVAEHF | |
from models.clip import FrozenCLIPEmbedder | |
from models.switti import SwittiHF, get_crop_condition | |
from models.helpers import sample_with_top_k_top_p_, gumbel_softmax_with_rng | |
class SwittiPipeline: | |
vae_path = "yresearch/VQVAE-Switti" | |
text_encoder_path = "openai/clip-vit-large-patch14" | |
text_encoder_2_path = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k" | |
def __init__(self, switti, vae, text_encoder, text_encoder_2, device): | |
self.switti = switti | |
self.vae = vae | |
self.text_encoder = text_encoder | |
self.text_encoder_2 = text_encoder_2 | |
self.switti.eval() | |
self.vae.eval() | |
self.device = device | |
def from_pretrained(cls, pretrained_model_name_or_path, device="cuda"): | |
switti = SwittiHF.from_pretrained(pretrained_model_name_or_path).to(device) | |
vae = VQVAEHF.from_pretrained(cls.vae_path).to(device) | |
text_encoder = FrozenCLIPEmbedder(cls.text_encoder_path, device=device) | |
text_encoder_2 = FrozenCLIPEmbedder(cls.text_encoder_2_path, device=device) | |
return cls(switti, vae, text_encoder, text_encoder_2, device) | |
def to_image(tensor): | |
return [ToPILImage()( | |
(255 * img.cpu().detach()).to(torch.uint8)) | |
for img in tensor] | |
def _encode_prompt(self, prompt: str | list[str]): | |
prompt = [prompt] if isinstance(prompt, str) else prompt | |
encodings = [ | |
self.text_encoder.encode(prompt), | |
self.text_encoder_2.encode(prompt), | |
] | |
prompt_embeds = torch.concat( | |
[encoding.last_hidden_state for encoding in encodings], dim=-1 | |
) | |
pooled_prompt_embeds = encodings[-1].pooler_output | |
attn_bias = encodings[-1].attn_bias | |
return prompt_embeds, pooled_prompt_embeds, attn_bias | |
def encode_prompt( | |
self, | |
prompt: str | list[str], | |
null_prompt: str = "", | |
encode_null: bool = True, | |
): | |
prompt_embeds, pooled_prompt_embeds, attn_bias = self._encode_prompt(prompt) | |
if encode_null: | |
B, L, hidden_dim = prompt_embeds.shape | |
pooled_dim = pooled_prompt_embeds.shape[1] | |
null_embeds, null_pooled_embeds, null_attn_bias = self._encode_prompt(null_prompt) | |
null_embeds = null_embeds[:, :L].expand(B, L, hidden_dim).to(prompt_embeds.device) | |
null_pooled_embeds = null_pooled_embeds.expand(B, pooled_dim).to(pooled_prompt_embeds.device) | |
null_attn_bias = null_attn_bias[:, :L].expand(B, L).to(attn_bias.device) | |
prompt_embeds = torch.cat([prompt_embeds, null_embeds], dim=0) | |
pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, null_pooled_embeds], dim=0) | |
attn_bias = torch.cat([attn_bias, null_attn_bias], dim=0) | |
return prompt_embeds, pooled_prompt_embeds, attn_bias | |
def __call__( | |
self, | |
prompt: str | list[str], | |
null_prompt: str = "", | |
seed: int | None = None, | |
cfg: float = 4.0, | |
top_k: int = 400, | |
top_p: float = 0.95, | |
more_smooth: bool = False, | |
return_pil: bool = True, | |
smooth_start_si: int = 0, | |
turn_off_cfg_start_si: int = 10, | |
image_size: tuple[int, int] = (512, 512), | |
) -> torch.Tensor | list[PILImage]: | |
""" | |
only used for inference, on autoregressive mode | |
:param prompt: text prompt to generate an image | |
:param null_prompt: negative prompt for CFG | |
:param seed: random seed | |
:param cfg: classifier-free guidance ratio | |
:param top_k: top-k sampling | |
:param top_p: top-p sampling | |
:param more_smooth: sampling using gumbel softmax; only used in visualization, not used in FID/IS benchmarking | |
:return: if return_pil: list of PIL Images, else: torch.tensor (B, 3, H, W) in [0, 1] | |
""" | |
assert not self.switti.training | |
switti = self.switti | |
vae = self.vae | |
vae_quant = self.vae.quantize | |
if seed is None: | |
rng = None | |
else: | |
switti.rng.manual_seed(seed) | |
rng = switti.rng | |
context, cond_vector, context_attn_bias = self.encode_prompt(prompt, null_prompt) | |
B = context.shape[0] // 2 | |
cond_vector = switti.text_pooler(cond_vector) | |
if switti.use_crop_cond: | |
crop_coords = get_crop_condition(2 * B * [image_size[0]], | |
2 * B * [image_size[1]], | |
).to(cond_vector.device) | |
crop_embed = switti.crop_embed(crop_coords.view(-1)).reshape(2 * B, switti.D) | |
crop_cond = switti.crop_proj(crop_embed) | |
else: | |
crop_cond = None | |
sos = cond_BD = cond_vector | |
lvl_pos = switti.lvl_embed(switti.lvl_1L) | |
if not switti.rope: | |
lvl_pos += switti.pos_1LC | |
next_token_map = ( | |
sos.unsqueeze(1) | |
+ switti.pos_start.expand(2 * B, switti.first_l, -1) | |
+ lvl_pos[:, : switti.first_l] | |
) | |
cur_L = 0 | |
f_hat = sos.new_zeros(B, switti.Cvae, switti.patch_nums[-1], switti.patch_nums[-1]) | |
for b in switti.blocks: | |
b.attn.kv_caching(switti.use_ar) # Use KV caching if switti is in the AR mode | |
b.cross_attn.kv_caching(True) | |
for si, pn in enumerate(switti.patch_nums): # si: i-th segment | |
ratio = si / switti.num_stages_minus_1 | |
x_BLC = next_token_map | |
if switti.rope: | |
freqs_cis = switti.freqs_cis[:, cur_L : cur_L + pn * pn] | |
else: | |
freqs_cis = switti.freqs_cis | |
if si >= turn_off_cfg_start_si: | |
x_BLC = x_BLC[:B] | |
context = context[:B] | |
context_attn_bias = context_attn_bias[:B] | |
freqs_cis = freqs_cis[:B] | |
cond_BD = cond_BD[:B] | |
if crop_cond is not None: | |
crop_cond = crop_cond[:B] | |
for b in switti.blocks: | |
if b.attn.caching: | |
b.attn.cached_k = b.attn.cached_k[:B] | |
b.attn.cached_v = b.attn.cached_v[:B] | |
if b.cross_attn.caching: | |
b.cross_attn.cached_k = b.cross_attn.cached_k[:B] | |
b.cross_attn.cached_v = b.cross_attn.cached_v[:B] | |
for block in switti.blocks: | |
x_BLC = block( | |
x=x_BLC, | |
cond_BD=cond_BD, | |
attn_bias=None, | |
context=context, | |
context_attn_bias=context_attn_bias, | |
freqs_cis=freqs_cis, | |
crop_cond=crop_cond, | |
) | |
cur_L += pn * pn | |
logits_BlV = switti.get_logits(x_BLC, cond_BD) | |
# Guidance | |
if si < turn_off_cfg_start_si: | |
t = cfg * ratio | |
logits_BlV = (1 + t) * logits_BlV[:B] - t * logits_BlV[B:] | |
if more_smooth and si >= smooth_start_si: | |
# not used when evaluating FID/IS/Precision/Recall | |
gum_t = max(0.27 * (1 - ratio * 0.95), 0.005) # refer to mask-git | |
idx_Bl = gumbel_softmax_with_rng( | |
logits_BlV.mul(1 + ratio), tau=gum_t, hard=False, dim=-1, rng=rng, | |
) | |
h_BChw = idx_Bl @ vae_quant.embedding.weight.unsqueeze(0) | |
else: | |
# defaul nucleus sampling | |
idx_Bl = sample_with_top_k_top_p_( | |
logits_BlV, rng=rng, top_k=top_k, top_p=top_p, num_samples=1, | |
)[:, :, 0] | |
h_BChw = vae_quant.embedding(idx_Bl) | |
h_BChw = h_BChw.transpose_(1, 2).reshape(B, switti.Cvae, pn, pn) | |
f_hat, next_token_map = vae_quant.get_next_autoregressive_input( | |
si, len(switti.patch_nums), f_hat, h_BChw, | |
) | |
if si != switti.num_stages_minus_1: # prepare for next stage | |
next_token_map = next_token_map.view(B, switti.Cvae, -1).transpose(1, 2) | |
next_token_map = ( | |
switti.word_embed(next_token_map) | |
+ lvl_pos[:, cur_L : cur_L + switti.patch_nums[si + 1] ** 2] | |
) | |
# double the batch sizes due to CFG | |
next_token_map = next_token_map.repeat(2, 1, 1) | |
for b in switti.blocks: | |
b.attn.kv_caching(False) | |
b.cross_attn.kv_caching(False) | |
# de-normalize, from [-1, 1] to [0, 1] | |
img = vae.fhat_to_img(f_hat).add(1).mul(0.5) | |
if return_pil: | |
img = self.to_image(img) | |
return img | |