File size: 3,160 Bytes
cd89176
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import torch 
from torch.utils.data import Dataset

class BilingualDataset(Dataset):

    def __init__(self, dataset, source_tokenizer, target_tokenizer, source_language, target_language, sequence_length):
        super().__init__()
        self.dataset = dataset
        self.source_tokenizer = source_tokenizer
        self.target_tokenizer = target_tokenizer
        self.source_language = source_language
        self.target_language = target_language
        self.sequence_length = sequence_length

        self.SOS_token = torch.tensor([target_tokenizer.token_to_id("[SOS]")], dtype=torch.int64)
        self.PAD_token = torch.tensor([target_tokenizer.token_to_id("[PAD]")], dtype= torch.int64)
        self.EOS_token = torch.tensor([target_tokenizer.token_to_id("[EOS]")], dtype= torch.int64)

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index) :
        source_target_dataset = self.dataset[index]
        source_text = source_target_dataset['translation'][self.source_language]
        target_text = source_target_dataset['translation'][self.target_language]

        encode_source_tokenizer = self.source_tokenizer.encode(source_text).ids 
        encode_target_tokenizer = self.target_tokenizer.encode(target_text).ids 

        encode_source_padding = self.sequence_length - len(encode_source_tokenizer) - 2 
        encode_target_padding = self.sequence_length - len(encode_target_tokenizer) - 1 

        if encode_source_padding < 0 or encode_target_padding < 0:
            raise ValueError("sequence is too long")

        encoder_input = torch.cat(
            [
                self.SOS_token,
                torch.tensor(encode_source_tokenizer, dtype=torch.int64),
                self.EOS_token,
                torch.tensor([self.PAD_token] * encode_source_padding, dtype=torch.int64)
            ]
        )

        decoder_input = torch.cat(
            [
                self.SOS_token,
                torch.tensor(encode_target_tokenizer, dtype=torch.int64),
                torch.tensor([self.PAD_token] * encode_target_padding, dtype=torch.int64)
            ]
        )

        Target = torch.cat(
            [
                torch.tensor(encode_target_tokenizer, dtype=torch.int64),
                torch.tensor([self.PAD_token] * encode_target_padding, dtype=torch.int64),
                self.EOS_token
            ]
        )

        assert encoder_input.size(0) == self.sequence_length 
        assert decoder_input.size(0) == self.sequence_length 
        assert Target.size(0) == self.sequence_length

        return {
            "encoder_input": encoder_input,
            "decoder_input": decoder_input,
            "encoder_input_mask": (encoder_input != self.PAD_token).unsqueeze(0).unsqueeze(0).int(),
            "decoder_input_mask": (decoder_input != self.PAD_token).unsqueeze(0).int() & casual_mask(decoder_input.size(0)),
            "Target": Target,
            "source_text": source_text,
            "target_text": target_text 
        }


def casual_mask(size):
    mask = torch.triu(torch.ones((1, size, size)), diagonal=1).type(torch.int)
    return mask == 0