File size: 1,560 Bytes
6a62ffb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 |
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from fairseq.data import Dictionary
class MaskedLMDictionary(Dictionary):
"""
Dictionary for Masked Language Modelling tasks. This extends Dictionary by
adding the mask symbol.
"""
def __init__(
self,
pad="<pad>",
eos="</s>",
unk="<unk>",
mask="<mask>",
):
super().__init__(pad=pad, eos=eos, unk=unk)
self.mask_word = mask
self.mask_index = self.add_symbol(mask)
self.nspecial = len(self.symbols)
def mask(self):
"""Helper to get index of mask symbol"""
return self.mask_index
class BertDictionary(MaskedLMDictionary):
"""
Dictionary for BERT task. This extends MaskedLMDictionary by adding support
for cls and sep symbols.
"""
def __init__(
self,
pad="<pad>",
eos="</s>",
unk="<unk>",
mask="<mask>",
cls="<cls>",
sep="<sep>",
):
super().__init__(pad=pad, eos=eos, unk=unk, mask=mask)
self.cls_word = cls
self.sep_word = sep
self.cls_index = self.add_symbol(cls)
self.sep_index = self.add_symbol(sep)
self.nspecial = len(self.symbols)
def cls(self):
"""Helper to get index of cls symbol"""
return self.cls_index
def sep(self):
"""Helper to get index of sep symbol"""
return self.sep_index
|