tiny_clip / src /data.py
sachin's picture
succesful local run
c6fe3c5
import multiprocessing as mp
import pathlib
from typing import Any
import datasets
from PIL import Image
import torch
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torchvision import transforms
from src import config
from src import tokenizer as tk
class CaptionDatset(Dataset):
def __init__(self, dataset: datasets.Dataset, img_path: pathlib.Path) -> None:
self.dataset = dataset
self.img_path = img_path
def __len__(self) -> int:
return len(self.dataset)
def __getitem__(self, idx: int) -> dict[str, Any]:
item = self.dataset[idx]
image = Image.open(self.img_path / item["url"].rsplit("/", 1)[-1]).convert("RGB")
return {"image": image, "caption": item["short_caption"]}
class CollateFn:
def __init__(self, tokenizer: tk.Tokenizer, transform: transforms.Compose):
self.tokenizer = tokenizer
self.transform = transform
def __call__(self, batch: list[dict[str, Any]]) -> dict[str, torch.Tensor]:
stacked_images = torch.stack([self.transform(item["image"]) for item in batch])
tokenized_text = self.tokenizer([item["caption"] for item in batch])
return {
"images": stacked_images,
**tokenized_text,
}
def _get_dataloaders(
train_ds: Dataset,
valid_ds: Dataset,
training_config: config.TrainerConfig,
collate_fn: CollateFn,
) -> tuple[DataLoader, DataLoader]:
common_params = {
"batch_size": training_config.batch_size,
"pin_memory": True,
"num_workers": mp.cpu_count() // 3,
"collate_fn": collate_fn,
}
train_loader = DataLoader(
train_ds,
shuffle=True,
drop_last=True,
**common_params,
)
valid_loader = DataLoader(
valid_ds,
shuffle=False,
drop_last=False,
**common_params,
)
return train_loader, valid_loader
def get_dataset(
transform: transforms.Compose,
tokenizer: tk.Tokenizer,
hyper_parameters: config.TrainerConfig,
) -> tuple[DataLoader, DataLoader]:
dataset: datasets.Dataset = datasets.load_dataset(
hyper_parameters._data_config.dataset, split="train"
) # type: ignore
train_test_dataset = dataset.train_test_split(seed=42, test_size=0.1)
train_ds = CaptionDatset(train_test_dataset["train"], config.IMAGE_DOWNLOAD_PATH)
valid_ds = CaptionDatset(train_test_dataset["test"], config.IMAGE_DOWNLOAD_PATH)
collate_fn = CollateFn(tokenizer, transform)
return _get_dataloaders(
train_ds=train_ds,
valid_ds=valid_ds,
training_config=hyper_parameters,
collate_fn=collate_fn,
)
if __name__ == "__main__":
# do not want to do these imports in general
import os
from tqdm.auto import tqdm
os.environ["TOKENIZERS_PARALLELISM"] = "false"
hyper_parameters = config.TrainerConfig()
transform = transforms.Compose([transforms.Resize((128, 128)), transforms.ToTensor()])
tokenizer = tk.Tokenizer(
hyper_parameters._model_config.text_model, hyper_parameters._model_config.max_len
)
train_dl, valid_dl = get_dataset(transform, tokenizer, hyper_parameters)
batch = next(iter(train_dl))
print({k: v.shape for k, v in batch.items()}) # torch.Size([1, 3, 128, 128])
for batch in tqdm(train_dl):
continue