DIRECTOR-demo / src /datasets /modalities /caption_dataset.py
robin-courant's picture
Add app
f7a5cb1 verified
raw
history blame
3.81 kB
from collections import Counter
from pathlib import Path
import numpy as np
import torch
from torch.utils.data import Dataset
import torch.nn.functional as F
from utils.file_utils import load_txt
class CaptionDataset(Dataset):
def __init__(
self,
name: str,
dataset_dir: str,
num_cams: int,
num_feats: int,
num_segments: int,
sequential: bool,
**kwargs,
):
super().__init__()
self.modality = name
self.name = name
self.dataset_dir = Path(dataset_dir)
# Set data paths (segments, captions, etc...)
for name, field in kwargs.items():
if isinstance(field, str):
field = Path(field)
if name == "feat_caption_dir":
field = field / "seq" if sequential else field / "token"
setattr(self, name, field)
self.filenames = None
self.clip_seq_dir = self.dataset_dir / "caption_clip" / "seq" # For CLaTrScore
self.num_cams = num_cams
self.num_feats = num_feats
self.num_segments = num_segments
self.sequential = sequential
def __len__(self):
return len(self.filenames)
def __getitem__(self, index):
filename = self.filenames[index]
# Load data
if hasattr(self, "segment_dir"):
raw_segments = torch.from_numpy(
np.load((self.segment_dir / (filename + ".npy")))
)
padded_raw_segments = F.pad(
raw_segments,
(0, self.num_cams - len(raw_segments)),
value=self.num_segments,
)
if hasattr(self, "raw_caption_dir"):
raw_caption = load_txt(self.raw_caption_dir / (filename + ".txt"))
if hasattr(self, "feat_caption_dir"):
feat_caption = torch.from_numpy(
np.load((self.feat_caption_dir / (filename + ".npy")))
)
if self.sequential:
feat_caption = F.pad(
feat_caption.to(torch.float32),
(0, 0, 0, self.max_feat_length - feat_caption.shape[0]),
)
if self.modality == "caption":
raw_data = {"caption": raw_caption, "segments": padded_raw_segments}
feat_data = (
feat_caption.permute(1, 0) if feat_caption.dim() == 2 else feat_caption
)
elif self.modality == "segments":
raw_data = {"segments": padded_raw_segments}
# Shift by one for padding
feat_data = F.one_hot(
padded_raw_segments, num_classes=self.num_segments + 1
).to(torch.float32)
if self.sequential:
feat_data = feat_data.permute(1, 0)
else:
feat_data = feat_data.reshape(-1)
elif self.modality == "class":
raw_data = {"segments": padded_raw_segments}
most_frequent_segment = Counter(raw_segments).most_common(1)[0][0]
feat_data = F.one_hot(
torch.tensor(most_frequent_segment), num_classes=self.num_segments
).to(torch.float32)
else:
raise ValueError(f"Modality {self.modality} not supported")
clip_seq_caption = torch.from_numpy(
np.load((self.clip_seq_dir / (filename + ".npy")))
)
padding_mask = torch.ones((self.max_feat_length))
padding_mask[clip_seq_caption.shape[0] :] = 0
clip_seq_caption = F.pad(
clip_seq_caption.to(torch.float32),
(0, 0, 0, self.max_feat_length - clip_seq_caption.shape[0]),
)
raw_data["clip_seq_caption"] = clip_seq_caption
raw_data["clip_seq_mask"] = padding_mask
return filename, feat_data, raw_data