Spaces:
Runtime error
Runtime error
File size: 667 Bytes
733aa30 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 |
import logging
import torch.utils.data
from fairseq.data import FairseqDataset
logger = logging.getLogger(__name__)
class OFADataset(FairseqDataset):
def __len__(self):
return len(self.dataset)
def encode_text(self, text, length=None, append_bos=False, append_eos=False):
s = self.tgt_dict.encode_line(
line=self.bpe.encode(text),
add_if_not_exist=False,
append_eos=False
).long()
if length is not None:
s = s[:length]
if append_bos:
s = torch.cat([self.bos_item, s])
if append_eos:
s = torch.cat([s, self.eos_item])
return s
|