File size: 2,334 Bytes
d07276d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
from typing import Union

import numpy as np
from transformers.utils import TensorType
from transformers.feature_extraction_utils import BatchFeature


class PadAndSortCollator:
    def __init__(self, processor, return_tensors: Union[str, TensorType] = "pt"):
        self.processor = processor
        self.return_tensors = return_tensors

    def __call__(self, batch):
        """
        expect batch with `return_tensors=None` from processor
        batch: input_ids, length(optional), mel_specgram, mel_specgram_length(optional)
        """
        text_batch = {}
        text_batch["input_ids"] = [x["input_ids"] for x in batch]
        if "length" in batch[0]:
            text_batch["length"] = [x["length"] for x in batch]
        else:
            text_batch["length"] = [len(x["input_ids"]) for x in batch]

        audio_batch = {}
        # transpose mel_specgram for padding
        audio_batch["mel_specgram"] = [
            x["mel_specgram"][0].transpose(1, 0) for x in batch
        ]
        if "mel_specgram_length" in batch[0]:
            audio_batch["mel_specgram_length"] = [
                x["mel_specgram_length"] for x in batch
            ]
        else:
            audio_batch["mel_specgram_length"] = [
                x["mel_specgram"][0].shape[1] for x in batch
            ]

        text_batch = self.processor.tokenizer.pad(
            text_batch,
            padding=True,
            return_tensors="np",
            return_attention_mask=False,
        )

        audio_batch = self.processor.feature_extractor.pad(
            audio_batch,
            padding=True,
            return_tensors="np",
            return_attention_mask=True,
        )
        audio_batch["mel_specgram"] = audio_batch["mel_specgram"].transpose(0, 2, 1)

        attention_mask = audio_batch.pop("attention_mask")
        gate_padded = 1 - attention_mask
        gate_padded = np.roll(gate_padded, -1, axis=1)
        gate_padded[:, -1] = 1
        gate_padded = gate_padded.astype(np.float32)

        output = {**text_batch, **audio_batch, "gate_padded": gate_padded}

        # sort by text length
        sort_idx = np.argsort(output["length"])[::-1]

        for key, value in output.items():
            output[key] = value[sort_idx]

        return BatchFeature(output, tensor_type=self.return_tensors)