# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. import pickle from tqdm import tqdm import os import numpy as np from PIL import Image import argparse import lmdb from torchvision import transforms MAX_SIZE = 1e12 def load_and_resize(root, path, imscale): transf_list = [] transf_list.append(transforms.Resize(imscale)) transf_list.append(transforms.CenterCrop(imscale)) transform = transforms.Compose(transf_list) img = Image.open(os.path.join(root, path[0], path[1], path[2], path[3], path)).convert('RGB') img = transform(img) return img def main(args): parts = {} datasets = {} imname2pos = {'train': {}, 'val': {}, 'test': {}} for split in ['train', 'val', 'test']: datasets[split] = pickle.load(open(os.path.join(args.save_dir, args.suff + 'recipe1m_' + split + '.pkl'), 'rb')) parts[split] = lmdb.open(os.path.join(args.save_dir, 'lmdb_'+split), map_size=int(MAX_SIZE)) with parts[split].begin() as txn: present_entries = [key for key, _ in txn.cursor()] j = 0 for i, entry in tqdm(enumerate(datasets[split])): impaths = entry['images'][0:5] for n, p in enumerate(impaths): if n == args.maxnumims: break if p.encode() not in present_entries: im = load_and_resize(os.path.join(args.root, 'images', split), p, args.imscale) im = np.array(im).astype(np.uint8) with parts[split].begin(write=True) as txn: txn.put(p.encode(), im) imname2pos[split][p] = j j += 1 pickle.dump(imname2pos, open(os.path.join(args.save_dir, 'imname2pos.pkl'), 'wb')) def test(args): imname2pos = pickle.load(open(os.path.join(args.save_dir, 'imname2pos.pkl'), 'rb')) paths = imname2pos['val'] for k, v in paths.items(): path = k break image_file = lmdb.open(os.path.join(args.save_dir, 'lmdb_' + 'val'), max_readers=1, readonly=True, lock=False, readahead=False, meminit=False) with image_file.begin(write=False) as txn: image = txn.get(path.encode()) image = np.fromstring(image, dtype=np.uint8) image = np.reshape(image, (args.imscale, args.imscale, 3)) image = Image.fromarray(image.astype('uint8'), 'RGB') print (np.shape(image)) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('--root', type=str, default='path/to/recipe1m', help='path to the recipe1m dataset') parser.add_argument('--save_dir', type=str, default='../data', help='path where the lmdbs will be saved') parser.add_argument('--imscale', type=int, default=256, help='size of images (will be rescaled and center cropped)') parser.add_argument('--maxnumims', type=int, default=5, help='maximum number of images to allow for each sample') parser.add_argument('--suff', type=str, default='', help='id of the vocabulary to use') parser.add_argument('--test_only', dest='test_only', action='store_true') parser.set_defaults(test_only=False) args = parser.parse_args() if not args.test_only: main(args) test(args)