# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import logging import random from copy import deepcopy import numpy as np import torch from iopath.common.file_io import g_pathmgr from PIL import Image as PILImage from torchvision.datasets.vision import VisionDataset from training.dataset.vos_raw_dataset import VOSRawDataset from training.dataset.vos_sampler import VOSSampler from training.dataset.vos_segment_loader import JSONSegmentLoader from training.utils.data_utils import Frame, Object, VideoDatapoint MAX_RETRIES = 100 class VOSDataset(VisionDataset): def __init__( self, transforms, training: bool, video_dataset: VOSRawDataset, sampler: VOSSampler, multiplier: int, always_target=True, target_segments_available=True, ): self._transforms = transforms self.training = training self.video_dataset = video_dataset self.sampler = sampler self.repeat_factors = torch.ones(len(self.video_dataset), dtype=torch.float32) self.repeat_factors *= multiplier print(f"Raw dataset length = {len(self.video_dataset)}") self.curr_epoch = 0 # Used in case data loader behavior changes across epochs self.always_target = always_target self.target_segments_available = target_segments_available def _get_datapoint(self, idx): for retry in range(MAX_RETRIES): try: if isinstance(idx, torch.Tensor): idx = idx.item() # sample a video video, segment_loader = self.video_dataset.get_video(idx) # sample frames and object indices to be used in a datapoint sampled_frms_and_objs = self.sampler.sample( video, segment_loader, epoch=self.curr_epoch ) break # Succesfully loaded video except Exception as e: if self.training: logging.warning( f"Loading failed (id={idx}); Retry {retry} with exception: {e}" ) idx = random.randrange(0, len(self.video_dataset)) else: # Shouldn't fail to load a val video raise e datapoint = self.construct(video, sampled_frms_and_objs, segment_loader) for transform in self._transforms: datapoint = transform(datapoint, epoch=self.curr_epoch) return datapoint def construct(self, video, sampled_frms_and_objs, segment_loader): """ Constructs a VideoDatapoint sample to pass to transforms """ sampled_frames = sampled_frms_and_objs.frames sampled_object_ids = sampled_frms_and_objs.object_ids images = [] rgb_images = load_images(sampled_frames) # Iterate over the sampled frames and store their rgb data and object data (bbox, segment) for frame_idx, frame in enumerate(sampled_frames): w, h = rgb_images[frame_idx].size images.append( Frame( data=rgb_images[frame_idx], objects=[], ) ) # We load the gt segments associated with the current frame if isinstance(segment_loader, JSONSegmentLoader): segments = segment_loader.load( frame.frame_idx, obj_ids=sampled_object_ids ) else: segments = segment_loader.load(frame.frame_idx) for obj_id in sampled_object_ids: # Extract the segment if obj_id in segments: assert ( segments[obj_id] is not None ), "None targets are not supported" # segment is uint8 and remains uint8 throughout the transforms segment = segments[obj_id].to(torch.uint8) else: # There is no target, we either use a zero mask target or drop this object if not self.always_target: continue segment = torch.zeros(h, w, dtype=torch.uint8) images[frame_idx].objects.append( Object( object_id=obj_id, frame_index=frame.frame_idx, segment=segment, ) ) return VideoDatapoint( frames=images, video_id=video.video_id, size=(h, w), ) def __getitem__(self, idx): return self._get_datapoint(idx) def __len__(self): return len(self.video_dataset) def load_images(frames): all_images = [] cache = {} for frame in frames: if frame.data is None: # Load the frame rgb data from file path = frame.image_path if path in cache: all_images.append(deepcopy(all_images[cache[path]])) continue with g_pathmgr.open(path, "rb") as fopen: all_images.append(PILImage.open(fopen).convert("RGB")) cache[path] = len(all_images) - 1 else: # The frame rgb data has already been loaded # Convert it to a PILImage all_images.append(tensor_2_PIL(frame.data)) return all_images def tensor_2_PIL(data: torch.Tensor) -> PILImage.Image: data = data.cpu().numpy().transpose((1, 2, 0)) * 255.0 data = data.astype(np.uint8) return PILImage.fromarray(data)