|
import random |
|
from transformers import AutoTokenizer, AutoModelForMaskedLM |
|
import torch |
|
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") |
|
model = AutoModelForMaskedLM.from_pretrained("bert-base-uncased") |
|
def split_vocabulary(seed=42): |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") |
|
model = AutoModelForMaskedLM.from_pretrained("bert-base-uncased") |
|
|
|
|
|
vocab = list(tokenizer.get_vocab().items()) |
|
|
|
|
|
random.seed(seed) |
|
|
|
|
|
permissible = {} |
|
non_permissible = {} |
|
|
|
for word, index in vocab: |
|
if random.random() < 0.5: |
|
permissible[word] = index |
|
else: |
|
non_permissible[word] = index |
|
|
|
return permissible, non_permissible |
|
|
|
def get_logits_for_mask(model, tokenizer, sentence): |
|
inputs = tokenizer(sentence, return_tensors="pt") |
|
mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1] |
|
|
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
|
|
logits = outputs.logits |
|
mask_token_logits = logits[0, mask_token_index, :] |
|
return mask_token_logits.squeeze() |
|
|
|
def filter_logits(logits, permissible_indices): |
|
filtered_logits = logits.clone() |
|
if filtered_logits.dim() > 1: |
|
filtered_logits = filtered_logits.squeeze() |
|
if filtered_logits.shape != permissible_indices.shape: |
|
permissible_indices = permissible_indices[:filtered_logits.shape[0]] |
|
filtered_logits[~permissible_indices] = float('-inf') |
|
return filtered_logits |
|
|
|
|
|
permissible, non_permissible = split_vocabulary(seed=42) |
|
permissible_indices = torch.tensor([i in permissible.values() for i in range(len(tokenizer))]) |
|
|
|
|
|
sentence = "The [MASK] is bright today." |
|
logits = get_logits_for_mask(model, tokenizer, sentence) |
|
filtered_logits = filter_logits(logits, permissible_indices) |
|
|