|
from __future__ import print_function, division |
|
import os |
|
import torch |
|
import pandas as pd |
|
from skimage import io, transform |
|
import matplotlib.pyplot as plt |
|
from torch.utils.data import Dataset, DataLoader |
|
from torchvision import transforms, utils |
|
|
|
import cv2 |
|
import numpy as np |
|
import ceph |
|
|
|
class MyDataset(Dataset): |
|
def __init__(self, csv_file, root_dir): |
|
""" |
|
Args: |
|
csv_file (string): Path to the csv file with annotations. |
|
root_dir (string): Bucket with all the images, such as s3://faces/ |
|
""" |
|
self.landmarks_frame = pd.read_csv(csv_file) |
|
self.root_dir = root_dir |
|
|
|
def __len__(self): |
|
return len(self.landmarks_frame) |
|
|
|
def __getitem__(self, idx): |
|
img_name = os.path.join(self.root_dir, self.landmarks_frame.iloc[idx, 0]) |
|
s3client = ceph.S3Client() |
|
value = s3client.get(img_name) |
|
if not value: |
|
""" |
|
Picture doesn't exist in ceph, your code here to handle error |
|
""" |
|
return None |
|
img_array = np.fromstring(value, np.uint8) |
|
|
|
|
|
|
|
|
|
string_data = img_array.tostring() |
|
|
|
|
|
|
|
|
|
landmarks = self.landmarks_frame.iloc[idx, 1:].as_matrix() |
|
landmarks = landmarks.astype('float').reshape(-1, 2) |
|
sample = {'image': img_array, 'landmarks': landmarks} |
|
|
|
return sample, string_data |
|
|