|
import torch |
|
from torch import nn |
|
from torch.utils.data import Dataset, DataLoader |
|
from torchvision import transforms as T |
|
import t5 |
|
from torch.nn.utils.rnn import pad_sequence |
|
|
|
from PIL import Image |
|
|
|
from datasets import load_dataset |
|
|
|
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput |
|
from typing import List, Optional, Tuple, Union |
|
from diffusers.utils.torch_utils import randn_tensor |
|
|
|
|
|
|
|
class Collator: |
|
def __init__(self, image_size, text_label, image_label, name, channels): |
|
self.text_label = text_label |
|
self.image_label = image_label |
|
self.name = name |
|
self.channels = channels |
|
self.transform = T.Compose([ |
|
T.Resize((image_size, image_size)), |
|
T.ToTensor(), |
|
]) |
|
|
|
def __call__(self, batch): |
|
texts = [] |
|
masks = [] |
|
images = [] |
|
for item in batch: |
|
try: |
|
|
|
image_path = 'data/'+item[self.image_label] |
|
with Image.open(image_path) as img: |
|
image = self.transform(img.convert(self.channels)) |
|
except Exception as e: |
|
print(f"Failed to process image {image_path}: {e}") |
|
continue |
|
|
|
|
|
text, mask = t5.t5_encode_text( |
|
[item[self.text_label]], |
|
name=self.name, |
|
return_attn_mask=True |
|
) |
|
texts.append(torch.squeeze(text)) |
|
masks.append(torch.squeeze(mask)) |
|
images.append(image) |
|
|
|
if len(texts) == 0: |
|
return None |
|
|
|
|
|
texts = pad_sequence(texts, True) |
|
masks = pad_sequence(masks, True) |
|
|
|
newbatch = [] |
|
for i in range(len(texts)): |
|
newbatch.append((images[i], texts[i], masks[i])) |
|
|
|
return torch.utils.data.dataloader.default_collate(newbatch) |
|
|
|
|
|
class GlyffuserPipeline(DiffusionPipeline): |
|
r''' |
|
Pipeline for text-to-image generation from the glyffuser model |
|
|
|
Parameters: |
|
unet (['UNet2DConditionModel']) |
|
scheduler (['SchedulerMixin']) |
|
text_encoder (['TextEncoder']) - T5 small |
|
''' |
|
def __init__(self, unet, scheduler): |
|
super().__init__() |
|
self.register_modules( |
|
unet=unet, |
|
scheduler=scheduler, |
|
) |
|
|
|
@torch.no_grad() |
|
def __call__( |
|
self, |
|
texts: List[str], |
|
text_encoder: str = "google-t5/t5-small", |
|
batch_size: int = 1, |
|
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, |
|
num_inference_steps: int = 1000, |
|
output_type: Optional[str] = "pil", |
|
return_dict: bool = True, |
|
) -> Union[ImagePipelineOutput, Tuple]: |
|
''' |
|
Docstring |
|
''' |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
batch_size = len(texts) |
|
|
|
text_embeddings, masks = t5.t5_encode_text(texts, name=text_encoder, return_attn_mask=True) |
|
|
|
|
|
if isinstance(self.unet.config.sample_size, int): |
|
image_shape = ( |
|
batch_size, |
|
self.unet.config.in_channels, |
|
self.unet.config.sample_size, |
|
self.unet.config.sample_size, |
|
) |
|
else: |
|
image_shape = (batch_size, self.unet.config.in_channels, *self.unet.config.sample_size) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
image = randn_tensor(image_shape, generator=generator, device=self.device) |
|
|
|
|
|
self.scheduler.set_timesteps(num_inference_steps) |
|
|
|
for t in self.progress_bar(self.scheduler.timesteps): |
|
|
|
model_output = self.unet( |
|
image, |
|
t, |
|
encoder_hidden_states=text_embeddings, |
|
encoder_attention_mask=masks, |
|
return_dict=False |
|
)[0] |
|
|
|
|
|
image = self.scheduler.step(model_output, t, image, generator=generator, return_dict=False)[0] |
|
|
|
image = (image / 2 + 0.5).clamp(0, 1) |
|
image = image.cpu().permute(0, 2, 3, 1).numpy() |
|
if output_type == "pil": |
|
image = self.numpy_to_pil(image) |
|
|
|
if not return_dict: |
|
return (image,) |
|
|
|
return ImagePipelineOutput(images=image) |
|
|
|
def make_grid(images, rows, cols): |
|
w, h = images[0].size |
|
grid = Image.new('RGB', size=(cols*w, rows*h)) |
|
for i, image in enumerate(images): |
|
grid.paste(image, box=(i%cols*w, i//cols*h)) |
|
return grid |
|
|
|
def evaluate(config, epoch, texts, pipeline): |
|
images = pipeline( |
|
texts, |
|
batch_size = config.eval_batch_size, |
|
generator=torch.Generator(device='cpu').manual_seed(config.seed), |
|
).images |
|
|
|
|
|
image_grid = make_grid(images, rows=4, cols=4) |
|
|
|
|
|
test_dir = os.path.join(config.output_dir, "samples") |
|
os.makedirs(test_dir, exist_ok=True) |
|
image_grid.save(f"{test_dir}/{epoch:04d}.png") |