|
import torch |
|
from torch.utils.data import Dataset |
|
from pathlib import Path |
|
|
|
|
|
class S2TDataset(Dataset): |
|
def __init__(self, data_path): |
|
self.path = Path(data_path) |
|
self.files = list(self.path.iterdir()) |
|
|
|
def __len__(self): |
|
return len(self.files) |
|
|
|
def __getitem__(self, idx): |
|
file_path = self.files[idx] |
|
eg = torch.load(file_path) |
|
eg['file_path'] = file_path |
|
return eg |
|
|
|
|
|
|
|
def make_collate_fn(tokenizer): |
|
def collate_fn(examples): |
|
wav2vec_feats = [eg['wave2vec_features'] for eg in examples] |
|
max_len = len(max(wav2vec_feats, key=len)) |
|
padded_feats, attention_masks = [], [] |
|
for feats in wav2vec_feats: |
|
num_pads = max_len - len(feats) |
|
padded_feats.append(torch.cat([feats, torch.zeros((num_pads, feats.shape[-1]), device=feats.device)])) |
|
if num_pads > 0: |
|
mask = torch.zeros((max_len,), device=feats.device).long() |
|
mask[:-num_pads] = 1 |
|
else: |
|
mask = torch.ones((max_len,), device=feats.device).long() |
|
attention_masks.append(mask) |
|
|
|
encoder_hidden_states = torch.stack(padded_feats, dim=0) |
|
encoder_attention_masks = torch.stack(attention_masks, dim=0).bool() |
|
input_ids = tokenizer([eg['sentence'] for eg in examples], return_tensors='pt', padding=True).input_ids |
|
return encoder_hidden_states, encoder_attention_masks, input_ids |
|
return collate_fn |