Spaces:
Running
on
Zero
Running
on
Zero
from collections import Counter | |
from pathlib import Path | |
import numpy as np | |
import torch | |
from torch.utils.data import Dataset | |
import torch.nn.functional as F | |
from utils.file_utils import load_txt | |
class CaptionDataset(Dataset): | |
def __init__( | |
self, | |
name: str, | |
dataset_dir: str, | |
num_cams: int, | |
num_feats: int, | |
num_segments: int, | |
sequential: bool, | |
**kwargs, | |
): | |
super().__init__() | |
self.modality = name | |
self.name = name | |
self.dataset_dir = Path(dataset_dir) | |
# Set data paths (segments, captions, etc...) | |
for name, field in kwargs.items(): | |
if isinstance(field, str): | |
field = Path(field) | |
if name == "feat_caption_dir": | |
field = field / "seq" if sequential else field / "token" | |
setattr(self, name, field) | |
self.filenames = None | |
self.clip_seq_dir = self.dataset_dir / "caption_clip" / "seq" # For CLaTrScore | |
self.num_cams = num_cams | |
self.num_feats = num_feats | |
self.num_segments = num_segments | |
self.sequential = sequential | |
def __len__(self): | |
return len(self.filenames) | |
def __getitem__(self, index): | |
filename = self.filenames[index] | |
# Load data | |
if hasattr(self, "segment_dir"): | |
raw_segments = torch.from_numpy( | |
np.load((self.segment_dir / (filename + ".npy"))) | |
) | |
padded_raw_segments = F.pad( | |
raw_segments, | |
(0, self.num_cams - len(raw_segments)), | |
value=self.num_segments, | |
) | |
if hasattr(self, "raw_caption_dir"): | |
raw_caption = load_txt(self.raw_caption_dir / (filename + ".txt")) | |
if hasattr(self, "feat_caption_dir"): | |
feat_caption = torch.from_numpy( | |
np.load((self.feat_caption_dir / (filename + ".npy"))) | |
) | |
if self.sequential: | |
feat_caption = F.pad( | |
feat_caption.to(torch.float32), | |
(0, 0, 0, self.max_feat_length - feat_caption.shape[0]), | |
) | |
if self.modality == "caption": | |
raw_data = {"caption": raw_caption, "segments": padded_raw_segments} | |
feat_data = ( | |
feat_caption.permute(1, 0) if feat_caption.dim() == 2 else feat_caption | |
) | |
elif self.modality == "segments": | |
raw_data = {"segments": padded_raw_segments} | |
# Shift by one for padding | |
feat_data = F.one_hot( | |
padded_raw_segments, num_classes=self.num_segments + 1 | |
).to(torch.float32) | |
if self.sequential: | |
feat_data = feat_data.permute(1, 0) | |
else: | |
feat_data = feat_data.reshape(-1) | |
elif self.modality == "class": | |
raw_data = {"segments": padded_raw_segments} | |
most_frequent_segment = Counter(raw_segments).most_common(1)[0][0] | |
feat_data = F.one_hot( | |
torch.tensor(most_frequent_segment), num_classes=self.num_segments | |
).to(torch.float32) | |
else: | |
raise ValueError(f"Modality {self.modality} not supported") | |
clip_seq_caption = torch.from_numpy( | |
np.load((self.clip_seq_dir / (filename + ".npy"))) | |
) | |
padding_mask = torch.ones((self.max_feat_length)) | |
padding_mask[clip_seq_caption.shape[0] :] = 0 | |
clip_seq_caption = F.pad( | |
clip_seq_caption.to(torch.float32), | |
(0, 0, 0, self.max_feat_length - clip_seq_caption.shape[0]), | |
) | |
raw_data["clip_seq_caption"] = clip_seq_caption | |
raw_data["clip_seq_mask"] = padding_mask | |
return filename, feat_data, raw_data | |