Spaces:
Runtime error
Runtime error
# copy from https://github.com/Lyken17/Efficient-PyTorch/tools | |
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
import os | |
import os.path as osp | |
import os, sys | |
import os.path as osp | |
from PIL import Image | |
import six | |
import string | |
from lmdbdict import lmdbdict | |
from lmdbdict.methods import DUMPS_FUNC, LOADS_FUNC | |
import pickle | |
import tqdm | |
import numpy as np | |
import argparse | |
import json | |
import torch | |
import torch.utils.data as data | |
from torch.utils.data import DataLoader | |
import csv | |
csv.field_size_limit(sys.maxsize) | |
FIELDNAMES = ['image_id', 'status'] | |
class FolderLMDB(data.Dataset): | |
def __init__(self, db_path, fn_list=None): | |
self.db_path = db_path | |
self.lmdb = lmdbdict(db_path, unsafe=True) | |
self.lmdb._key_dumps = DUMPS_FUNC['ascii'] | |
self.lmdb._value_loads = LOADS_FUNC['identity'] | |
if fn_list is not None: | |
self.length = len(fn_list) | |
self.keys = fn_list | |
else: | |
raise Error | |
def __getitem__(self, index): | |
byteflow = self.lmdb[self.keys[index]] | |
# load image | |
imgbuf = byteflow | |
buf = six.BytesIO() | |
buf.write(imgbuf) | |
buf.seek(0) | |
try: | |
if args.extension == '.npz': | |
feat = np.load(buf)['feat'] | |
else: | |
feat = np.load(buf) | |
except Exception as e: | |
print(self.keys[index], e) | |
return None | |
return feat | |
def __len__(self): | |
return self.length | |
def __repr__(self): | |
return self.__class__.__name__ + ' (' + self.db_path + ')' | |
def make_dataset(dir, extension): | |
images = [] | |
dir = os.path.expanduser(dir) | |
for root, _, fnames in sorted(os.walk(dir)): | |
for fname in sorted(fnames): | |
if has_file_allowed_extension(fname, [extension]): | |
path = os.path.join(root, fname) | |
images.append(path) | |
return images | |
def raw_reader(path): | |
with open(path, 'rb') as f: | |
bin_data = f.read() | |
return bin_data | |
def raw_npz_reader(path): | |
with open(path, 'rb') as f: | |
bin_data = f.read() | |
try: | |
npz_data = np.load(six.BytesIO(bin_data))['feat'] | |
except Exception as e: | |
print(path) | |
npz_data = None | |
print(e) | |
return bin_data, npz_data | |
def raw_npy_reader(path): | |
with open(path, 'rb') as f: | |
bin_data = f.read() | |
try: | |
npy_data = np.load(six.BytesIO(bin_data)) | |
except Exception as e: | |
print(path) | |
npy_data = None | |
print(e) | |
return bin_data, npy_data | |
class Folder(data.Dataset): | |
def __init__(self, root, loader, extension, fn_list=None): | |
super(Folder, self).__init__() | |
self.root = root | |
if fn_list: | |
samples = [os.path.join(root, str(_)+extension) for _ in fn_list] | |
else: | |
samples = make_dataset(self.root, extension) | |
self.loader = loader | |
self.extension = extension | |
self.samples = samples | |
def __getitem__(self, index): | |
""" | |
Args: | |
index (int): Index | |
Returns: | |
tuple: (sample, target) where target is class_index of the target class. | |
""" | |
path = self.samples[index] | |
sample = self.loader(path) | |
return (path.split('/')[-1].split('.')[0],) + sample | |
def __len__(self): | |
return len(self.samples) | |
def folder2lmdb(dpath, fn_list, write_frequency=5000): | |
directory = osp.expanduser(osp.join(dpath)) | |
print("Loading dataset from %s" % directory) | |
if args.extension == '.npz': | |
dataset = Folder(directory, loader=raw_npz_reader, extension='.npz', | |
fn_list=fn_list) | |
else: | |
dataset = Folder(directory, loader=raw_npy_reader, extension='.npy', | |
fn_list=fn_list) | |
data_loader = DataLoader(dataset, num_workers=16, collate_fn=lambda x: x) | |
# lmdb_path = osp.join(dpath, "%s.lmdb" % (directory.split('/')[-1])) | |
lmdb_path = osp.join("%s.lmdb" % (directory)) | |
isdir = os.path.isdir(lmdb_path) | |
print("Generate LMDB to %s" % lmdb_path) | |
db = lmdbdict(lmdb_path, mode='w', key_method='ascii', value_method='identity') | |
tsvfile = open(args.output_file, 'a') | |
writer = csv.DictWriter(tsvfile, delimiter='\t', fieldnames=FIELDNAMES) | |
names = [] | |
all_keys = [] | |
for idx, data in enumerate(tqdm.tqdm(data_loader)): | |
# print(type(data), data) | |
name, byte, npz = data[0] | |
if npz is not None: | |
db[name] = byte | |
all_keys.append(name) | |
names.append({'image_id': name, 'status': str(npz is not None)}) | |
if idx % write_frequency == 0: | |
print("[%d/%d]" % (idx, len(data_loader))) | |
print('writing') | |
db.flush() | |
# write in tsv | |
for name in names: | |
writer.writerow(name) | |
names = [] | |
tsvfile.flush() | |
print('writing finished') | |
# write all keys | |
# txn.put("keys".encode(), pickle.dumps(all_keys)) | |
# # finish iterating through dataset | |
# txn.commit() | |
for name in names: | |
writer.writerow(name) | |
tsvfile.flush() | |
tsvfile.close() | |
print("Flushing database ...") | |
db.flush() | |
del db | |
def parse_args(): | |
""" | |
Parse input arguments | |
""" | |
parser = argparse.ArgumentParser(description='Generate bbox output from a Fast R-CNN network') | |
# parser.add_argument('--json) | |
parser.add_argument('--input_json', default='./data/dataset_coco.json', type=str) | |
parser.add_argument('--output_file', default='.dump_cache.tsv', type=str) | |
parser.add_argument('--folder', default='./data/cocobu_att', type=str) | |
parser.add_argument('--extension', default='.npz', type=str) | |
args = parser.parse_args() | |
return args | |
if __name__ == "__main__": | |
global args | |
args = parse_args() | |
args.output_file += args.folder.split('/')[-1] | |
if args.folder.find('/') > 0: | |
args.output_file = args.folder[:args.folder.rfind('/')+1]+args.output_file | |
print(args.output_file) | |
img_list = json.load(open(args.input_json, 'r'))['images'] | |
fn_list = [str(_['cocoid']) for _ in img_list] | |
found_ids = set() | |
try: | |
with open(args.output_file, 'r') as tsvfile: | |
reader = csv.DictReader(tsvfile, delimiter='\t', fieldnames=FIELDNAMES) | |
for item in reader: | |
if item['status'] == 'True': | |
found_ids.add(item['image_id']) | |
except: | |
pass | |
fn_list = [_ for _ in fn_list if _ not in found_ids] | |
folder2lmdb(args.folder, fn_list) | |
# Test existing. | |
found_ids = set() | |
with open(args.output_file, 'r') as tsvfile: | |
reader = csv.DictReader(tsvfile, delimiter='\t', fieldnames=FIELDNAMES) | |
for item in reader: | |
if item['status'] == 'True': | |
found_ids.add(item['image_id']) | |
folder_dataset = FolderLMDB(args.folder+'.lmdb', list(found_ids)) | |
data_loader = DataLoader(folder_dataset, num_workers=16, collate_fn=lambda x: x) | |
for data in tqdm.tqdm(data_loader): | |
assert data[0] is not None |