import pytorch_lightning as pl from torch.utils.data import DataLoader from torchvision import transforms from src.ss.datasets_signboard_detection.dataset import PoIDataset import src.ss.datasets_signboard_detection.utils as utils class POIDataModule(pl.LightningDataModule): def __init__(self, data_path: str, train_batch_size=8, test_batch_size=8, seed=28): super().__init__() self.data_path = data_path self.train_batch_size = train_batch_size self.test_batch_size = test_batch_size self.seed = seed def prepare_data(self): pass def setup(self, stage="fit"): transform = [transforms.ToTensor()] test_transform = transforms.Compose(transform) if stage == "predict" or stage is None: self.test_dataset = PoIDataset(self.data_path, transforms=test_transform) def predict_dataloader(self): if self.test_dataset is not None: return DataLoader(self.test_dataset, batch_size=self.test_batch_size, shuffle=False, num_workers=16, collate_fn=utils.collate_fn) def _get_name(filepath): images = filepath return images