File size: 5,796 Bytes
f53b39e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import logging
import random
from copy import deepcopy
import numpy as np
import torch
from iopath.common.file_io import g_pathmgr
from PIL import Image as PILImage
from torchvision.datasets.vision import VisionDataset
from training.dataset.vos_raw_dataset import VOSRawDataset
from training.dataset.vos_sampler import VOSSampler
from training.dataset.vos_segment_loader import JSONSegmentLoader
from training.utils.data_utils import Frame, Object, VideoDatapoint
MAX_RETRIES = 100
class VOSDataset(VisionDataset):
def __init__(
self,
transforms,
training: bool,
video_dataset: VOSRawDataset,
sampler: VOSSampler,
multiplier: int,
always_target=True,
target_segments_available=True,
):
self._transforms = transforms
self.training = training
self.video_dataset = video_dataset
self.sampler = sampler
self.repeat_factors = torch.ones(len(self.video_dataset), dtype=torch.float32)
self.repeat_factors *= multiplier
print(f"Raw dataset length = {len(self.video_dataset)}")
self.curr_epoch = 0 # Used in case data loader behavior changes across epochs
self.always_target = always_target
self.target_segments_available = target_segments_available
def _get_datapoint(self, idx):
for retry in range(MAX_RETRIES):
try:
if isinstance(idx, torch.Tensor):
idx = idx.item()
# sample a video
video, segment_loader = self.video_dataset.get_video(idx)
# sample frames and object indices to be used in a datapoint
sampled_frms_and_objs = self.sampler.sample(
video, segment_loader, epoch=self.curr_epoch
)
break # Succesfully loaded video
except Exception as e:
if self.training:
logging.warning(
f"Loading failed (id={idx}); Retry {retry} with exception: {e}"
)
idx = random.randrange(0, len(self.video_dataset))
else:
# Shouldn't fail to load a val video
raise e
datapoint = self.construct(video, sampled_frms_and_objs, segment_loader)
for transform in self._transforms:
datapoint = transform(datapoint, epoch=self.curr_epoch)
return datapoint
def construct(self, video, sampled_frms_and_objs, segment_loader):
"""
Constructs a VideoDatapoint sample to pass to transforms
"""
sampled_frames = sampled_frms_and_objs.frames
sampled_object_ids = sampled_frms_and_objs.object_ids
images = []
rgb_images = load_images(sampled_frames)
# Iterate over the sampled frames and store their rgb data and object data (bbox, segment)
for frame_idx, frame in enumerate(sampled_frames):
w, h = rgb_images[frame_idx].size
images.append(
Frame(
data=rgb_images[frame_idx],
objects=[],
)
)
# We load the gt segments associated with the current frame
if isinstance(segment_loader, JSONSegmentLoader):
segments = segment_loader.load(
frame.frame_idx, obj_ids=sampled_object_ids
)
else:
segments = segment_loader.load(frame.frame_idx)
for obj_id in sampled_object_ids:
# Extract the segment
if obj_id in segments:
assert (
segments[obj_id] is not None
), "None targets are not supported"
# segment is uint8 and remains uint8 throughout the transforms
segment = segments[obj_id].to(torch.uint8)
else:
# There is no target, we either use a zero mask target or drop this object
if not self.always_target:
continue
segment = torch.zeros(h, w, dtype=torch.uint8)
images[frame_idx].objects.append(
Object(
object_id=obj_id,
frame_index=frame.frame_idx,
segment=segment,
)
)
return VideoDatapoint(
frames=images,
video_id=video.video_id,
size=(h, w),
)
def __getitem__(self, idx):
return self._get_datapoint(idx)
def __len__(self):
return len(self.video_dataset)
def load_images(frames):
all_images = []
cache = {}
for frame in frames:
if frame.data is None:
# Load the frame rgb data from file
path = frame.image_path
if path in cache:
all_images.append(deepcopy(all_images[cache[path]]))
continue
with g_pathmgr.open(path, "rb") as fopen:
all_images.append(PILImage.open(fopen).convert("RGB"))
cache[path] = len(all_images) - 1
else:
# The frame rgb data has already been loaded
# Convert it to a PILImage
all_images.append(tensor_2_PIL(frame.data))
return all_images
def tensor_2_PIL(data: torch.Tensor) -> PILImage.Image:
data = data.cpu().numpy().transpose((1, 2, 0)) * 255.0
data = data.astype(np.uint8)
return PILImage.fromarray(data)
|