Spaces:
Configuration error
Configuration error
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. | |
import pickle | |
from tqdm import tqdm | |
import os | |
import numpy as np | |
from PIL import Image | |
import argparse | |
import lmdb | |
from torchvision import transforms | |
MAX_SIZE = 1e12 | |
def load_and_resize(root, path, imscale): | |
transf_list = [] | |
transf_list.append(transforms.Resize(imscale)) | |
transf_list.append(transforms.CenterCrop(imscale)) | |
transform = transforms.Compose(transf_list) | |
img = Image.open(os.path.join(root, path[0], path[1], path[2], path[3], path)).convert('RGB') | |
img = transform(img) | |
return img | |
def main(args): | |
parts = {} | |
datasets = {} | |
imname2pos = {'train': {}, 'val': {}, 'test': {}} | |
for split in ['train', 'val', 'test']: | |
datasets[split] = pickle.load(open(os.path.join(args.save_dir, args.suff + 'recipe1m_' + split + '.pkl'), 'rb')) | |
parts[split] = lmdb.open(os.path.join(args.save_dir, 'lmdb_'+split), map_size=int(MAX_SIZE)) | |
with parts[split].begin() as txn: | |
present_entries = [key for key, _ in txn.cursor()] | |
j = 0 | |
for i, entry in tqdm(enumerate(datasets[split])): | |
impaths = entry['images'][0:5] | |
for n, p in enumerate(impaths): | |
if n == args.maxnumims: | |
break | |
if p.encode() not in present_entries: | |
im = load_and_resize(os.path.join(args.root, 'images', split), p, args.imscale) | |
im = np.array(im).astype(np.uint8) | |
with parts[split].begin(write=True) as txn: | |
txn.put(p.encode(), im) | |
imname2pos[split][p] = j | |
j += 1 | |
pickle.dump(imname2pos, open(os.path.join(args.save_dir, 'imname2pos.pkl'), 'wb')) | |
def test(args): | |
imname2pos = pickle.load(open(os.path.join(args.save_dir, 'imname2pos.pkl'), 'rb')) | |
paths = imname2pos['val'] | |
for k, v in paths.items(): | |
path = k | |
break | |
image_file = lmdb.open(os.path.join(args.save_dir, 'lmdb_' + 'val'), max_readers=1, readonly=True, | |
lock=False, readahead=False, meminit=False) | |
with image_file.begin(write=False) as txn: | |
image = txn.get(path.encode()) | |
image = np.fromstring(image, dtype=np.uint8) | |
image = np.reshape(image, (args.imscale, args.imscale, 3)) | |
image = Image.fromarray(image.astype('uint8'), 'RGB') | |
print (np.shape(image)) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--root', type=str, default='path/to/recipe1m', | |
help='path to the recipe1m dataset') | |
parser.add_argument('--save_dir', type=str, default='../data', | |
help='path where the lmdbs will be saved') | |
parser.add_argument('--imscale', type=int, default=256, | |
help='size of images (will be rescaled and center cropped)') | |
parser.add_argument('--maxnumims', type=int, default=5, | |
help='maximum number of images to allow for each sample') | |
parser.add_argument('--suff', type=str, default='', | |
help='id of the vocabulary to use') | |
parser.add_argument('--test_only', dest='test_only', action='store_true') | |
parser.set_defaults(test_only=False) | |
args = parser.parse_args() | |
if not args.test_only: | |
main(args) | |
test(args) | |