from typing import List, Dict, Any import torch import torch.nn as nn from torch.utils.data import Dataset from torch import Tensor from tokenizers import Tokenizer class BilingualDataset(Dataset): """ A Bilingual Dataset that follows the structure of the 'opus_books' dataset. """ def __init__( self, ds: List[Dict[str, Dict[str,str]]], src_tokenizer: Tokenizer, tgt_tokenizer: Tokenizer, src_lang: str, tgt_lang: str, src_max_seq_len: int, tgt_max_seq_len: int, ) -> None: super(BilingualDataset, self).__init__() self.ds = ds self.src_tokenizer = src_tokenizer self.tgt_tokenizer = tgt_tokenizer self.src_lang = src_lang self.tgt_lang = tgt_lang self.src_max_seq_len = src_max_seq_len self.tgt_max_seq_len = tgt_max_seq_len self.sos_token = torch.tensor([src_tokenizer.token_to_id('')], dtype=torch.int64) self.eos_token = torch.tensor([src_tokenizer.token_to_id('')], dtype=torch.int64) self.pad_token = torch.tensor([src_tokenizer.token_to_id('')], dtype=torch.int64) def __len__(self): return len(self.ds) def __getitem__(self, index: int) -> Dict[str, Any]: src_tgt_pair = self.ds[index] src_text = src_tgt_pair['translation'][self.src_lang] tgt_text = src_tgt_pair['translation'][self.tgt_lang] encoder_input_tokens = self.src_tokenizer.encode(src_text).ids decoder_input_tokens = self.tgt_tokenizer.encode(tgt_text).ids encoder_num_padding = self.src_max_seq_len - len(encoder_input_tokens) - 2 # + decoder_num_padding = self.tgt_max_seq_len - len(decoder_input_tokens) - 1 # # + source_text + + = encoder_input encoder_input = torch.cat( [ self.sos_token, torch.tensor(encoder_input_tokens, dtype=torch.int64), self.eos_token, torch.tensor([self.pad_token] * encoder_num_padding, dtype=torch.int64) ] ) decoder_input_tokens = torch.tensor(decoder_input_tokens, dtype=torch.int64) decoder_padding = torch.tensor([self.pad_token] * decoder_num_padding, dtype=torch.int64) # + target_text + = decoder_input decoder_input = torch.cat( [ self.sos_token, decoder_input_tokens, decoder_padding ] ) # target_text + + = expected decoder_output (label) label = torch.cat( [ decoder_input_tokens, self.eos_token, decoder_padding ] ) assert encoder_input.size(0) == self.src_max_seq_len assert decoder_input.size(0) == self.tgt_max_seq_len assert label.size(0) == self.tgt_max_seq_len return { 'encoder_input': encoder_input, # (seq_len) 'decoder_input': decoder_input, # (seq_len) 'encoder_mask': (encoder_input != self.pad_token).unsqueeze(0).unsqueeze(0).int(), # (1, 1, seq_len) 'decoder_mask': (decoder_input != self.pad_token).unsqueeze(0).unsqueeze(0).int() & causal_mask(decoder_input.size(0)), # (1, seq_len, seq_len) 'label': label, # (seq_len) 'src_text': src_text, 'tgt_text': tgt_text, } def causal_mask(size: int) -> Tensor: mask = torch.triu(torch.ones(1, size, size), diagonal=1).type(torch.int) return mask == 0