DIRECTOR-demo / src /datasets /modalities /caption_dataset.py
robin-courant's picture
Add app
f7a5cb1 verified
from collections import Counter
from pathlib import Path
import numpy as np
import torch
from torch.utils.data import Dataset
import torch.nn.functional as F
from utils.file_utils import load_txt
class CaptionDataset(Dataset):
def __init__(
self,
name: str,
dataset_dir: str,
num_cams: int,
num_feats: int,
num_segments: int,
sequential: bool,
**kwargs,
):
super().__init__()
self.modality = name
self.name = name
self.dataset_dir = Path(dataset_dir)
# Set data paths (segments, captions, etc...)
for name, field in kwargs.items():
if isinstance(field, str):
field = Path(field)
if name == "feat_caption_dir":
field = field / "seq" if sequential else field / "token"
setattr(self, name, field)
self.filenames = None
self.clip_seq_dir = self.dataset_dir / "caption_clip" / "seq" # For CLaTrScore
self.num_cams = num_cams
self.num_feats = num_feats
self.num_segments = num_segments
self.sequential = sequential
def __len__(self):
return len(self.filenames)
def __getitem__(self, index):
filename = self.filenames[index]
# Load data
if hasattr(self, "segment_dir"):
raw_segments = torch.from_numpy(
np.load((self.segment_dir / (filename + ".npy")))
)
padded_raw_segments = F.pad(
raw_segments,
(0, self.num_cams - len(raw_segments)),
value=self.num_segments,
)
if hasattr(self, "raw_caption_dir"):
raw_caption = load_txt(self.raw_caption_dir / (filename + ".txt"))
if hasattr(self, "feat_caption_dir"):
feat_caption = torch.from_numpy(
np.load((self.feat_caption_dir / (filename + ".npy")))
)
if self.sequential:
feat_caption = F.pad(
feat_caption.to(torch.float32),
(0, 0, 0, self.max_feat_length - feat_caption.shape[0]),
)
if self.modality == "caption":
raw_data = {"caption": raw_caption, "segments": padded_raw_segments}
feat_data = (
feat_caption.permute(1, 0) if feat_caption.dim() == 2 else feat_caption
)
elif self.modality == "segments":
raw_data = {"segments": padded_raw_segments}
# Shift by one for padding
feat_data = F.one_hot(
padded_raw_segments, num_classes=self.num_segments + 1
).to(torch.float32)
if self.sequential:
feat_data = feat_data.permute(1, 0)
else:
feat_data = feat_data.reshape(-1)
elif self.modality == "class":
raw_data = {"segments": padded_raw_segments}
most_frequent_segment = Counter(raw_segments).most_common(1)[0][0]
feat_data = F.one_hot(
torch.tensor(most_frequent_segment), num_classes=self.num_segments
).to(torch.float32)
else:
raise ValueError(f"Modality {self.modality} not supported")
clip_seq_caption = torch.from_numpy(
np.load((self.clip_seq_dir / (filename + ".npy")))
)
padding_mask = torch.ones((self.max_feat_length))
padding_mask[clip_seq_caption.shape[0] :] = 0
clip_seq_caption = F.pad(
clip_seq_caption.to(torch.float32),
(0, 0, 0, self.max_feat_length - clip_seq_caption.shape[0]),
)
raw_data["clip_seq_caption"] = clip_seq_caption
raw_data["clip_seq_mask"] = padding_mask
return filename, feat_data, raw_data