File size: 3,393 Bytes
2a3a041
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.

import pickle
from tqdm import tqdm
import os
import numpy as np
from PIL import Image
import argparse
import lmdb
from torchvision import transforms


MAX_SIZE = 1e12


def load_and_resize(root, path, imscale):

    transf_list = []
    transf_list.append(transforms.Resize(imscale))
    transf_list.append(transforms.CenterCrop(imscale))
    transform = transforms.Compose(transf_list)

    img = Image.open(os.path.join(root, path[0], path[1], path[2], path[3], path)).convert('RGB')
    img = transform(img)

    return img


def main(args):

    parts = {}
    datasets = {}
    imname2pos = {'train': {}, 'val': {}, 'test': {}}
    for split in ['train', 'val', 'test']:
        datasets[split] = pickle.load(open(os.path.join(args.save_dir, args.suff + 'recipe1m_' + split + '.pkl'), 'rb'))

        parts[split] = lmdb.open(os.path.join(args.save_dir, 'lmdb_'+split), map_size=int(MAX_SIZE))
        with parts[split].begin() as txn:
            present_entries = [key for key, _ in txn.cursor()]
        j = 0
        for i, entry in tqdm(enumerate(datasets[split])):
            impaths = entry['images'][0:5]

            for n, p in enumerate(impaths):
                if n == args.maxnumims:
                    break
                if p.encode() not in present_entries:
                    im = load_and_resize(os.path.join(args.root, 'images', split), p, args.imscale)
                    im = np.array(im).astype(np.uint8)
                    with parts[split].begin(write=True) as txn:
                        txn.put(p.encode(), im)
                imname2pos[split][p] = j
                j += 1
    pickle.dump(imname2pos, open(os.path.join(args.save_dir, 'imname2pos.pkl'), 'wb'))


def test(args):

    imname2pos = pickle.load(open(os.path.join(args.save_dir, 'imname2pos.pkl'), 'rb'))
    paths = imname2pos['val']

    for k, v in paths.items():
        path = k
        break
    image_file = lmdb.open(os.path.join(args.save_dir, 'lmdb_' + 'val'), max_readers=1, readonly=True,
                           lock=False, readahead=False, meminit=False)
    with image_file.begin(write=False) as txn:
        image = txn.get(path.encode())
        image = np.fromstring(image, dtype=np.uint8)
        image = np.reshape(image, (args.imscale, args.imscale, 3))
    image = Image.fromarray(image.astype('uint8'), 'RGB')
    print (np.shape(image))


if __name__ == "__main__":

    parser = argparse.ArgumentParser()
    parser.add_argument('--root', type=str, default='path/to/recipe1m',
                        help='path to the recipe1m dataset')
    parser.add_argument('--save_dir', type=str, default='../data',
                        help='path where the lmdbs will be saved')
    parser.add_argument('--imscale', type=int, default=256,
                        help='size of images (will be rescaled and center cropped)')
    parser.add_argument('--maxnumims', type=int, default=5,
                        help='maximum number of images to allow for each sample')
    parser.add_argument('--suff', type=str, default='',
                        help='id of the vocabulary to use')
    parser.add_argument('--test_only', dest='test_only', action='store_true')
    parser.set_defaults(test_only=False)
    args = parser.parse_args()

    if not args.test_only:
        main(args)
    test(args)