Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. | |
# Copyright (c) Meta Platforms, Inc. All Rights Reserved | |
# Modified by Feng Liang from https://github.com/MendelXu/zsseg.baseline/blob/master/datasets/prepare_voc_sem_seg.py | |
# Modified by Heeseong Shin from https://github.com/facebookresearch/ov-seg/blob/main/datasets/prepare_voc_sem_seg.py | |
import os | |
import os.path as osp | |
from pathlib import Path | |
import tqdm | |
import numpy as np | |
from PIL import Image | |
clsID_to_trID = { | |
0: 255, | |
1: 0, | |
2: 1, | |
3: 2, | |
4: 3, | |
5: 4, | |
6: 5, | |
7: 6, | |
8: 7, | |
9: 8, | |
10: 9, | |
11: 10, | |
12: 11, | |
13: 12, | |
14: 13, | |
15: 14, | |
16: 15, | |
17: 16, | |
18: 17, | |
19: 18, | |
20: 19, | |
255: 255, | |
} | |
clsID_to_trID_bg = clsID_to_trID.copy() | |
clsID_to_trID_bg[0] = 20 | |
def convert_to_trainID( | |
maskpath, out_mask_dir, is_train, clsID_to_trID=clsID_to_trID, suffix="" | |
): | |
mask = np.array(Image.open(maskpath)) | |
mask_copy = np.ones_like(mask, dtype=np.uint8) * 255 | |
for clsID, trID in clsID_to_trID.items(): | |
mask_copy[mask == clsID] = trID | |
seg_filename = ( | |
osp.join(out_mask_dir, "train" + suffix, osp.basename(maskpath)) | |
if is_train | |
else osp.join(out_mask_dir, "val" + suffix, osp.basename(maskpath)) | |
) | |
if len(np.unique(mask_copy)) == 1 and np.unique(mask_copy)[0] == 255: | |
return | |
Image.fromarray(mask_copy).save(seg_filename, "PNG") | |
if __name__ == "__main__": | |
dataset_dir = Path(os.getenv("DETECTRON2_DATASETS", "datasets")) | |
print('Caution: we only generate the validation set!') | |
voc_path = dataset_dir / "VOCdevkit" / "VOC2012" | |
out_mask_dir = voc_path / "annotations_detectron2" | |
out_mask_dir_bg = voc_path / "annotations_detectron2_bg" | |
#out_image_dir = voc_path / "images_detectron2" | |
for name in ["val"]: | |
os.makedirs((out_mask_dir / name), exist_ok=True) | |
os.makedirs((out_mask_dir_bg / name), exist_ok=True) | |
#os.makedirs((out_image_dir / name), exist_ok=True) | |
val_list = [ | |
osp.join(voc_path, "SegmentationClassAug", f + ".png") | |
for f in np.loadtxt(osp.join(voc_path, "ImageSets/Segmentation/val.txt"), dtype=np.str).tolist() | |
] | |
for file in tqdm.tqdm(val_list): | |
convert_to_trainID(file, out_mask_dir, is_train=False) | |
convert_to_trainID(file, out_mask_dir_bg, is_train=False, clsID_to_trID=clsID_to_trID_bg) |