|
|
|
|
|
|
|
|
|
|
|
|
|
import random |
|
from dataclasses import dataclass |
|
from typing import List |
|
|
|
from training.dataset.vos_segment_loader import LazySegments |
|
|
|
MAX_RETRIES = 1000 |
|
|
|
|
|
@dataclass |
|
class SampledFramesAndObjects: |
|
frames: List[int] |
|
object_ids: List[int] |
|
|
|
|
|
class VOSSampler: |
|
def __init__(self, sort_frames=True): |
|
|
|
self.sort_frames = sort_frames |
|
|
|
def sample(self, video): |
|
raise NotImplementedError() |
|
|
|
|
|
class RandomUniformSampler(VOSSampler): |
|
def __init__( |
|
self, |
|
num_frames, |
|
max_num_objects, |
|
reverse_time_prob=0.0, |
|
): |
|
self.num_frames = num_frames |
|
self.max_num_objects = max_num_objects |
|
self.reverse_time_prob = reverse_time_prob |
|
|
|
def sample(self, video, segment_loader, epoch=None): |
|
|
|
for retry in range(MAX_RETRIES): |
|
if len(video.frames) < self.num_frames: |
|
raise Exception( |
|
f"Cannot sample {self.num_frames} frames from video {video.video_name} as it only has {len(video.frames)} annotated frames." |
|
) |
|
start = random.randrange(0, len(video.frames) - self.num_frames + 1) |
|
frames = [video.frames[start + step] for step in range(self.num_frames)] |
|
if random.uniform(0, 1) < self.reverse_time_prob: |
|
|
|
frames = frames[::-1] |
|
|
|
|
|
visible_object_ids = [] |
|
loaded_segms = segment_loader.load(frames[0].frame_idx) |
|
if isinstance(loaded_segms, LazySegments): |
|
|
|
visible_object_ids = list(loaded_segms.keys()) |
|
else: |
|
for object_id, segment in segment_loader.load( |
|
frames[0].frame_idx |
|
).items(): |
|
if segment.sum(): |
|
visible_object_ids.append(object_id) |
|
|
|
|
|
if len(visible_object_ids) > 0: |
|
break |
|
if retry >= MAX_RETRIES - 1: |
|
raise Exception("No visible objects") |
|
|
|
object_ids = random.sample( |
|
visible_object_ids, |
|
min(len(visible_object_ids), self.max_num_objects), |
|
) |
|
return SampledFramesAndObjects(frames=frames, object_ids=object_ids) |
|
|
|
|
|
class EvalSampler(VOSSampler): |
|
""" |
|
VOS Sampler for evaluation: sampling all the frames and all the objects in a video |
|
""" |
|
|
|
def __init__( |
|
self, |
|
): |
|
super().__init__() |
|
|
|
def sample(self, video, segment_loader, epoch=None): |
|
""" |
|
Sampling all the frames and all the objects |
|
""" |
|
if self.sort_frames: |
|
|
|
frames = sorted(video.frames, key=lambda x: x.frame_idx) |
|
else: |
|
|
|
frames = video.frames |
|
object_ids = segment_loader.load(frames[0].frame_idx).keys() |
|
if len(object_ids) == 0: |
|
raise Exception("First frame of the video has no objects") |
|
|
|
return SampledFramesAndObjects(frames=frames, object_ids=object_ids) |
|
|