aiisc-watermarking-modelv3 / watermark_detector.py
jgyasu's picture
Upload folder using huggingface_hub
436c4c1 verified
raw
history blame
2.98 kB
import nltk
from nltk.corpus import stopwords
from transformers import AutoTokenizer, AutoModelForMaskedLM
from vocabulary_split import split_vocabulary, filter_logits
import torch
from lcs import find_common_subsequences
from paraphraser import generate_paraphrase
nltk.download('punkt', quiet=True)
nltk.download('stopwords', quiet=True)
tokenizer = AutoTokenizer.from_pretrained("bert-large-cased-whole-word-masking")
model = AutoModelForMaskedLM.from_pretrained("bert-large-cased-whole-word-masking")
permissible, _ = split_vocabulary(seed=42)
permissible_indices = torch.tensor([i in permissible.values() for i in range(len(tokenizer))])
def get_non_melting_points(original_sentence):
paraphrased_sentences = generate_paraphrase(original_sentence)
common_subsequences = find_common_subsequences(original_sentence, paraphrased_sentences)
return common_subsequences
def get_word_between_points(sentence, start_point, end_point):
words = nltk.word_tokenize(sentence)
stop_words = set(stopwords.words('english'))
start_index = sentence.index(start_point[1])
end_index = sentence.index(end_point[1])
for word in words[start_index+1:end_index]:
if word.lower() not in stop_words:
return word, words.index(word)
return None, None
def get_logits_for_mask(sentence):
inputs = tokenizer(sentence, return_tensors="pt")
mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1]
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
mask_token_logits = logits[0, mask_token_index, :]
return mask_token_logits.squeeze()
def detect_watermark(sentence):
non_melting_points = get_non_melting_points(sentence)
if len(non_melting_points) < 2:
return False, "Not enough non-melting points found."
word_to_check, index = get_word_between_points(sentence, non_melting_points[0], non_melting_points[1])
if word_to_check is None:
return False, "No suitable word found between non-melting points."
words = nltk.word_tokenize(sentence)
masked_sentence = ' '.join(words[:index] + ['[MASK]'] + words[index+1:])
logits = get_logits_for_mask(masked_sentence)
filtered_logits = filter_logits(logits, permissible_indices)
top_predictions = filtered_logits.argsort()[-5:]
predicted_words = [tokenizer.decode([i]) for i in top_predictions]
if word_to_check in predicted_words:
return True, f"Watermark detected. The word '{word_to_check}' is in the permissible vocabulary."
else:
return False, f"No watermark detected. The word '{word_to_check}' is not in the permissible vocabulary."
# Example usage
# if __name__ == "__main__":
# test_sentence = "The quick brown fox jumps over the lazy dog."
# is_watermarked, message = detect_watermark(test_sentence)
# print(f"Is the sentence watermarked? {is_watermarked}")
# print(f"Detection message: {message}")