glyffuser / glyffuser_utils.py
yue-here
updating app and utils
e81087c
raw
history blame
5.97 kB
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms as T
import t5
from torch.nn.utils.rnn import pad_sequence
from PIL import Image
from datasets import load_dataset
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from typing import List, Optional, Tuple, Union
from diffusers.utils.torch_utils import randn_tensor
# Collator adjusted for local dataset
class Collator:
def __init__(self, image_size, text_label, image_label, name, channels):
self.text_label = text_label
self.image_label = image_label
self.name = name
self.channels = channels
self.transform = T.Compose([
T.Resize((image_size, image_size)),
T.ToTensor(),
])
def __call__(self, batch):
texts = []
masks = []
images = []
for item in batch:
try:
# Load image from local file
image_path = 'data/'+item[self.image_label] # Assuming this is a path to the image file
with Image.open(image_path) as img:
image = self.transform(img.convert(self.channels))
except Exception as e:
print(f"Failed to process image {image_path}: {e}")
continue
# Encode the text
text, mask = t5.t5_encode_text(
[item[self.text_label]],
name=self.name,
return_attn_mask=True
)
texts.append(torch.squeeze(text))
masks.append(torch.squeeze(mask))
images.append(image)
if len(texts) == 0:
return None
# Are these strictly necessary?
texts = pad_sequence(texts, True)
masks = pad_sequence(masks, True)
newbatch = []
for i in range(len(texts)):
newbatch.append((images[i], texts[i], masks[i]))
return torch.utils.data.dataloader.default_collate(newbatch)
class GlyffuserPipeline(DiffusionPipeline):
r'''
Pipeline for text-to-image generation from the glyffuser model
Parameters:
unet (['UNet2DConditionModel'])
scheduler (['SchedulerMixin'])
text_encoder (['TextEncoder']) - T5 small
'''
def __init__(self, unet, scheduler):
super().__init__()
self.register_modules(
unet=unet,
scheduler=scheduler,
)
@torch.no_grad()
def __call__(
self,
texts: List[str],
text_encoder: str = "google-t5/t5-small",
batch_size: int = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
num_inference_steps: int = 1000,
output_type: Optional[str] = "pil",
return_dict: bool = True,
) -> Union[ImagePipelineOutput, Tuple]:
'''
Docstring
'''
# Get text embeddings
# Encode the text
# text_embeddings = []
# for text in texts:
# embedding = t5.t5_encode_text(text, name=text_encoder)
# text_embeddings.append(torch.squeeze(embedding))
# text_embeddings = pad_sequence(text_embeddings, True)
batch_size = len(texts)
text_embeddings, masks = t5.t5_encode_text(texts, name=text_encoder, return_attn_mask=True)
# Sample gaussian noise to begin loop
if isinstance(self.unet.config.sample_size, int):
image_shape = (
batch_size,
self.unet.config.in_channels,
self.unet.config.sample_size,
self.unet.config.sample_size,
)
else:
image_shape = (batch_size, self.unet.config.in_channels, *self.unet.config.sample_size)
# if self.device.type == "mps": # MPS is apple silicon
# # randn does not work reproducibly on mps
# image = randn_tensor(image_shape, generator=generator)
# image = image.to(self.device)
# else:
image = randn_tensor(image_shape, generator=generator, device=self.device)
# set step values
self.scheduler.set_timesteps(num_inference_steps)
for t in self.progress_bar(self.scheduler.timesteps):
# 1. predict noise model_output
model_output = self.unet(
image,
t,
encoder_hidden_states=text_embeddings, # Add text encoding input
encoder_attention_mask=masks, # Add attention mask
return_dict=False
)[0] # <-- sample is an attribute of the BaseOutClass of type torch.FloatTensor
# 2. compute previous image: x_t -> x_t-1
image = self.scheduler.step(model_output, t, image, generator=generator, return_dict=False)[0]
# image = (image / 2 + 0.5).clamp(0, 1)
image = image.clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
if output_type == "pil":
image = self.numpy_to_pil(image)
if not return_dict:
return (image,)
return ImagePipelineOutput(images=image)
def make_grid(images, rows, cols):
w, h = images[0].size
grid = Image.new('RGB', size=(cols*w, rows*h))
for i, image in enumerate(images):
grid.paste(image, box=(i%cols*w, i//cols*h))
return grid
def evaluate(config, epoch, texts, pipeline):
images = pipeline(
texts,
batch_size = config.eval_batch_size,
generator=torch.Generator(device='cpu').manual_seed(config.seed), # Generator must be on CPU for sampling during training
).images
# Make a grid out of the images
image_grid = make_grid(images, rows=4, cols=4)
# Save the images
test_dir = os.path.join(config.output_dir, "samples")
os.makedirs(test_dir, exist_ok=True)
image_grid.save(f"{test_dir}/{epoch:04d}.png")