import nltk from nltk.corpus import stopwords from transformers import AutoTokenizer, AutoModelForMaskedLM from vocabulary_split import split_vocabulary, filter_logits import torch from lcs import find_common_subsequences from paraphraser import generate_paraphrase nltk.download('punkt', quiet=True) nltk.download('stopwords', quiet=True) tokenizer = AutoTokenizer.from_pretrained("bert-large-cased-whole-word-masking") model = AutoModelForMaskedLM.from_pretrained("bert-large-cased-whole-word-masking") permissible, _ = split_vocabulary(seed=42) permissible_indices = torch.tensor([i in permissible.values() for i in range(len(tokenizer))]) def get_non_melting_points(original_sentence): paraphrased_sentences = generate_paraphrase(original_sentence) common_subsequences = find_common_subsequences(original_sentence, paraphrased_sentences) return common_subsequences def get_word_between_points(sentence, start_point, end_point): words = nltk.word_tokenize(sentence) stop_words = set(stopwords.words('english')) start_index = sentence.index(start_point[1]) end_index = sentence.index(end_point[1]) for word in words[start_index+1:end_index]: if word.lower() not in stop_words: return word, words.index(word) return None, None def get_logits_for_mask(sentence): inputs = tokenizer(sentence, return_tensors="pt") mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1] with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits mask_token_logits = logits[0, mask_token_index, :] return mask_token_logits.squeeze() def detect_watermark(sentence): non_melting_points = get_non_melting_points(sentence) if len(non_melting_points) < 2: return False, "Not enough non-melting points found." word_to_check, index = get_word_between_points(sentence, non_melting_points[0], non_melting_points[1]) if word_to_check is None: return False, "No suitable word found between non-melting points." words = nltk.word_tokenize(sentence) masked_sentence = ' '.join(words[:index] + ['[MASK]'] + words[index+1:]) logits = get_logits_for_mask(masked_sentence) filtered_logits = filter_logits(logits, permissible_indices) top_predictions = filtered_logits.argsort()[-5:] predicted_words = [tokenizer.decode([i]) for i in top_predictions] if word_to_check in predicted_words: return True, f"Watermark detected. The word '{word_to_check}' is in the permissible vocabulary." else: return False, f"No watermark detected. The word '{word_to_check}' is not in the permissible vocabulary." # Example usage # if __name__ == "__main__": # test_sentence = "The quick brown fox jumps over the lazy dog." # is_watermarked, message = detect_watermark(test_sentence) # print(f"Is the sentence watermarked? {is_watermarked}") # print(f"Detection message: {message}")