VQGAN-f16-16384
Model Description
This is a Flax/JAX implementation of VQGAN, which learns a codebook of context-rich visual parts by leveraging both the use of convolutional methods and transformers. It was introduced in Taming Transformers for High-Resolution Image Synthesis (CVPR paper).
The model allows the encoding of images as a fixed-length sequence of tokens taken from the codebook.
This version of the model uses a reduction factor f=16
and a vocabulary of 13,384
tokens.
As an example of how the reduction factor works, images of size 256x256
are encoded to sequences of 256
tokens: 256/16 * 256/16
. Images of 512x512
would result in sequences of 1024
tokens.
Datasets Used for Training
- ImageNet. We didn't train this model from scratch. Instead, we started from a checkpoint pre-trained on ImageNet.
- Conceptual Captions 3M (CC3M).
- OpenAI subset of YFCC100M.
We fine-tuned on CC3M and YFCC100M to improve the encoding quality of people and faces, which are not very well represented in ImageNet. We used a subset of 2,268,720 images from CC3M and YFCC100M for this purpose.
Training Process
Finetuning was performed in PyTorch using taming-transformers. The full training process and model preparation includes these steps:
- Pre-training on ImageNet. Previously performed. We used this checkpoint.
- Fine-tuning, Part 1.
- Fine-tuning, Part 2 – continuation from Part 1. The final checkpoint was uploaded to boris/vqgan_f16_16384.
- Conversion to JAX, which is the model described in this card.
How to Use
The checkpoint can be loaded using Suraj Patil's implementation of VQModel
.
Example notebook, heavily based in work by Suraj:
Batch encoding using JAX
pmap
, complete example including data loading with PyTorch:
# VQGAN-JAX - pmap encoding HowTo
import numpy as np
# For data loading
import torch
import torchvision.transforms.functional as TF
from torch.utils.data import Dataset, DataLoader
from torchvision.datasets.folder import default_loader
from torchvision.transforms import InterpolationMode
# For data saving
from pathlib import Path
import pandas as pd
from tqdm import tqdm
import jax
from jax import pmap
from vqgan_jax.modeling_flax_vqgan import VQModel
## Params and arguments
# List of paths containing images to encode
image_list = '/sddata/dalle-mini/CC12M/10k.tsv'
output_tsv = 'output.tsv' # Encoded results
batch_size = 64
num_workers = 4 # TPU v3-8s have 96 cores, so feel free to increase this number when necessary
# Load model
model = VQModel.from_pretrained("flax-community/vqgan_f16_16384")
## Data Loading.
# Simple torch Dataset to load images from paths.
# You can use your own pipeline instead.
class ImageDataset(Dataset):
def __init__(self, image_list_path: str, image_size: int, max_items=None):
"""
:param image_list_path: Path to a file containing a list of all images. We assume absolute paths for now.
:param image_size: Image size. Source images will be resized and center-cropped.
:max_items: Limit dataset size for debugging
"""
self.image_list = pd.read_csv(image_list_path, sep='\t', header=None)
if max_items is not None: self.image_list = self.image_list[:max_items]
self.image_size = image_size
def __len__(self):
return len(self.image_list)
def _get_raw_image(self, i):
image_path = Path(self.image_list.iloc[i][0])
return default_loader(image_path)
def resize_image(self, image):
s = min(image.size)
r = self.image_size / s
s = (round(r * image.size[1]), round(r * image.size[0]))
image = TF.resize(image, s, interpolation=InterpolationMode.LANCZOS)
image = TF.center_crop(image, output_size = 2 * [self.image_size])
image = np.expand_dims(np.array(image), axis=0)
return image
def __getitem__(self, i):
image = self._get_raw_image(i)
return self.resize_image(image)
## Encoding
# Encoding function to be parallelized with `pmap`
# Note: images have to be square
def encode(model, batch):
_, indices = model.encode(batch)
return indices
# Alternative: create a batch with num_tpus*batch_size and use `shard` to distribute.
def superbatch_generator(dataloader, num_tpus):
iter_loader = iter(dataloader)
for batch in iter_loader:
superbatch = [batch.squeeze(1)]
try:
for _ in range(num_tpus-1):
batch = next(iter_loader)
if batch is None:
break
# Skip incomplete last batch
if batch.shape[0] == dataloader.batch_size:
superbatch.append(batch.squeeze(1))
except StopIteration:
pass
superbatch = torch.stack(superbatch, axis=0)
yield superbatch
def encode_dataset(dataset, batch_size=32):
dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers)
superbatches = superbatch_generator(dataloader, num_tpus=jax.device_count())
num_tpus = jax.device_count()
dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers)
superbatches = superbatch_generator(dataloader, num_tpus=num_tpus)
p_encoder = pmap(lambda batch: encode(model, batch))
# Save each superbatch to avoid reallocation of buffers as we process them.
# Keep the file open to prevent excessive file seeks.
with open(output_tsv, "w") as file:
iterations = len(dataset) // (batch_size * num_tpus)
for n in tqdm(range(iterations)):
superbatch = next(superbatches)
encoded = p_encoder(superbatch.numpy())
encoded = encoded.reshape(-1, encoded.shape[-1])
# Extract paths from the dataset, save paths and encodings (as string)
start_index = n * batch_size * num_tpus
end_index = (n+1) * batch_size * num_tpus
paths = dataset.image_list[start_index:end_index][0].values
encoded_as_string = list(map(lambda item: np.array2string(item, separator=',', max_line_width=50000, formatter={'int':lambda x: str(x)}), encoded))
batch_df = pd.DataFrame.from_dict({"image_file": paths, "encoding": encoded_as_string})
batch_df.to_csv(file, sep='\t', header=(n==0), index=None)
dataset = ImageDataset(image_list, image_size=256)
encoded_dataset = encode_dataset(dataset, batch_size=batch_size)
Related Models in the Hub
- PyTorch version of VQGAN, trained on the same datasets described here: boris/vqgan_f16_16384.
- DALL·E mini, a Flax/JAX simplified implementation of OpenAI's DALL·E.
Other
This model was successfully used as part of the implementation of DALL·E mini. Our report contains more details on how to leverage it in an image encoding / generation pipeline.