Commit
Β·
b248ba6
1
Parent(s):
80b5673
model update
Browse files- __pycache__/app.cpython-310.pyc +0 -0
- {src β algorithm}/__pycache__/demo_watermark.cpython-310.pyc +0 -0
- {src β algorithm}/__pycache__/extended_watermark_engine.cpython-310.pyc +0 -0
- {src β algorithm}/__pycache__/extended_watermark_processor.cpython-310.pyc +0 -0
- {src β algorithm}/__pycache__/extended_watermark_utils.cpython-310.pyc +0 -0
- {src β algorithm}/__pycache__/watermark_demo.cpython-310.pyc +0 -0
- {src β algorithm}/__pycache__/watermark_engine.cpython-310.pyc +0 -0
- {src β algorithm}/__pycache__/watermark_processor.cpython-310.pyc +0 -0
- {src β algorithm}/__pycache__/watermark_utils.cpython-310.pyc +0 -0
- {src β algorithm}/extended_watermark_engine.py +0 -0
- {src β algorithm}/watermark_engine.py +0 -0
- app.py +254 -32
- components/__pycache__/homoglyphs.cpython-310.pyc +0 -0
- components/__pycache__/normalizers.cpython-310.pyc +0 -0
- components/__pycache__/prf_schemes.cpython-310.pyc +0 -0
- components/__pycache__/utils.cpython-310.pyc +0 -0
- {data β components/data}/__init__.py +0 -0
- {data β components/data}/categories.json +0 -0
- {data β components/data}/confusables.json +0 -0
- {data β components/data}/languages.json +0 -0
- components/utils.py +32 -0
- src/watermark_demo.py +0 -291
- test.py +0 -31
__pycache__/app.cpython-310.pyc
ADDED
Binary file (7.62 kB). View file
|
|
{src β algorithm}/__pycache__/demo_watermark.cpython-310.pyc
RENAMED
File without changes
|
{src β algorithm}/__pycache__/extended_watermark_engine.cpython-310.pyc
RENAMED
Binary files a/src/__pycache__/extended_watermark_engine.cpython-310.pyc and b/algorithm/__pycache__/extended_watermark_engine.cpython-310.pyc differ
|
|
{src β algorithm}/__pycache__/extended_watermark_processor.cpython-310.pyc
RENAMED
File without changes
|
{src β algorithm}/__pycache__/extended_watermark_utils.cpython-310.pyc
RENAMED
File without changes
|
{src β algorithm}/__pycache__/watermark_demo.cpython-310.pyc
RENAMED
File without changes
|
{src β algorithm}/__pycache__/watermark_engine.cpython-310.pyc
RENAMED
Binary files a/src/__pycache__/watermark_engine.cpython-310.pyc and b/algorithm/__pycache__/watermark_engine.cpython-310.pyc differ
|
|
{src β algorithm}/__pycache__/watermark_processor.cpython-310.pyc
RENAMED
File without changes
|
{src β algorithm}/__pycache__/watermark_utils.cpython-310.pyc
RENAMED
File without changes
|
{src β algorithm}/extended_watermark_engine.py
RENAMED
File without changes
|
{src β algorithm}/watermark_engine.py
RENAMED
File without changes
|
app.py
CHANGED
@@ -1,37 +1,259 @@
|
|
1 |
-
|
2 |
-
from
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
|
31 |
-
|
32 |
-
args.__dict__.update(default_arg_dict)
|
33 |
-
return args
|
34 |
|
35 |
if __name__ == "__main__":
|
36 |
args = get_default_args()
|
37 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
from functools import partial
|
3 |
+
import gradio as gr
|
4 |
+
import torch
|
5 |
+
from transformers import (AutoTokenizer,
|
6 |
+
AutoModelForSeq2SeqLM,
|
7 |
+
AutoModelForCausalLM,
|
8 |
+
LogitsProcessorList)
|
9 |
+
from algorithm.watermark_engine import LogitsProcessorWithWatermark, WatermarkAnalyzer
|
10 |
+
from algorithm.extended_watermark_engine import LogitsProcessorWithWatermarkExtended, WatermarkAnalyzerExtended
|
11 |
+
from components.utils import process_args, get_default_prompt, display_prompt, display_results, parse_args, list_format_scores, get_default_args
|
12 |
+
|
13 |
+
|
14 |
+
def run_gradio(args, model=None, device=None, tokenizer=None):
|
15 |
+
"""Define and launch with gradio"""
|
16 |
+
generate_partial = partial(generate, model=model, device=device, tokenizer=tokenizer)
|
17 |
+
detect_partial = partial(analyze, device=device, tokenizer=tokenizer)
|
18 |
+
|
19 |
+
with gr.Blocks(theme=gr.themes.Soft(primary_hue="orange"), css="footer{display:none !important}") as demo:
|
20 |
+
with gr.Row():
|
21 |
+
with gr.Column(scale=9):
|
22 |
+
gr.Markdown(
|
23 |
+
"""
|
24 |
+
## Plagiarism detection for Large Language Models through watermarking
|
25 |
+
"""
|
26 |
+
)
|
27 |
+
with gr.Column(scale=2):
|
28 |
+
algorithm = gr.Radio(label="Watermark Algorithm", info="which algorithm would you like to use?", choices=["basic", "advance"], value=("advance" if args.run_extended else "basic"))
|
29 |
+
|
30 |
+
gr.Markdown(f"Language model: {args.model_name_or_path} {'(float16 mode)' if args.load_fp16 else ''}")
|
31 |
+
|
32 |
+
default_prompt = args.__dict__.pop("default_prompt")
|
33 |
+
session_args = gr.State(value=args)
|
34 |
+
|
35 |
+
with gr.Tab("Generate and Detect"):
|
36 |
+
|
37 |
+
with gr.Row():
|
38 |
+
prompt = gr.Textbox(label=f"Prompt", interactive=True,lines=10,max_lines=10, value=default_prompt)
|
39 |
+
with gr.Row():
|
40 |
+
generate_btn = gr.Button("Generate")
|
41 |
+
with gr.Row():
|
42 |
+
with gr.Column(scale=2):
|
43 |
+
output_without_watermark = gr.Textbox(label="Output Without Watermark", interactive=False,lines=14,max_lines=14)
|
44 |
+
with gr.Column(scale=1):
|
45 |
+
without_watermark_detection_result = gr.Dataframe(headers=["Metric", "Value"], interactive=False,row_count=7,col_count=2)
|
46 |
+
with gr.Row():
|
47 |
+
with gr.Column(scale=2):
|
48 |
+
output_with_watermark = gr.Textbox(label="Output With Watermark", interactive=False,lines=14,max_lines=14)
|
49 |
+
with gr.Column(scale=1):
|
50 |
+
with_watermark_detection_result = gr.Dataframe(headers=["Metric", "Value"],interactive=False,row_count=7,col_count=2)
|
51 |
+
|
52 |
+
redecoded_input = gr.Textbox(visible=False)
|
53 |
+
truncation_warning = gr.Number(visible=False)
|
54 |
+
def truncate_prompt(redecoded_input, truncation_warning, orig_prompt, args):
|
55 |
+
if truncation_warning:
|
56 |
+
return redecoded_input + f"\n\n[Prompt was truncated before generation due to length...]", args
|
57 |
+
else:
|
58 |
+
return orig_prompt, args
|
59 |
+
|
60 |
+
with gr.Tab("Detector Only"):
|
61 |
+
with gr.Row():
|
62 |
+
with gr.Column(scale=2):
|
63 |
+
detection_input = gr.Textbox(label="Text to Analyze", interactive=True,lines=14,max_lines=14)
|
64 |
+
with gr.Column(scale=1):
|
65 |
+
detection_result = gr.Dataframe(headers=["Metric", "Value"], interactive=False,row_count=7,col_count=2)
|
66 |
+
with gr.Row():
|
67 |
+
detect_btn = gr.Button("Detect")
|
68 |
+
|
69 |
+
gr.HTML("""
|
70 |
+
<p style="color: gray;">Built with π€ by Charles Apochi
|
71 |
+
<br/>
|
72 |
+
<a href="mailto:charlesapochi@gmail.com" style="text-decoration: none; color: orange;">Reach out</a>
|
73 |
+
<p/>
|
74 |
+
""")
|
75 |
+
|
76 |
+
generate_btn.click(fn=generate_partial, inputs=[prompt,session_args], outputs=[redecoded_input, truncation_warning, output_without_watermark, output_with_watermark,session_args])
|
77 |
+
redecoded_input.change(fn=truncate_prompt, inputs=[redecoded_input,truncation_warning,prompt,session_args], outputs=[prompt,session_args])
|
78 |
+
output_without_watermark.change(fn=detect_partial, inputs=[output_without_watermark,session_args], outputs=[without_watermark_detection_result,session_args])
|
79 |
+
output_with_watermark.change(fn=detect_partial, inputs=[output_with_watermark,session_args], outputs=[with_watermark_detection_result,session_args])
|
80 |
+
detect_btn.click(fn=detect_partial, inputs=[detection_input,session_args], outputs=[detection_result, session_args])
|
81 |
+
|
82 |
+
# State management logic
|
83 |
+
def update_algorithm(session_state, value):
|
84 |
+
if value == "advance":
|
85 |
+
session_state.run_extended = True
|
86 |
+
elif value == "basic":
|
87 |
+
session_state.run_extended = False
|
88 |
+
return session_state,
|
89 |
+
|
90 |
+
|
91 |
+
algorithm.change(update_algorithm,inputs=[session_args, algorithm], outputs=[session_args])
|
92 |
+
|
93 |
+
demo.launch(share=args.demo_public)
|
94 |
+
|
95 |
+
|
96 |
+
def load_model(args):
|
97 |
+
"""Load and return the model and tokenizer"""
|
98 |
+
args.is_decoder_only_model = True
|
99 |
+
|
100 |
+
model = AutoModelForCausalLM.from_pretrained(
|
101 |
+
args.model_name_or_path,
|
102 |
+
# device_map="auto",
|
103 |
+
# torch_dtype=torch.float16,
|
104 |
+
)
|
105 |
+
|
106 |
+
if args.use_gpu:
|
107 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
108 |
+
if args.load_fp16:
|
109 |
+
pass
|
110 |
+
else:
|
111 |
+
model = model.to(device)
|
112 |
+
else:
|
113 |
+
device = "cpu" #"mps" if args.run_extended else "cpu"
|
114 |
+
|
115 |
+
model.eval()
|
116 |
+
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
|
117 |
+
|
118 |
+
return model, tokenizer, device
|
119 |
+
|
120 |
+
def generate(prompt, args, model=None, device=None, tokenizer=None):
|
121 |
+
print(f"Generating with {args}")
|
122 |
+
|
123 |
+
if args.run_extended:
|
124 |
+
watermark_processor = LogitsProcessorWithWatermarkExtended(vocab=list(tokenizer.get_vocab().values()),
|
125 |
+
gamma=args.gamma,
|
126 |
+
delta=args.delta,
|
127 |
+
seeding_scheme=args.seeding_scheme,
|
128 |
+
select_green_tokens=args.select_green_tokens)
|
129 |
+
else:
|
130 |
+
watermark_processor = LogitsProcessorWithWatermark(vocab=list(tokenizer.get_vocab().values()),
|
131 |
+
gamma=args.gamma,
|
132 |
+
delta=args.delta,
|
133 |
+
seeding_scheme=args.seeding_scheme,
|
134 |
+
select_green_tokens=args.select_green_tokens)
|
135 |
+
|
136 |
+
gen_kwargs = dict(max_new_tokens=args.max_new_tokens)
|
137 |
+
|
138 |
+
if args.use_sampling:
|
139 |
+
gen_kwargs.update(dict(
|
140 |
+
do_sample=True,
|
141 |
+
top_k=0,
|
142 |
+
temperature=args.sampling_temp,
|
143 |
+
))
|
144 |
+
else:
|
145 |
+
gen_kwargs.update(dict(
|
146 |
+
num_beams=args.n_beams,
|
147 |
+
))
|
148 |
+
|
149 |
+
generate_without_watermark = partial(
|
150 |
+
model.generate,
|
151 |
+
**gen_kwargs
|
152 |
+
)
|
153 |
+
generate_with_watermark = partial(
|
154 |
+
model.generate,
|
155 |
+
logits_processor=LogitsProcessorList([watermark_processor]),
|
156 |
+
**gen_kwargs
|
157 |
+
)
|
158 |
+
if args.prompt_max_length:
|
159 |
+
pass
|
160 |
+
elif hasattr(model.config,"max_position_embedding"):
|
161 |
+
args.prompt_max_length = model.config.max_position_embeddings-args.max_new_tokens
|
162 |
+
else:
|
163 |
+
args.prompt_max_length = 2048-args.max_new_tokens
|
164 |
+
|
165 |
+
tokd_input = tokenizer(prompt, return_tensors="pt", add_special_tokens=True, truncation=True, max_length=args.prompt_max_length).to(device)
|
166 |
+
truncation_warning = True if tokd_input["input_ids"].shape[-1] == args.prompt_max_length else False
|
167 |
+
redecoded_input = tokenizer.batch_decode(tokd_input["input_ids"], skip_special_tokens=True)[0]
|
168 |
+
|
169 |
+
torch.manual_seed(args.generation_seed)
|
170 |
+
output_without_watermark = generate_without_watermark(**tokd_input)
|
171 |
+
|
172 |
+
if args.seed_separately:
|
173 |
+
torch.manual_seed(args.generation_seed)
|
174 |
+
output_with_watermark = generate_with_watermark(**tokd_input)
|
175 |
+
|
176 |
+
if args.is_decoder_only_model:
|
177 |
+
# need to isolate the newly generated tokens
|
178 |
+
output_without_watermark = output_without_watermark[:,tokd_input["input_ids"].shape[-1]:]
|
179 |
+
output_with_watermark = output_with_watermark[:,tokd_input["input_ids"].shape[-1]:]
|
180 |
+
|
181 |
+
decoded_output_without_watermark = tokenizer.batch_decode(output_without_watermark, skip_special_tokens=True)[0]
|
182 |
+
decoded_output_with_watermark = tokenizer.batch_decode(output_with_watermark, skip_special_tokens=True)[0]
|
183 |
+
|
184 |
+
return (redecoded_input,
|
185 |
+
int(truncation_warning),
|
186 |
+
decoded_output_without_watermark,
|
187 |
+
decoded_output_with_watermark,
|
188 |
+
args)
|
189 |
+
# decoded_output_with_watermark)
|
190 |
+
|
191 |
+
def analyze(input_text, args, device=None, tokenizer=None):
|
192 |
+
|
193 |
+
detector_args = {
|
194 |
+
"vocab": list(tokenizer.get_vocab().values()),
|
195 |
+
"gamma": args.gamma,
|
196 |
+
"delta": args.delta,
|
197 |
+
"seeding_scheme": args.seeding_scheme,
|
198 |
+
"select_green_tokens": args.select_green_tokens,
|
199 |
+
"device": device,
|
200 |
+
"tokenizer": tokenizer,
|
201 |
+
"z_threshold": args.detection_z_threshold,
|
202 |
+
"normalizers": args.normalizers,
|
203 |
}
|
204 |
+
if args.run_extended:
|
205 |
+
detector_args["ignore_repeated_ngrams"] = args.ignore_repeated_ngrams
|
206 |
+
else:
|
207 |
+
detector_args["skip_repeated_bigrams"] = args.skip_repeated_bigrams
|
208 |
+
|
209 |
+
if args.run_extended:
|
210 |
+
watermark_detector = WatermarkAnalyzerExtended(**detector_args)
|
211 |
+
else:
|
212 |
+
watermark_detector = WatermarkAnalyzer(**detector_args)
|
213 |
+
|
214 |
+
if args.run_extended:
|
215 |
+
score_dict = watermark_detector.analyze(input_text)
|
216 |
+
output = list_format_scores(score_dict, watermark_detector.z_threshold)
|
217 |
+
else:
|
218 |
+
if len(input_text)-1 > watermark_detector.min_prefix_len:
|
219 |
+
score_dict = watermark_detector.analyze(input_text)
|
220 |
+
# output = str_format_scores(score_dict, watermark_detector.z_threshold)
|
221 |
+
output = list_format_scores(score_dict, watermark_detector.z_threshold)
|
222 |
+
else:
|
223 |
+
# output = (f"Error: string not long enough to compute watermark presence.")
|
224 |
+
output = [["Error","string too short to compute metrics"]]
|
225 |
+
output += [["",""] for _ in range(6)]
|
226 |
|
227 |
+
return output, args
|
|
|
|
|
228 |
|
229 |
if __name__ == "__main__":
|
230 |
args = get_default_args()
|
231 |
+
# args = process_args(args)
|
232 |
+
input_text = get_default_prompt()
|
233 |
+
args.default_prompt = input_text
|
234 |
+
|
235 |
+
if not args.skip_model_load:
|
236 |
+
model, tokenizer, device = load_model(args)
|
237 |
+
else:
|
238 |
+
model, tokenizer, device = None, None, None
|
239 |
+
|
240 |
+
if not args.skip_model_load:
|
241 |
+
display_prompt(input_text)
|
242 |
+
|
243 |
+
_, _, decoded_output_without_watermark, decoded_output_with_watermark, _ = generate(
|
244 |
+
input_text, args, model=model, device=device, tokenizer=tokenizer
|
245 |
+
)
|
246 |
+
|
247 |
+
without_watermark_detection_result = analyze(
|
248 |
+
decoded_output_without_watermark, args, device=device, tokenizer=tokenizer
|
249 |
+
)
|
250 |
+
|
251 |
+
with_watermark_detection_result = analyze(
|
252 |
+
decoded_output_with_watermark, args, device=device, tokenizer=tokenizer
|
253 |
+
)
|
254 |
+
|
255 |
+
display_results(decoded_output_without_watermark, without_watermark_detection_result, args, with_watermark=False)
|
256 |
+
display_results(decoded_output_with_watermark, with_watermark_detection_result, args, with_watermark=True)
|
257 |
+
|
258 |
+
if args.run_gradio:
|
259 |
+
run_gradio(args, model=model, tokenizer=tokenizer, device=device)
|
components/__pycache__/homoglyphs.cpython-310.pyc
CHANGED
Binary files a/components/__pycache__/homoglyphs.cpython-310.pyc and b/components/__pycache__/homoglyphs.cpython-310.pyc differ
|
|
components/__pycache__/normalizers.cpython-310.pyc
CHANGED
Binary files a/components/__pycache__/normalizers.cpython-310.pyc and b/components/__pycache__/normalizers.cpython-310.pyc differ
|
|
components/__pycache__/prf_schemes.cpython-310.pyc
CHANGED
Binary files a/components/__pycache__/prf_schemes.cpython-310.pyc and b/components/__pycache__/prf_schemes.cpython-310.pyc differ
|
|
components/__pycache__/utils.cpython-310.pyc
CHANGED
Binary files a/components/__pycache__/utils.cpython-310.pyc and b/components/__pycache__/utils.cpython-310.pyc differ
|
|
{data β components/data}/__init__.py
RENAMED
File without changes
|
{data β components/data}/categories.json
RENAMED
File without changes
|
{data β components/data}/confusables.json
RENAMED
File without changes
|
{data β components/data}/languages.json
RENAMED
File without changes
|
components/utils.py
CHANGED
@@ -1,6 +1,38 @@
|
|
1 |
from pprint import pprint
|
2 |
import argparse
|
3 |
from itertools import chain, tee
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
|
5 |
def process_args(args):
|
6 |
"""Process and normalize command-line arguments."""
|
|
|
1 |
from pprint import pprint
|
2 |
import argparse
|
3 |
from itertools import chain, tee
|
4 |
+
from argparse import Namespace
|
5 |
+
|
6 |
+
def get_default_args():
|
7 |
+
"""Returns the default arguments as a Namespace object."""
|
8 |
+
default_arg_dict = {
|
9 |
+
'run_gradio': True,
|
10 |
+
'run_extended': True,
|
11 |
+
'demo_public': False,
|
12 |
+
'model_name_or_path': 'google/gemma-2-2b-it', #'bigscience/bloom-560m', #'meta-llama/Meta-Llama-3-8B',
|
13 |
+
'load_fp16': False,
|
14 |
+
'prompt_max_length': None,
|
15 |
+
'max_new_tokens': 200,
|
16 |
+
'generation_seed': 123,
|
17 |
+
'use_sampling': True,
|
18 |
+
'n_beams': 1,
|
19 |
+
'sampling_temp': 0.7,
|
20 |
+
'use_gpu': False,
|
21 |
+
'seeding_scheme': 'simple_1',
|
22 |
+
'gamma': 0.25,
|
23 |
+
'delta': 2.0,
|
24 |
+
'normalizers': '',
|
25 |
+
'skip_repeated_bigrams': False,
|
26 |
+
'ignore_repeated_ngrams': False,
|
27 |
+
'detection_z_threshold': 4.0,
|
28 |
+
'select_green_tokens': True,
|
29 |
+
'skip_model_load': False,
|
30 |
+
'seed_separately': True,
|
31 |
+
}
|
32 |
+
|
33 |
+
args = Namespace()
|
34 |
+
args.__dict__.update(default_arg_dict)
|
35 |
+
return args
|
36 |
|
37 |
def process_args(args):
|
38 |
"""Process and normalize command-line arguments."""
|
src/watermark_demo.py
DELETED
@@ -1,291 +0,0 @@
|
|
1 |
-
from distutils.command.config import config
|
2 |
-
import sys
|
3 |
-
from functools import partial
|
4 |
-
import gradio as gr
|
5 |
-
import torch
|
6 |
-
from transformers import (AutoTokenizer,
|
7 |
-
AutoModelForSeq2SeqLM,
|
8 |
-
AutoModelForCausalLM,
|
9 |
-
LogitsProcessorList)
|
10 |
-
from src.watermark_engine import LogitsProcessorWithWatermark, WatermarkAnalyzer
|
11 |
-
from src.extended_watermark_engine import LogitsProcessorWithWatermarkExtended, WatermarkAnalyzerExtended
|
12 |
-
from components.utils import process_args, get_default_prompt, display_prompt, display_results, parse_args, list_format_scores
|
13 |
-
|
14 |
-
|
15 |
-
def run_gradio(args, model=None, device=None, tokenizer=None):
|
16 |
-
"""Define and launch with gradio"""
|
17 |
-
generate_partial = partial(generate, model=model, device=device, tokenizer=tokenizer)
|
18 |
-
detect_partial = partial(analyze, device=device, tokenizer=tokenizer)
|
19 |
-
|
20 |
-
with gr.Blocks(theme=gr.themes.Soft(primary_hue="orange"), css="footer{display:none !important}") as demo:
|
21 |
-
with gr.Row():
|
22 |
-
with gr.Column(scale=9):
|
23 |
-
gr.Markdown(
|
24 |
-
"""
|
25 |
-
## Plagiarism detection for Large Language Models through watermarking
|
26 |
-
"""
|
27 |
-
)
|
28 |
-
with gr.Column(scale=2):
|
29 |
-
algorithm = gr.Radio(label="Watermark Algorithm", info="which algorithm would you like to use?", choices=["basic", "advance"], value=("advance" if args.run_extended else "basic"))
|
30 |
-
|
31 |
-
|
32 |
-
default_prompt = args.__dict__.pop("default_prompt")
|
33 |
-
session_args = gr.State(value=args)
|
34 |
-
initial_args = {
|
35 |
-
"model_name_or_path": args.model_name_or_path,
|
36 |
-
"load_fp16": False
|
37 |
-
}
|
38 |
-
|
39 |
-
session_md = gr.State(value=initial_args)
|
40 |
-
|
41 |
-
def update_display(args):
|
42 |
-
return f"Language model: {args['model_name_or_path']} {'(float16 mode)' if args['load_fp16'] else ''}"
|
43 |
-
|
44 |
-
display_md = gr.Markdown(update_display(initial_args))
|
45 |
-
|
46 |
-
# gr.Markdown(f"Language model: {args.model_name_or_path} {'(float16 mode)' if args.load_fp16 else ''}")
|
47 |
-
|
48 |
-
with gr.Tab("Generate and Detect"):
|
49 |
-
|
50 |
-
with gr.Row():
|
51 |
-
prompt = gr.Textbox(label=f"Prompt", interactive=True,lines=10,max_lines=10, value=default_prompt)
|
52 |
-
with gr.Row():
|
53 |
-
generate_btn = gr.Button("Generate")
|
54 |
-
with gr.Row():
|
55 |
-
with gr.Column(scale=2):
|
56 |
-
output_without_watermark = gr.Textbox(label="Output Without Watermark", interactive=False,lines=14,max_lines=14)
|
57 |
-
with gr.Column(scale=1):
|
58 |
-
without_watermark_detection_result = gr.Dataframe(headers=["Metric", "Value"], interactive=False,row_count=7,col_count=2)
|
59 |
-
with gr.Row():
|
60 |
-
with gr.Column(scale=2):
|
61 |
-
output_with_watermark = gr.Textbox(label="Output With Watermark", interactive=False,lines=14,max_lines=14)
|
62 |
-
with gr.Column(scale=1):
|
63 |
-
with_watermark_detection_result = gr.Dataframe(headers=["Metric", "Value"],interactive=False,row_count=7,col_count=2)
|
64 |
-
|
65 |
-
redecoded_input = gr.Textbox(visible=False)
|
66 |
-
truncation_warning = gr.Number(visible=False)
|
67 |
-
def truncate_prompt(redecoded_input, truncation_warning, orig_prompt, args):
|
68 |
-
if truncation_warning:
|
69 |
-
return redecoded_input + f"\n\n[Prompt was truncated before generation due to length...]", args
|
70 |
-
else:
|
71 |
-
return orig_prompt, args
|
72 |
-
|
73 |
-
with gr.Tab("Detector Only"):
|
74 |
-
with gr.Row():
|
75 |
-
with gr.Column(scale=2):
|
76 |
-
detection_input = gr.Textbox(label="Text to Analyze", interactive=True,lines=14,max_lines=14)
|
77 |
-
with gr.Column(scale=1):
|
78 |
-
detection_result = gr.Dataframe(headers=["Metric", "Value"], interactive=False,row_count=7,col_count=2)
|
79 |
-
with gr.Row():
|
80 |
-
detect_btn = gr.Button("Detect")
|
81 |
-
|
82 |
-
gr.HTML("""
|
83 |
-
<p style="color: gray;">Built with π€ by Charles Apochi
|
84 |
-
<br/>
|
85 |
-
<a href="mailto:charlesapochi@gmail.com" style="text-decoration: none; color: orange;">Reach out</a>
|
86 |
-
<p/>
|
87 |
-
""")
|
88 |
-
|
89 |
-
generate_btn.click(fn=generate_partial, inputs=[prompt,session_args], outputs=[redecoded_input, truncation_warning, output_without_watermark, output_with_watermark,session_args])
|
90 |
-
redecoded_input.change(fn=truncate_prompt, inputs=[redecoded_input,truncation_warning,prompt,session_args], outputs=[prompt,session_args])
|
91 |
-
output_without_watermark.change(fn=detect_partial, inputs=[output_without_watermark,session_args], outputs=[without_watermark_detection_result,session_args])
|
92 |
-
output_with_watermark.change(fn=detect_partial, inputs=[output_with_watermark,session_args], outputs=[with_watermark_detection_result,session_args])
|
93 |
-
detect_btn.click(fn=detect_partial, inputs=[detection_input,session_args], outputs=[detection_result, session_args])
|
94 |
-
|
95 |
-
# State management logic
|
96 |
-
def update_algorithm(session_state, session_md, value):
|
97 |
-
new_md = session_md.copy()
|
98 |
-
if value == "advance":
|
99 |
-
session_state.run_extended = True
|
100 |
-
session_state.model_name_or_path = 'google/gemma-2-2b-it'
|
101 |
-
new_md['model_name_or_path'] = 'google/gemma-2-2b-it'
|
102 |
-
elif value == "basic":
|
103 |
-
session_state.run_extended = False
|
104 |
-
session_state.model_name_or_path = 'bigscience/bloom-560m'
|
105 |
-
new_md['model_name_or_path'] = 'bigscience/bloom-560m'
|
106 |
-
return session_state, new_md, update_display(new_md)
|
107 |
-
|
108 |
-
|
109 |
-
algorithm.change(update_algorithm,inputs=[session_args, session_md, algorithm], outputs=[session_args, session_md, display_md])
|
110 |
-
|
111 |
-
demo.launch(share=args.demo_public)
|
112 |
-
|
113 |
-
def load_model(args):
|
114 |
-
"""Load and return the model and tokenizer"""
|
115 |
-
args.is_decoder_only_model = True
|
116 |
-
if args.run_extended:
|
117 |
-
model = AutoModelForCausalLM.from_pretrained(
|
118 |
-
args.model_name_or_path,
|
119 |
-
device_map="auto",
|
120 |
-
torch_dtype=torch.float16,
|
121 |
-
)
|
122 |
-
else:
|
123 |
-
model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path)
|
124 |
-
|
125 |
-
if args.use_gpu:
|
126 |
-
device = "cuda" if torch.cuda.is_available() else "cpu"
|
127 |
-
if args.load_fp16:
|
128 |
-
pass
|
129 |
-
else:
|
130 |
-
model = model.to(device)
|
131 |
-
else:
|
132 |
-
device = "mps" if args.run_extended else "cpu"
|
133 |
-
|
134 |
-
model.eval()
|
135 |
-
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
|
136 |
-
|
137 |
-
return model, tokenizer, device
|
138 |
-
|
139 |
-
def generate(prompt, args, model=None, device=None, tokenizer=None):
|
140 |
-
print(f"Generating with {args}")
|
141 |
-
|
142 |
-
model, tokenizer, device = load_model(args)
|
143 |
-
|
144 |
-
if args.run_extended:
|
145 |
-
watermark_processor = LogitsProcessorWithWatermarkExtended(vocab=list(tokenizer.get_vocab().values()),
|
146 |
-
gamma=args.gamma,
|
147 |
-
delta=args.delta,
|
148 |
-
seeding_scheme=args.seeding_scheme,
|
149 |
-
select_green_tokens=args.select_green_tokens)
|
150 |
-
else:
|
151 |
-
watermark_processor = LogitsProcessorWithWatermark(vocab=list(tokenizer.get_vocab().values()),
|
152 |
-
gamma=args.gamma,
|
153 |
-
delta=args.delta,
|
154 |
-
seeding_scheme=args.seeding_scheme,
|
155 |
-
select_green_tokens=args.select_green_tokens)
|
156 |
-
|
157 |
-
gen_kwargs = dict(max_new_tokens=args.max_new_tokens)
|
158 |
-
|
159 |
-
if args.use_sampling:
|
160 |
-
gen_kwargs.update(dict(
|
161 |
-
do_sample=True,
|
162 |
-
top_k=0,
|
163 |
-
temperature=args.sampling_temp,
|
164 |
-
))
|
165 |
-
else:
|
166 |
-
gen_kwargs.update(dict(
|
167 |
-
num_beams=args.n_beams,
|
168 |
-
))
|
169 |
-
|
170 |
-
generate_without_watermark = partial(
|
171 |
-
model.generate,
|
172 |
-
**gen_kwargs
|
173 |
-
)
|
174 |
-
generate_with_watermark = partial(
|
175 |
-
model.generate,
|
176 |
-
logits_processor=LogitsProcessorList([watermark_processor]),
|
177 |
-
**gen_kwargs
|
178 |
-
)
|
179 |
-
if args.prompt_max_length:
|
180 |
-
pass
|
181 |
-
elif hasattr(model.config,"max_position_embedding"):
|
182 |
-
args.prompt_max_length = model.config.max_position_embeddings-args.max_new_tokens
|
183 |
-
else:
|
184 |
-
args.prompt_max_length = 2048-args.max_new_tokens
|
185 |
-
|
186 |
-
tokd_input = tokenizer(prompt, return_tensors="pt", add_special_tokens=True, truncation=True, max_length=args.prompt_max_length).to(device)
|
187 |
-
truncation_warning = True if tokd_input["input_ids"].shape[-1] == args.prompt_max_length else False
|
188 |
-
redecoded_input = tokenizer.batch_decode(tokd_input["input_ids"], skip_special_tokens=True)[0]
|
189 |
-
|
190 |
-
torch.manual_seed(args.generation_seed)
|
191 |
-
output_without_watermark = generate_without_watermark(**tokd_input)
|
192 |
-
|
193 |
-
if args.seed_separately:
|
194 |
-
torch.manual_seed(args.generation_seed)
|
195 |
-
output_with_watermark = generate_with_watermark(**tokd_input)
|
196 |
-
|
197 |
-
if args.is_decoder_only_model:
|
198 |
-
# need to isolate the newly generated tokens
|
199 |
-
output_without_watermark = output_without_watermark[:,tokd_input["input_ids"].shape[-1]:]
|
200 |
-
output_with_watermark = output_with_watermark[:,tokd_input["input_ids"].shape[-1]:]
|
201 |
-
|
202 |
-
decoded_output_without_watermark = tokenizer.batch_decode(output_without_watermark, skip_special_tokens=True)[0]
|
203 |
-
decoded_output_with_watermark = tokenizer.batch_decode(output_with_watermark, skip_special_tokens=True)[0]
|
204 |
-
|
205 |
-
return (redecoded_input,
|
206 |
-
int(truncation_warning),
|
207 |
-
decoded_output_without_watermark,
|
208 |
-
decoded_output_with_watermark,
|
209 |
-
args)
|
210 |
-
# decoded_output_with_watermark)
|
211 |
-
|
212 |
-
def analyze(input_text, args, device=None, tokenizer=None):
|
213 |
-
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
|
214 |
-
device = "mps" if args.run_extended else "cpu"
|
215 |
-
|
216 |
-
detector_args = {
|
217 |
-
"vocab": list(tokenizer.get_vocab().values()),
|
218 |
-
"gamma": args.gamma,
|
219 |
-
"delta": args.delta,
|
220 |
-
"seeding_scheme": args.seeding_scheme,
|
221 |
-
"select_green_tokens": args.select_green_tokens,
|
222 |
-
"device": device,
|
223 |
-
"tokenizer": tokenizer,
|
224 |
-
"z_threshold": args.detection_z_threshold,
|
225 |
-
"normalizers": args.normalizers,
|
226 |
-
}
|
227 |
-
if args.run_extended:
|
228 |
-
detector_args["ignore_repeated_ngrams"] = args.ignore_repeated_ngrams
|
229 |
-
else:
|
230 |
-
detector_args["skip_repeated_bigrams"] = args.skip_repeated_bigrams
|
231 |
-
|
232 |
-
if args.run_extended:
|
233 |
-
watermark_detector = WatermarkAnalyzerExtended(**detector_args)
|
234 |
-
else:
|
235 |
-
watermark_detector = WatermarkAnalyzer(**detector_args)
|
236 |
-
|
237 |
-
if args.run_extended:
|
238 |
-
score_dict = watermark_detector.analyze(input_text)
|
239 |
-
output = list_format_scores(score_dict, watermark_detector.z_threshold)
|
240 |
-
else:
|
241 |
-
if len(input_text)-1 > watermark_detector.min_prefix_len:
|
242 |
-
score_dict = watermark_detector.analyze(input_text)
|
243 |
-
# output = str_format_scores(score_dict, watermark_detector.z_threshold)
|
244 |
-
output = list_format_scores(score_dict, watermark_detector.z_threshold)
|
245 |
-
else:
|
246 |
-
# output = (f"Error: string not long enough to compute watermark presence.")
|
247 |
-
output = [["Error","string too short to compute metrics"]]
|
248 |
-
output += [["",""] for _ in range(6)]
|
249 |
-
|
250 |
-
return output, args
|
251 |
-
|
252 |
-
|
253 |
-
def main(args):
|
254 |
-
"""Run the main script for generation and detection"""
|
255 |
-
args = process_args(args)
|
256 |
-
input_text = get_default_prompt()
|
257 |
-
args.default_prompt = input_text
|
258 |
-
|
259 |
-
if not args.skip_model_load:
|
260 |
-
model, tokenizer, device = load_model(args)
|
261 |
-
else:
|
262 |
-
model, tokenizer, device = None, None, None
|
263 |
-
|
264 |
-
if not args.skip_model_load:
|
265 |
-
display_prompt(input_text)
|
266 |
-
|
267 |
-
_, _, decoded_output_without_watermark, decoded_output_with_watermark, _ = generate(
|
268 |
-
input_text, args, model=model, device=device, tokenizer=tokenizer
|
269 |
-
)
|
270 |
-
|
271 |
-
without_watermark_detection_result = analyze(
|
272 |
-
decoded_output_without_watermark, args, device=device, tokenizer=tokenizer
|
273 |
-
)
|
274 |
-
|
275 |
-
with_watermark_detection_result = analyze(
|
276 |
-
decoded_output_with_watermark, args, device=device, tokenizer=tokenizer
|
277 |
-
)
|
278 |
-
|
279 |
-
display_results(decoded_output_without_watermark, without_watermark_detection_result, args, with_watermark=False)
|
280 |
-
display_results(decoded_output_with_watermark, with_watermark_detection_result, args, with_watermark=True)
|
281 |
-
|
282 |
-
if args.run_gradio:
|
283 |
-
run_gradio(args, model=model, tokenizer=tokenizer, device=device)
|
284 |
-
|
285 |
-
return
|
286 |
-
|
287 |
-
if __name__ == "__main__":
|
288 |
-
|
289 |
-
args = parse_args()
|
290 |
-
|
291 |
-
main(args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
test.py
DELETED
@@ -1,31 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
|
3 |
-
|
4 |
-
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b-it")
|
5 |
-
model = AutoModelForCausalLM.from_pretrained(
|
6 |
-
"google/gemma-2-2b-it",
|
7 |
-
device_map="auto",
|
8 |
-
torch_dtype=torch.float16,
|
9 |
-
)
|
10 |
-
|
11 |
-
input_text = "Write me a poem about Machine Learning."
|
12 |
-
input_ids = tokenizer(input_text, return_tensors="pt").to("mps")
|
13 |
-
|
14 |
-
outputs = model.generate(**input_ids, max_new_tokens=32)
|
15 |
-
print(tokenizer.decode(outputs[0]))
|
16 |
-
|
17 |
-
pipe = pipeline(
|
18 |
-
"text-generation",
|
19 |
-
model= "google/gemma-2-2b-it",
|
20 |
-
model_kwargs={"torch_dtype": torch.float16},
|
21 |
-
device="mps", # replace with "mps" to run on a Mac device
|
22 |
-
)
|
23 |
-
|
24 |
-
messages = [
|
25 |
-
{"role": "user", "content": "Who are you? Please, answer in pirate-speak."},
|
26 |
-
]
|
27 |
-
|
28 |
-
outputs = pipe(messages, max_new_tokens=256)
|
29 |
-
assistant_response = outputs[0]["generated_text"][-1]["content"].strip()
|
30 |
-
print(assistant_response)
|
31 |
-
# Ahoy, matey! I be Gemma, a digital scallywag, a language-slingin' parrot of the digital seas. I be here to help ye with yer wordy woes, answer yer questions, and spin ye yarns of the digital world. So, what be yer pleasure, eh? π¦
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|