File size: 1,584 Bytes
004e907
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import torch
from torch.utils.data import Dataset
from pathlib import Path


class S2TDataset(Dataset):
    def __init__(self, data_path):
        self.path = Path(data_path)
        self.files = list(self.path.iterdir())

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        file_path = self.files[idx]
        eg = torch.load(file_path)
        eg['file_path'] = file_path
        return eg


# TODO: Somehow masks do not work yet (bad performace), but Training also works w/o using the mask.
def make_collate_fn(tokenizer):
    def collate_fn(examples):
        wav2vec_feats = [eg['wave2vec_features'] for eg in examples]
        max_len = len(max(wav2vec_feats, key=len))
        padded_feats, attention_masks = [], []
        for feats in wav2vec_feats:
            num_pads = max_len - len(feats)
            padded_feats.append(torch.cat([feats, torch.zeros((num_pads, feats.shape[-1]), device=feats.device)]))
            if num_pads > 0:
                mask = torch.zeros((max_len,), device=feats.device).long()
                mask[:-num_pads] = 1
            else:
                mask = torch.ones((max_len,), device=feats.device).long()
            attention_masks.append(mask)

        encoder_hidden_states = torch.stack(padded_feats, dim=0)
        encoder_attention_masks = torch.stack(attention_masks, dim=0).bool()
        input_ids = tokenizer([eg['sentence'] for eg in examples], return_tensors='pt', padding=True).input_ids
        return encoder_hidden_states, encoder_attention_masks, input_ids
    return collate_fn