# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import torch from fairseq.data import encoders def get_whole_word_mask(args, dictionary): bpe = encoders.build_bpe(args) if bpe is not None: def is_beginning_of_word(i): if i < dictionary.nspecial: # special elements are always considered beginnings return True tok = dictionary[i] if tok.startswith("madeupword"): return True try: return bpe.is_beginning_of_word(tok) except ValueError: return True mask_whole_words = torch.ByteTensor( list(map(is_beginning_of_word, range(len(dictionary)))) ) return mask_whole_words return None