# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import torch | |
from fairseq.data import encoders | |
def get_whole_word_mask(args, dictionary): | |
bpe = encoders.build_bpe(args) | |
if bpe is not None: | |
def is_beginning_of_word(i): | |
if i < dictionary.nspecial: | |
# special elements are always considered beginnings | |
return True | |
tok = dictionary[i] | |
if tok.startswith("madeupword"): | |
return True | |
try: | |
return bpe.is_beginning_of_word(tok) | |
except ValueError: | |
return True | |
mask_whole_words = torch.ByteTensor( | |
list(map(is_beginning_of_word, range(len(dictionary)))) | |
) | |
return mask_whole_words | |
return None | |