Spaces:
Runtime error
Runtime error
import yaml | |
import torch | |
import torch.nn.functional as F | |
from omegaconf import OmegaConf | |
from einops import rearrange | |
from taming.models.vqgan import VQModel, GumbelVQ | |
from taming.models.cond_transformer import Net2NetTransformer | |
from PIL import Image | |
from torchvision.utils import make_grid, save_image | |
from math import sqrt, log | |
#https://github.com/lucidrains/DALLE-pytorch/blob/main/dalle_pytorch/vae.py#L160 | |
def load_vqgan(config, ckpt_path=None, is_gumbel=False, is_transformer=False): | |
if is_gumbel: | |
model = GumbelVQ(**config.model.params) | |
elif is_transformer: | |
model = Net2NetTransformer(**config.model.params) | |
else: | |
model = VQModel(**config.model.params) | |
if ckpt_path is not None: | |
sd = torch.load(ckpt_path, map_location="cpu")["state_dict"] | |
missing, unexpected = model.load_state_dict(sd, strict=False) | |
if is_transformer: | |
model = model.first_stage_model | |
return model | |
def preprocess_vqgan(x): | |
x = 2.*x - 1. | |
return x | |
def build_vqgan_model(args): | |
config = OmegaConf.load(args.vqgan_config_path) | |
vqgan_model = load_vqgan(config, ckpt_path=args.vqgan_model_path, | |
is_transformer=args.image_tokenizer["is_transformer"], | |
is_gumbel=args.image_tokenizer["is_gumbel"]) | |
return vqgan_model | |
def image_tokenize(vqgan_model, image, is_gumbel=False): | |
image = torch.stack([preprocess_vqgan(image)], 0) | |
with torch.no_grad(): | |
_, _, [_, _, indices] = vqgan_model.encode(image) | |
if is_gumbel: | |
image_tokens = rearrange(indices, 'b h w -> b (h w)', b = 1).flatten().tolist() | |
else: | |
image_tokens = rearrange(indices, '(b n) -> b n', b = 1).flatten().tolist() | |
return image_tokens | |
def image_tokenize_batch(vqgan_model, images, is_gumbel=False): | |
image_src = torch.stack([preprocess_vqgan(image) for image in images], 0) | |
with torch.no_grad(): | |
_, _, [_, _, indices] = vqgan_model.encode(image_src) | |
if is_gumbel: | |
image_tokens = rearrange(indices, 'b h w -> b (h w)', b = len(images)).tolist() | |
else: | |
image_tokens = rearrange(indices, '(b n) -> b n', b = len(images)).tolist() | |
return image_tokens | |
def image_detokenize(vqgan_model, image_tokens, image_vocab_size=1024, is_gumbel=False, save_path=None): | |
with torch.no_grad(): | |
b, n = 1, len(image_tokens) | |
one_hot_indices = F.one_hot(torch.tensor([image_tokens]), num_classes = image_vocab_size).float().to(vqgan_model.device) | |
z = one_hot_indices @ vqgan_model.quantize.embed.weight if is_gumbel \ | |
else (one_hot_indices @ vqgan_model.quantize.embedding.weight) | |
z = rearrange(z, 'b (h w) c -> b c h w', h = int(sqrt(n))).to(vqgan_model.device) | |
img = vqgan_model.decode(z) | |
img = (img.clamp(-1., 1.) + 1) * 0.5 | |
if save_path: | |
save_image(img, save_path, normalize=False) | |
return img | |