File size: 539 Bytes
a3cb5cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import numpy as np


class TokenWeighter:
    def __init__(self, tokenizer):
        self.tokenizer_ = tokenizer
        self.proba = self.get_token_proba()

    def get_token_proba(self):
        valid_token_mask = self._filter_short_partial(self.tokenizer_.vocab)
        return valid_token_mask

    def _filter_short_partial(self, vocab):
        valid_token_ids = [v for k, v in vocab.items() if len(k) > 1 and "#" not in k]
        is_valid = np.zeros(len(vocab.keys()))
        is_valid[valid_token_ids] = 1
        return is_valid