|
import torch |
|
from transformers import AutoTokenizer, AutoModelForMaskedLM |
|
from transformers import pipeline |
|
import random |
|
from nltk.corpus import stopwords |
|
import math |
|
from vocabulary_split import split_vocabulary, filter_logits |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("bert-large-cased-whole-word-masking") |
|
model = AutoModelForMaskedLM.from_pretrained("bert-large-cased-whole-word-masking") |
|
fill_mask = pipeline("fill-mask", model=model, tokenizer=tokenizer) |
|
|
|
|
|
permissible, _ = split_vocabulary(seed=42) |
|
permissible_indices = torch.tensor([i in permissible.values() for i in range(len(tokenizer))]) |
|
|
|
def get_logits_for_mask(model, tokenizer, sentence): |
|
inputs = tokenizer(sentence, return_tensors="pt") |
|
mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1] |
|
|
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
|
|
logits = outputs.logits |
|
mask_token_logits = logits[0, mask_token_index, :] |
|
return mask_token_logits.squeeze() |
|
|
|
def mask_non_stopword(sentence): |
|
stop_words = set(stopwords.words('english')) |
|
words = sentence.split() |
|
non_stop_words = [word for word in words if word.lower() not in stop_words] |
|
if not non_stop_words: |
|
return sentence, None, None |
|
word_to_mask = random.choice(non_stop_words) |
|
masked_sentence = sentence.replace(word_to_mask, '[MASK]', 1) |
|
logits = get_logits_for_mask(model, tokenizer, masked_sentence) |
|
filtered_logits = filter_logits(logits, permissible_indices) |
|
words = [tokenizer.decode([i]) for i in filtered_logits.argsort()[-5:]] |
|
return masked_sentence, filtered_logits.tolist(), words |
|
|
|
def mask_non_stopword_pseudorandom(sentence): |
|
stop_words = set(stopwords.words('english')) |
|
words = sentence.split() |
|
non_stop_words = [word for word in words if word.lower() not in stop_words] |
|
if not non_stop_words: |
|
return sentence, None, None |
|
random.seed(10) |
|
word_to_mask = random.choice(non_stop_words) |
|
masked_sentence = sentence.replace(word_to_mask, '[MASK]', 1) |
|
logits = get_logits_for_mask(model, tokenizer, masked_sentence) |
|
filtered_logits = filter_logits(logits, permissible_indices) |
|
words = [tokenizer.decode([i]) for i in filtered_logits.argsort()[-5:]] |
|
return masked_sentence, filtered_logits.tolist(), words |
|
|
|
|
|
def mask_between_lcs(sentence, lcs_points): |
|
words = sentence.split() |
|
masked_indices = [] |
|
|
|
|
|
if lcs_points and lcs_points[0] > 0: |
|
idx = random.randint(0, lcs_points[0]-1) |
|
words[idx] = '[MASK]' |
|
masked_indices.append(idx) |
|
|
|
|
|
for i in range(len(lcs_points) - 1): |
|
start, end = lcs_points[i], lcs_points[i+1] |
|
if end - start > 1: |
|
mask_index = random.randint(start + 1, end - 1) |
|
words[mask_index] = '[MASK]' |
|
masked_indices.append(mask_index) |
|
|
|
|
|
if lcs_points and lcs_points[-1] < len(words) - 1: |
|
idx = random.randint(lcs_points[-1]+1, len(words)-1) |
|
words[idx] = '[MASK]' |
|
masked_indices.append(idx) |
|
|
|
masked_sentence = ' '.join(words) |
|
logits = get_logits_for_mask(model, tokenizer, masked_sentence) |
|
|
|
|
|
top_words_list = [] |
|
logits_list = [] |
|
for i in range(len(masked_indices)): |
|
logits_i = logits[i] |
|
if logits_i.dim() > 1: |
|
logits_i = logits_i.squeeze() |
|
filtered_logits_i = filter_logits(logits_i, permissible_indices) |
|
logits_list.append(filtered_logits_i.tolist()) |
|
top_5_indices = filtered_logits_i.topk(5).indices.tolist() |
|
top_words = [tokenizer.decode([i]) for i in top_5_indices] |
|
top_words_list.append(top_words) |
|
|
|
return masked_sentence, logits_list, top_words_list |
|
|
|
|
|
def high_entropy_words(sentence, non_melting_points): |
|
stop_words = set(stopwords.words('english')) |
|
words = sentence.split() |
|
|
|
non_melting_words = set() |
|
for _, point in non_melting_points: |
|
non_melting_words.update(point.lower().split()) |
|
|
|
candidate_words = [word for word in words if word.lower() not in stop_words and word.lower() not in non_melting_words] |
|
|
|
if not candidate_words: |
|
return sentence, None, None |
|
|
|
max_entropy = -float('inf') |
|
max_entropy_word = None |
|
max_logits = None |
|
|
|
for word in candidate_words: |
|
masked_sentence = sentence.replace(word, '[MASK]', 1) |
|
logits = get_logits_for_mask(model, tokenizer, masked_sentence) |
|
filtered_logits = filter_logits(logits, permissible_indices) |
|
|
|
|
|
probs = torch.softmax(filtered_logits, dim=-1) |
|
top_5_probs = probs.topk(5).values |
|
entropy = -torch.sum(top_5_probs * torch.log(top_5_probs)) |
|
|
|
if entropy > max_entropy: |
|
max_entropy = entropy |
|
max_entropy_word = word |
|
max_logits = filtered_logits |
|
|
|
if max_entropy_word is None: |
|
return sentence, None, None |
|
|
|
masked_sentence = sentence.replace(max_entropy_word, '[MASK]', 1) |
|
words = [tokenizer.decode([i]) for i in max_logits.argsort()[-5:]] |
|
return masked_sentence, max_logits.tolist(), words |
|
|
|
|
|
def mask_by_pos(sentence, pos_to_mask=['NOUN', 'VERB', 'ADJ']): |
|
import nltk |
|
nltk.download('averaged_perceptron_tagger', quiet=True) |
|
|
|
words = nltk.word_tokenize(sentence) |
|
pos_tags = nltk.pos_tag(words) |
|
|
|
maskable_words = [word for word, pos in pos_tags if pos[:2] in pos_to_mask] |
|
|
|
if not maskable_words: |
|
return sentence, None, None |
|
|
|
word_to_mask = random.choice(maskable_words) |
|
masked_sentence = sentence.replace(word_to_mask, '[MASK]', 1) |
|
|
|
logits = get_logits_for_mask(model, tokenizer, masked_sentence) |
|
filtered_logits = filter_logits(logits, permissible_indices) |
|
words = [tokenizer.decode([i]) for i in filtered_logits.argsort()[-5:]] |
|
|
|
return masked_sentence, filtered_logits.tolist(), words |
|
|
|
|
|
def mask_named_entity(sentence): |
|
import nltk |
|
nltk.download('maxent_ne_chunker', quiet=True) |
|
nltk.download('words', quiet=True) |
|
|
|
words = nltk.word_tokenize(sentence) |
|
pos_tags = nltk.pos_tag(words) |
|
named_entities = nltk.ne_chunk(pos_tags) |
|
|
|
maskable_words = [word for word, tag in named_entities.leaves() if isinstance(tag, nltk.Tree)] |
|
|
|
if not maskable_words: |
|
return sentence, None, None |
|
|
|
word_to_mask = random.choice(maskable_words) |
|
masked_sentence = sentence.replace(word_to_mask, '[MASK]', 1) |
|
|
|
logits = get_logits_for_mask(model, tokenizer, masked_sentence) |
|
filtered_logits = filter_logits(logits, permissible_indices) |
|
words = [tokenizer.decode([i]) for i in filtered_logits.argsort()[-5:]] |
|
|
|
return masked_sentence, filtered_logits.tolist(), words |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|