cyun9286's picture
Add application file
f53b39e
raw
history blame
5.8 kB
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import logging
import random
from copy import deepcopy
import numpy as np
import torch
from iopath.common.file_io import g_pathmgr
from PIL import Image as PILImage
from torchvision.datasets.vision import VisionDataset
from training.dataset.vos_raw_dataset import VOSRawDataset
from training.dataset.vos_sampler import VOSSampler
from training.dataset.vos_segment_loader import JSONSegmentLoader
from training.utils.data_utils import Frame, Object, VideoDatapoint
MAX_RETRIES = 100
class VOSDataset(VisionDataset):
def __init__(
self,
transforms,
training: bool,
video_dataset: VOSRawDataset,
sampler: VOSSampler,
multiplier: int,
always_target=True,
target_segments_available=True,
):
self._transforms = transforms
self.training = training
self.video_dataset = video_dataset
self.sampler = sampler
self.repeat_factors = torch.ones(len(self.video_dataset), dtype=torch.float32)
self.repeat_factors *= multiplier
print(f"Raw dataset length = {len(self.video_dataset)}")
self.curr_epoch = 0 # Used in case data loader behavior changes across epochs
self.always_target = always_target
self.target_segments_available = target_segments_available
def _get_datapoint(self, idx):
for retry in range(MAX_RETRIES):
try:
if isinstance(idx, torch.Tensor):
idx = idx.item()
# sample a video
video, segment_loader = self.video_dataset.get_video(idx)
# sample frames and object indices to be used in a datapoint
sampled_frms_and_objs = self.sampler.sample(
video, segment_loader, epoch=self.curr_epoch
)
break # Succesfully loaded video
except Exception as e:
if self.training:
logging.warning(
f"Loading failed (id={idx}); Retry {retry} with exception: {e}"
)
idx = random.randrange(0, len(self.video_dataset))
else:
# Shouldn't fail to load a val video
raise e
datapoint = self.construct(video, sampled_frms_and_objs, segment_loader)
for transform in self._transforms:
datapoint = transform(datapoint, epoch=self.curr_epoch)
return datapoint
def construct(self, video, sampled_frms_and_objs, segment_loader):
"""
Constructs a VideoDatapoint sample to pass to transforms
"""
sampled_frames = sampled_frms_and_objs.frames
sampled_object_ids = sampled_frms_and_objs.object_ids
images = []
rgb_images = load_images(sampled_frames)
# Iterate over the sampled frames and store their rgb data and object data (bbox, segment)
for frame_idx, frame in enumerate(sampled_frames):
w, h = rgb_images[frame_idx].size
images.append(
Frame(
data=rgb_images[frame_idx],
objects=[],
)
)
# We load the gt segments associated with the current frame
if isinstance(segment_loader, JSONSegmentLoader):
segments = segment_loader.load(
frame.frame_idx, obj_ids=sampled_object_ids
)
else:
segments = segment_loader.load(frame.frame_idx)
for obj_id in sampled_object_ids:
# Extract the segment
if obj_id in segments:
assert (
segments[obj_id] is not None
), "None targets are not supported"
# segment is uint8 and remains uint8 throughout the transforms
segment = segments[obj_id].to(torch.uint8)
else:
# There is no target, we either use a zero mask target or drop this object
if not self.always_target:
continue
segment = torch.zeros(h, w, dtype=torch.uint8)
images[frame_idx].objects.append(
Object(
object_id=obj_id,
frame_index=frame.frame_idx,
segment=segment,
)
)
return VideoDatapoint(
frames=images,
video_id=video.video_id,
size=(h, w),
)
def __getitem__(self, idx):
return self._get_datapoint(idx)
def __len__(self):
return len(self.video_dataset)
def load_images(frames):
all_images = []
cache = {}
for frame in frames:
if frame.data is None:
# Load the frame rgb data from file
path = frame.image_path
if path in cache:
all_images.append(deepcopy(all_images[cache[path]]))
continue
with g_pathmgr.open(path, "rb") as fopen:
all_images.append(PILImage.open(fopen).convert("RGB"))
cache[path] = len(all_images) - 1
else:
# The frame rgb data has already been loaded
# Convert it to a PILImage
all_images.append(tensor_2_PIL(frame.data))
return all_images
def tensor_2_PIL(data: torch.Tensor) -> PILImage.Image:
data = data.cpu().numpy().transpose((1, 2, 0)) * 255.0
data = data.astype(np.uint8)
return PILImage.fromarray(data)