vidimatch / third_party /lanet /data_loader.py
Vincentqyw
fix: cpu running
e430362
raw
history blame
2.38 kB
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from augmentations import ha_augment_sample, resize_sample, spatial_augment_sample
from lanet_utils import to_tensor_sample
def image_transforms(shape, jittering):
def train_transforms(sample):
sample = resize_sample(sample, image_shape=shape)
sample = spatial_augment_sample(sample)
sample = to_tensor_sample(sample)
sample = ha_augment_sample(sample, jitter_paramters=jittering)
return sample
return {"train": train_transforms}
class GetData(Dataset):
def __init__(self, config, transforms=None):
"""
Get the list containing all images and labels.
"""
datafile = open(config.train_txt, "r")
lines = datafile.readlines()
dataset = []
for line in lines:
line = line.rstrip()
data = line.split()
dataset.append(data[0])
self.config = config
self.dataset = dataset
self.root = config.train_root
self.transforms = transforms
def __getitem__(self, index):
"""
Return image'data and its label.
"""
img_path = self.dataset[index]
img_file = self.root + img_path
img = Image.open(img_file)
# image.mode == 'L' means the image is in gray scale
if img.mode == "L":
img_new = Image.new("RGB", img.size)
img_new.paste(img)
sample = {"image": img_new, "idx": index}
else:
sample = {"image": img, "idx": index}
if self.transforms:
sample = self.transforms(sample)
return sample
def __len__(self):
"""
Return the number of all data.
"""
return len(self.dataset)
def get_data_loader(
config,
transforms=None,
sampler=None,
drop_last=True,
):
"""
Return batch data for training.
"""
transforms = image_transforms(shape=config.image_shape, jittering=config.jittering)
dataset = GetData(config, transforms=transforms["train"])
train_loader = DataLoader(
dataset,
batch_size=config.batch_size,
shuffle=config.shuffle,
sampler=sampler,
num_workers=config.num_workers,
pin_memory=config.pin_memory,
drop_last=drop_last,
)
return train_loader