from PIL import Image from torch.utils.data import Dataset class ImageRetrievalDataset(Dataset): def __init__(self, data, transform=None): self.data = data self.transform = transform def __len__(self): return len(self.data) def __getitem__(self, idx): input_path, label_path = self.data[:, idx] input_image = Image.open(input_path).convert("RGB") label_image = Image.open(label_path).convert("RGB") if self.transform: input_image = self.transform(input_image) label_image = self.transform(label_image) return input_image, label_image