from __future__ import print_function, division import os import torch import pandas as pd from skimage import io, transform import matplotlib.pyplot as plt from torch.utils.data import Dataset, DataLoader from torchvision import transforms, utils import cv2 import numpy as np import ceph class MyDataset(Dataset): def __init__(self, csv_file, root_dir): """ Args: csv_file (string): Path to the csv file with annotations. root_dir (string): Bucket with all the images, such as s3://faces/ """ self.landmarks_frame = pd.read_csv(csv_file) self.root_dir = root_dir def __len__(self): return len(self.landmarks_frame) def __getitem__(self, idx): img_name = os.path.join(self.root_dir, self.landmarks_frame.iloc[idx, 0]) s3client = ceph.S3Client() value = s3client.get(img_name) if not value: """ Picture doesn't exist in ceph, your code here to handle error """ return None img_array = np.fromstring(value, np.uint8) # load image #img = cvb.img_from_bytes(value) string_data = img_array.tostring() #print(string_data) #print(value) #image = cv2.imdecode(img_array, cv2.CV_LOAD_IMAGE_COLOR) landmarks = self.landmarks_frame.iloc[idx, 1:].as_matrix() landmarks = landmarks.astype('float').reshape(-1, 2) sample = {'image': img_array, 'landmarks': landmarks} return sample, string_data