|
import os |
|
import random |
|
import json |
|
|
|
|
|
|
|
valid_list_json_path = './megadepth_valid_list.json' |
|
assert os.path.isfile(valid_list_json_path), 'Change to the valid list json' |
|
with open(valid_list_json_path, 'r') as f: |
|
all_list = json.load(f) |
|
|
|
|
|
scene_img_dict = {} |
|
for item in all_list: |
|
if not item[:4] in scene_img_dict: |
|
scene_img_dict[item[:4]] = [] |
|
scene_img_dict[item[:4]].append(item) |
|
|
|
train_split = [] |
|
val_split = [] |
|
test_split = [] |
|
for k in sorted(scene_img_dict.keys()): |
|
if int(k) == 204: |
|
val_split += scene_img_dict[k] |
|
elif int(k) <= 240 and int(k) != 204: |
|
train_split += scene_img_dict[k] |
|
else: |
|
test_split += scene_img_dict[k] |
|
|
|
|
|
with open('megadepth_train.json', 'w') as outfile: |
|
json.dump(sorted(train_split), outfile, indent=4) |
|
with open('megadepth_val.json', 'w') as outfile: |
|
json.dump(sorted(val_split), outfile, indent=4) |
|
with open('megadepth_test.json', 'w') as outfile: |
|
json.dump(sorted(test_split), outfile, indent=4) |
|
|