|
import numpy as np |
|
import pytorch_lightning as pl |
|
import torch |
|
|
|
from datasets import load_dataset |
|
from torch.utils.data import DataLoader, Dataset, random_split, Subset |
|
from transformers import SegformerFeatureExtractor, BatchFeature |
|
|
|
from typing import Optional |
|
|
|
|
|
class SegmentationDataset(Dataset): |
|
"""Image Segmentation Dataset""" |
|
def __init__(self, pixel_values: torch.Tensor, labels: torch.Tensor): |
|
""" |
|
Dataset for image segmentation. |
|
|
|
Parameters |
|
---------- |
|
pixel_values : torch.Tensor |
|
Tensor of shape (N, H, W) containing the pixel values of the images. |
|
labels : torch.Tensor |
|
Tensor of shape (H, W) containing the labels of the images. |
|
""" |
|
self.pixel_values = pixel_values |
|
self.labels = labels |
|
assert pixel_values.shape[0] == labels.shape[0] |
|
self.length = pixel_values.shape[0] |
|
print(f"Created dataset with {self.length} samples") |
|
|
|
|
|
def __len__(self): |
|
return self.length |
|
|
|
|
|
def __getitem__(self, index): |
|
image = self.pixel_values[index] |
|
label = self.labels[index] |
|
|
|
encoded_inputs = BatchFeature({"pixel_values": image, "labels": label}) |
|
|
|
return encoded_inputs |
|
|
|
|
|
class SidewalkSegmentationDataLoader(pl.LightningDataModule): |
|
def __init__( |
|
self, hub_dir: str, batch_size: int, split: Optional[str] = None, |
|
): |
|
super().__init__() |
|
self.hub_dir = hub_dir |
|
self.batch_size = batch_size |
|
self.tokenizer = SegformerFeatureExtractor(reduce_labels=True) |
|
self.dataset = load_dataset(self.hub_dir, split=split) |
|
self.len = len(self.dataset) |
|
|
|
|
|
def tokenize_data(self, *args, **kwargs): |
|
return self.tokenizer(*args, **kwargs) |
|
|
|
|
|
def setup(self, stage: str = None): |
|
encoded_dataset = self.tokenize_data( |
|
images=self.dataset["pixel_values"], segmentation_maps=self.dataset["label"], return_tensors="pt" |
|
) |
|
dataset = SegmentationDataset(encoded_dataset["pixel_values"], encoded_dataset["labels"]) |
|
|
|
indices = np.arange(self.len) |
|
train_indices, val_indices = random_split(indices, [int(self.len * 0.8), int(self.len * 0.2)]) |
|
|
|
self.train_dataset = Subset(dataset, train_indices) |
|
self.val_dataset = Subset(dataset, val_indices) |
|
|
|
|
|
def train_dataloader(self): |
|
return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=12) |
|
|
|
|
|
def val_dataloader(self): |
|
return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=12) |
|
|