venite's picture
initial
f670afc
raw
history blame
7.6 kB
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# This work is made available under the Nvidia Source Code License-NC.
# To view a copy of this license, check out LICENSE.md
import glob
import os
import lmdb
from tqdm import tqdm
from imaginaire.utils import path
def construct_file_path(root, data_type, sequence, filename, ext):
"""Get file path for our dataset structure."""
return '%s/%s/%s/%s.%s' % (root, data_type, sequence, filename, ext)
def check_and_add(filepath, key, filepaths, keys, remove_missing=False):
r"""Add filepath and key to list of filepaths and keys.
Args:
filepath (str): Filepath to add.
key (str): LMDB key for this filepath.
filepaths (list): List of filepaths added so far.
keys (list): List of keys added so far.
remove_missing (bool): If ``True``, removes missing files, otherwise
raises an error.
Returns:
(int): Size of file at filepath.
"""
if not os.path.exists(filepath):
print(filepath + ' does not exist.')
if remove_missing:
return -1
else:
raise FileNotFoundError(filepath + ' does not exist.')
filepaths.append(filepath)
keys.append(key)
return os.path.getsize(filepath)
def write_entry(txn, key, filepath):
r"""Dump binary contents of file associated with key to LMDB.
Args:
txn: handle to LMDB.
key (str): LMDB key for this filepath.
filepath (str): Filepath to add.
"""
with open(filepath, 'rb') as f:
data = f.read()
txn.put(key.encode('ascii'), data)
def build_lmdb(filepaths, keys, output_filepath, map_size, large):
r"""Write out lmdb containing (key, contents of filepath) to file.
Args:
filepaths (list): List of filepath strings.
keys (list): List of key strings associated with filepaths.
output_filepath (str): Location to write LMDB to.
map_size (int): Size of LMDB.
large (bool): Is the dataset large?
"""
if large:
db = lmdb.open(output_filepath, map_size=map_size, writemap=True)
else:
db = lmdb.open(output_filepath, map_size=map_size)
txn = db.begin(write=True)
print('Writing LMDB to:', output_filepath)
for filepath, key in tqdm(zip(filepaths, keys), total=len(keys)):
write_entry(txn, key, filepath)
txn.commit()
def get_all_filenames_from_list(list_name):
r"""Get all filenames from list.
Args:
list_name (str): Path to filename list.
Returns:
all_filenames (dict): Folder name for key, and filename for values.
"""
with open(list_name, 'rt') as f:
lines = f.readlines()
lines = [line.strip() for line in lines]
all_filenames = dict()
for line in lines:
if '/' in line:
file_str = line.split('/')[0:-1]
folder_name = os.path.join(*file_str)
image_name = line.split('/')[-1].replace('.jpg', '')
else:
folder_name = '.'
image_name = line.replace('.jpg', '')
if folder_name in all_filenames:
all_filenames[folder_name].append(image_name)
else:
all_filenames[folder_name] = [image_name]
return all_filenames
def get_lmdb_data_types(cfg):
r"""Get the data types which should be put in LMDB.
Args:
cfg: Configuration object.
"""
data_types, extensions = [], []
for data_type in cfg.data.input_types:
name = list(data_type.keys())
assert len(name) == 1
name = name[0]
info = data_type[name]
if 'computed_on_the_fly' not in info:
info['computed_on_the_fly'] = False
is_lmdb = not info['computed_on_the_fly']
if not is_lmdb:
continue
ext = info['ext']
data_types.append(name)
extensions.append(ext)
cfg.data.data_types = data_types
cfg.data.extensions = extensions
return cfg
def create_metadata(data_root=None, cfg=None, paired=None, input_list=''):
r"""Main function.
Args:
data_root (str): Location of dataset root.
cfg (object): Loaded config object.
paired (bool): Paired or unpaired data.
input_list (str): Path to filename containing list of inputs.
Returns:
(tuple):
- all_filenames (dict): Key of data type, values with sequences.
- extensions (dict): Extension of each data type.
"""
cfg = get_lmdb_data_types(cfg)
# Get list of all data_types in the dataset.
available_data_types = path.get_immediate_subdirectories(data_root)
print(available_data_types)
required_data_types = cfg.data.data_types
data_exts = cfg.data.extensions
# Find filenames.
assert set(required_data_types).issubset(set(available_data_types)), \
print(set(required_data_types) - set(available_data_types), 'missing')
# Find extensions for each data type.
extensions = {}
for data_type, data_ext in zip(required_data_types, data_exts):
extensions[data_type] = data_ext
print('Data file extensions:', extensions)
if paired:
if input_list != '':
all_filenames = get_all_filenames_from_list(input_list)
else:
# Get list of all sequences in the dataset.
if 'data_keypoint' in required_data_types:
search_dir = 'data_keypoint'
elif 'data_segmaps' in required_data_types:
search_dir = 'data_segmaps'
else:
search_dir = required_data_types[0]
print('Searching in dir: %s' % search_dir)
sequences = path.get_recursive_subdirectories(
os.path.join(data_root, search_dir),
extensions[search_dir])
print('Found %d sequences' % (len(sequences)))
# Get filenames in each sequence.
all_filenames = {}
for sequence in sequences:
folder = '%s/%s/%s/*.%s' % (
data_root, search_dir, sequence,
extensions[search_dir])
filenames = sorted(glob.glob(folder))
filenames = [
os.path.splitext(os.path.basename(filename))[0] for
filename in filenames]
all_filenames[sequence] = filenames
total_filenames = [len(filenames)
for _, filenames in all_filenames.items()]
print('Found %d files' % (sum(total_filenames)))
else:
# Get sequences in each data type.
all_filenames = {}
for data_type in required_data_types:
all_filenames[data_type] = {}
sequences = path.get_recursive_subdirectories(
os.path.join(data_root, data_type), extensions[data_type])
# Get filenames in each sequence.
total_filenames = 0
for sequence in sequences:
folder = '%s/%s/%s/*.%s' % (
data_root, data_type, sequence, extensions[data_type])
filenames = sorted(glob.glob(folder))
filenames = [
os.path.splitext(os.path.basename(filename))[0] for
filename in filenames]
all_filenames[data_type][sequence] = filenames
total_filenames += len(filenames)
print('Data type: %s, Found %d sequences, Found %d files' %
(data_type, len(sequences), total_filenames))
return all_filenames, extensions