|
|
|
|
|
|
|
|
|
|
|
|
|
import glob |
|
import json |
|
import os |
|
|
|
import numpy as np |
|
import pandas as pd |
|
import torch |
|
|
|
from PIL import Image as PILImage |
|
|
|
try: |
|
from pycocotools import mask as mask_utils |
|
except: |
|
pass |
|
|
|
|
|
class JSONSegmentLoader: |
|
def __init__(self, video_json_path, ann_every=1, frames_fps=24, valid_obj_ids=None): |
|
|
|
self.ann_every = ann_every |
|
|
|
self.valid_obj_ids = valid_obj_ids |
|
with open(video_json_path, "r") as f: |
|
data = json.load(f) |
|
if isinstance(data, list): |
|
self.frame_annots = data |
|
elif isinstance(data, dict): |
|
masklet_field_name = "masklet" if "masklet" in data else "masks" |
|
self.frame_annots = data[masklet_field_name] |
|
if "fps" in data: |
|
if isinstance(data["fps"], list): |
|
annotations_fps = int(data["fps"][0]) |
|
else: |
|
annotations_fps = int(data["fps"]) |
|
assert frames_fps % annotations_fps == 0 |
|
self.ann_every = frames_fps // annotations_fps |
|
else: |
|
raise NotImplementedError |
|
|
|
def load(self, frame_id, obj_ids=None): |
|
assert frame_id % self.ann_every == 0 |
|
rle_mask = self.frame_annots[frame_id // self.ann_every] |
|
|
|
valid_objs_ids = set(range(len(rle_mask))) |
|
if self.valid_obj_ids is not None: |
|
|
|
valid_objs_ids &= set(self.valid_obj_ids) |
|
if obj_ids is not None: |
|
|
|
valid_objs_ids &= set(obj_ids) |
|
valid_objs_ids = sorted(list(valid_objs_ids)) |
|
|
|
|
|
id_2_idx = {} |
|
rle_mask_filtered = [] |
|
for obj_id in valid_objs_ids: |
|
if rle_mask[obj_id] is not None: |
|
id_2_idx[obj_id] = len(rle_mask_filtered) |
|
rle_mask_filtered.append(rle_mask[obj_id]) |
|
else: |
|
id_2_idx[obj_id] = None |
|
|
|
|
|
raw_segments = torch.from_numpy(mask_utils.decode(rle_mask_filtered)).permute( |
|
2, 0, 1 |
|
) |
|
segments = {} |
|
for obj_id in valid_objs_ids: |
|
if id_2_idx[obj_id] is None: |
|
segments[obj_id] = None |
|
else: |
|
idx = id_2_idx[obj_id] |
|
segments[obj_id] = raw_segments[idx] |
|
return segments |
|
|
|
def get_valid_obj_frames_ids(self, num_frames_min=None): |
|
|
|
num_objects = len(self.frame_annots[0]) |
|
|
|
|
|
res = {obj_id: [] for obj_id in range(num_objects)} |
|
|
|
for annot_idx, annot in enumerate(self.frame_annots): |
|
for obj_id in range(num_objects): |
|
if annot[obj_id] is not None: |
|
res[obj_id].append(int(annot_idx * self.ann_every)) |
|
|
|
if num_frames_min is not None: |
|
|
|
for obj_id, valid_frames in list(res.items()): |
|
if len(valid_frames) < num_frames_min: |
|
res.pop(obj_id) |
|
|
|
return res |
|
|
|
|
|
class PalettisedPNGSegmentLoader: |
|
def __init__(self, video_png_root): |
|
""" |
|
SegmentLoader for datasets with masks stored as palettised PNGs. |
|
video_png_root: the folder contains all the masks stored in png |
|
""" |
|
self.video_png_root = video_png_root |
|
|
|
|
|
|
|
png_filenames = os.listdir(self.video_png_root) |
|
self.frame_id_to_png_filename = {} |
|
for filename in png_filenames: |
|
frame_id, _ = os.path.splitext(filename) |
|
self.frame_id_to_png_filename[int(frame_id)] = filename |
|
|
|
def load(self, frame_id): |
|
""" |
|
load the single palettised mask from the disk (path: f'{self.video_png_root}/{frame_id:05d}.png') |
|
Args: |
|
frame_id: int, define the mask path |
|
Return: |
|
binary_segments: dict |
|
""" |
|
|
|
mask_path = os.path.join( |
|
self.video_png_root, self.frame_id_to_png_filename[frame_id] |
|
) |
|
|
|
|
|
masks = PILImage.open(mask_path).convert("P") |
|
masks = np.array(masks) |
|
|
|
object_id = pd.unique(masks.flatten()) |
|
object_id = object_id[object_id != 0] |
|
|
|
|
|
binary_segments = {} |
|
for i in object_id: |
|
bs = masks == i |
|
binary_segments[i] = torch.from_numpy(bs) |
|
|
|
return binary_segments |
|
|
|
def __len__(self): |
|
return |
|
|
|
|
|
class MultiplePNGSegmentLoader: |
|
def __init__(self, video_png_root, single_object_mode=False): |
|
""" |
|
video_png_root: the folder contains all the masks stored in png |
|
single_object_mode: whether to load only a single object at a time |
|
""" |
|
self.video_png_root = video_png_root |
|
self.single_object_mode = single_object_mode |
|
|
|
if self.single_object_mode: |
|
tmp_mask_path = glob.glob(os.path.join(video_png_root, "*.png"))[0] |
|
else: |
|
tmp_mask_path = glob.glob(os.path.join(video_png_root, "*", "*.png"))[0] |
|
tmp_mask = np.array(PILImage.open(tmp_mask_path)) |
|
self.H = tmp_mask.shape[0] |
|
self.W = tmp_mask.shape[1] |
|
if self.single_object_mode: |
|
self.obj_id = ( |
|
int(video_png_root.split("/")[-1]) + 1 |
|
) |
|
else: |
|
self.obj_id = None |
|
|
|
def load(self, frame_id): |
|
if self.single_object_mode: |
|
return self._load_single_png(frame_id) |
|
else: |
|
return self._load_multiple_pngs(frame_id) |
|
|
|
def _load_single_png(self, frame_id): |
|
""" |
|
load single png from the disk (path: f'{self.obj_id}/{frame_id:05d}.png') |
|
Args: |
|
frame_id: int, define the mask path |
|
Return: |
|
binary_segments: dict |
|
""" |
|
mask_path = os.path.join(self.video_png_root, f"{frame_id:05d}.png") |
|
binary_segments = {} |
|
|
|
if os.path.exists(mask_path): |
|
mask = np.array(PILImage.open(mask_path)) |
|
else: |
|
|
|
mask = np.zeros((self.H, self.W), dtype=bool) |
|
binary_segments[self.obj_id] = torch.from_numpy(mask > 0) |
|
return binary_segments |
|
|
|
def _load_multiple_pngs(self, frame_id): |
|
""" |
|
load multiple png masks from the disk (path: f'{obj_id}/{frame_id:05d}.png') |
|
Args: |
|
frame_id: int, define the mask path |
|
Return: |
|
binary_segments: dict |
|
""" |
|
|
|
all_objects = sorted(glob.glob(os.path.join(self.video_png_root, "*"))) |
|
num_objects = len(all_objects) |
|
assert num_objects > 0 |
|
|
|
|
|
binary_segments = {} |
|
for obj_folder in all_objects: |
|
|
|
obj_id = int(obj_folder.split("/")[-1]) |
|
obj_id = obj_id + 1 |
|
mask_path = os.path.join(obj_folder, f"{frame_id:05d}.png") |
|
if os.path.exists(mask_path): |
|
mask = np.array(PILImage.open(mask_path)) |
|
else: |
|
mask = np.zeros((self.H, self.W), dtype=bool) |
|
binary_segments[obj_id] = torch.from_numpy(mask > 0) |
|
|
|
return binary_segments |
|
|
|
def __len__(self): |
|
return |
|
|
|
|
|
class LazySegments: |
|
""" |
|
Only decodes segments that are actually used. |
|
""" |
|
|
|
def __init__(self): |
|
self.segments = {} |
|
self.cache = {} |
|
|
|
def __setitem__(self, key, item): |
|
self.segments[key] = item |
|
|
|
def __getitem__(self, key): |
|
if key in self.cache: |
|
return self.cache[key] |
|
rle = self.segments[key] |
|
mask = torch.from_numpy(mask_utils.decode([rle])).permute(2, 0, 1)[0] |
|
self.cache[key] = mask |
|
return mask |
|
|
|
def __contains__(self, key): |
|
return key in self.segments |
|
|
|
def __len__(self): |
|
return len(self.segments) |
|
|
|
def keys(self): |
|
return self.segments.keys() |
|
|
|
|
|
class SA1BSegmentLoader: |
|
def __init__( |
|
self, |
|
video_mask_path, |
|
mask_area_frac_thresh=1.1, |
|
video_frame_path=None, |
|
uncertain_iou=-1, |
|
): |
|
with open(video_mask_path, "r") as f: |
|
self.frame_annots = json.load(f) |
|
|
|
if mask_area_frac_thresh <= 1.0: |
|
|
|
orig_w, orig_h = PILImage.open(video_frame_path).size |
|
area = orig_w * orig_h |
|
|
|
self.frame_annots = self.frame_annots["annotations"] |
|
|
|
rle_masks = [] |
|
for frame_annot in self.frame_annots: |
|
if not frame_annot["area"] > 0: |
|
continue |
|
if ("uncertain_iou" in frame_annot) and ( |
|
frame_annot["uncertain_iou"] < uncertain_iou |
|
): |
|
|
|
continue |
|
if ( |
|
mask_area_frac_thresh <= 1.0 |
|
and (frame_annot["area"] / area) >= mask_area_frac_thresh |
|
): |
|
continue |
|
rle_masks.append(frame_annot["segmentation"]) |
|
|
|
self.segments = LazySegments() |
|
for i, rle in enumerate(rle_masks): |
|
self.segments[i] = rle |
|
|
|
def load(self, frame_idx): |
|
return self.segments |
|
|