|
import nltk |
|
from nltk.corpus import stopwords |
|
from transformers import AutoTokenizer, AutoModelForMaskedLM |
|
from vocabulary_split import split_vocabulary, filter_logits |
|
import torch |
|
from lcs import find_common_subsequences |
|
from paraphraser import generate_paraphrase |
|
|
|
nltk.download('punkt', quiet=True) |
|
nltk.download('stopwords', quiet=True) |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("bert-large-cased-whole-word-masking") |
|
model = AutoModelForMaskedLM.from_pretrained("bert-large-cased-whole-word-masking") |
|
|
|
permissible, _ = split_vocabulary(seed=42) |
|
permissible_indices = torch.tensor([i in permissible.values() for i in range(len(tokenizer))]) |
|
|
|
def get_non_melting_points(original_sentence): |
|
paraphrased_sentences = generate_paraphrase(original_sentence) |
|
common_subsequences = find_common_subsequences(original_sentence, paraphrased_sentences) |
|
return common_subsequences |
|
|
|
def get_word_between_points(sentence, start_point, end_point): |
|
words = nltk.word_tokenize(sentence) |
|
stop_words = set(stopwords.words('english')) |
|
start_index = sentence.index(start_point[1]) |
|
end_index = sentence.index(end_point[1]) |
|
|
|
for word in words[start_index+1:end_index]: |
|
if word.lower() not in stop_words: |
|
return word, words.index(word) |
|
return None, None |
|
|
|
def get_logits_for_mask(sentence): |
|
inputs = tokenizer(sentence, return_tensors="pt") |
|
mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1] |
|
|
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
|
|
logits = outputs.logits |
|
mask_token_logits = logits[0, mask_token_index, :] |
|
return mask_token_logits.squeeze() |
|
|
|
def detect_watermark(sentence): |
|
non_melting_points = get_non_melting_points(sentence) |
|
|
|
if len(non_melting_points) < 2: |
|
return False, "Not enough non-melting points found." |
|
|
|
word_to_check, index = get_word_between_points(sentence, non_melting_points[0], non_melting_points[1]) |
|
|
|
if word_to_check is None: |
|
return False, "No suitable word found between non-melting points." |
|
|
|
words = nltk.word_tokenize(sentence) |
|
masked_sentence = ' '.join(words[:index] + ['[MASK]'] + words[index+1:]) |
|
|
|
logits = get_logits_for_mask(masked_sentence) |
|
filtered_logits = filter_logits(logits, permissible_indices) |
|
|
|
top_predictions = filtered_logits.argsort()[-5:] |
|
predicted_words = [tokenizer.decode([i]) for i in top_predictions] |
|
|
|
if word_to_check in predicted_words: |
|
return True, f"Watermark detected. The word '{word_to_check}' is in the permissible vocabulary." |
|
else: |
|
return False, f"No watermark detected. The word '{word_to_check}' is not in the permissible vocabulary." |
|
|
|
|
|
|
|
|
|
|
|
|
|
|