charlesapochi commited on
Commit
b248ba6
Β·
1 Parent(s): 80b5673

model update

Browse files
__pycache__/app.cpython-310.pyc ADDED
Binary file (7.62 kB). View file
 
{src β†’ algorithm}/__pycache__/demo_watermark.cpython-310.pyc RENAMED
File without changes
{src β†’ algorithm}/__pycache__/extended_watermark_engine.cpython-310.pyc RENAMED
Binary files a/src/__pycache__/extended_watermark_engine.cpython-310.pyc and b/algorithm/__pycache__/extended_watermark_engine.cpython-310.pyc differ
 
{src β†’ algorithm}/__pycache__/extended_watermark_processor.cpython-310.pyc RENAMED
File without changes
{src β†’ algorithm}/__pycache__/extended_watermark_utils.cpython-310.pyc RENAMED
File without changes
{src β†’ algorithm}/__pycache__/watermark_demo.cpython-310.pyc RENAMED
File without changes
{src β†’ algorithm}/__pycache__/watermark_engine.cpython-310.pyc RENAMED
Binary files a/src/__pycache__/watermark_engine.cpython-310.pyc and b/algorithm/__pycache__/watermark_engine.cpython-310.pyc differ
 
{src β†’ algorithm}/__pycache__/watermark_processor.cpython-310.pyc RENAMED
File without changes
{src β†’ algorithm}/__pycache__/watermark_utils.cpython-310.pyc RENAMED
File without changes
{src β†’ algorithm}/extended_watermark_engine.py RENAMED
File without changes
{src β†’ algorithm}/watermark_engine.py RENAMED
File without changes
app.py CHANGED
@@ -1,37 +1,259 @@
1
- from argparse import Namespace
2
- from src.watermark_demo import main
3
-
4
- def get_default_args():
5
- """Returns the default arguments as a Namespace object."""
6
- default_arg_dict = {
7
- 'run_gradio': True,
8
- 'run_extended': False,
9
- 'demo_public': False,
10
- 'model_name_or_path': 'bigscience/bloom-560m', #'google/gemma-2-2b-it', #'meta-llama/Meta-Llama-3-8B',
11
- 'load_fp16': False,
12
- 'prompt_max_length': None,
13
- 'max_new_tokens': 200,
14
- 'generation_seed': 123,
15
- 'use_sampling': True,
16
- 'n_beams': 1,
17
- 'sampling_temp': 0.7,
18
- 'use_gpu': False,
19
- 'seeding_scheme': 'simple_1',
20
- 'gamma': 0.25,
21
- 'delta': 2.0,
22
- 'normalizers': '',
23
- 'skip_repeated_bigrams': False,
24
- 'ignore_repeated_ngrams': False,
25
- 'detection_z_threshold': 4.0,
26
- 'select_green_tokens': True,
27
- 'skip_model_load': True,
28
- 'seed_separately': True,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
- args = Namespace()
32
- args.__dict__.update(default_arg_dict)
33
- return args
34
 
35
  if __name__ == "__main__":
36
  args = get_default_args()
37
- main(args)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ from functools import partial
3
+ import gradio as gr
4
+ import torch
5
+ from transformers import (AutoTokenizer,
6
+ AutoModelForSeq2SeqLM,
7
+ AutoModelForCausalLM,
8
+ LogitsProcessorList)
9
+ from algorithm.watermark_engine import LogitsProcessorWithWatermark, WatermarkAnalyzer
10
+ from algorithm.extended_watermark_engine import LogitsProcessorWithWatermarkExtended, WatermarkAnalyzerExtended
11
+ from components.utils import process_args, get_default_prompt, display_prompt, display_results, parse_args, list_format_scores, get_default_args
12
+
13
+
14
+ def run_gradio(args, model=None, device=None, tokenizer=None):
15
+ """Define and launch with gradio"""
16
+ generate_partial = partial(generate, model=model, device=device, tokenizer=tokenizer)
17
+ detect_partial = partial(analyze, device=device, tokenizer=tokenizer)
18
+
19
+ with gr.Blocks(theme=gr.themes.Soft(primary_hue="orange"), css="footer{display:none !important}") as demo:
20
+ with gr.Row():
21
+ with gr.Column(scale=9):
22
+ gr.Markdown(
23
+ """
24
+ ## Plagiarism detection for Large Language Models through watermarking
25
+ """
26
+ )
27
+ with gr.Column(scale=2):
28
+ algorithm = gr.Radio(label="Watermark Algorithm", info="which algorithm would you like to use?", choices=["basic", "advance"], value=("advance" if args.run_extended else "basic"))
29
+
30
+ gr.Markdown(f"Language model: {args.model_name_or_path} {'(float16 mode)' if args.load_fp16 else ''}")
31
+
32
+ default_prompt = args.__dict__.pop("default_prompt")
33
+ session_args = gr.State(value=args)
34
+
35
+ with gr.Tab("Generate and Detect"):
36
+
37
+ with gr.Row():
38
+ prompt = gr.Textbox(label=f"Prompt", interactive=True,lines=10,max_lines=10, value=default_prompt)
39
+ with gr.Row():
40
+ generate_btn = gr.Button("Generate")
41
+ with gr.Row():
42
+ with gr.Column(scale=2):
43
+ output_without_watermark = gr.Textbox(label="Output Without Watermark", interactive=False,lines=14,max_lines=14)
44
+ with gr.Column(scale=1):
45
+ without_watermark_detection_result = gr.Dataframe(headers=["Metric", "Value"], interactive=False,row_count=7,col_count=2)
46
+ with gr.Row():
47
+ with gr.Column(scale=2):
48
+ output_with_watermark = gr.Textbox(label="Output With Watermark", interactive=False,lines=14,max_lines=14)
49
+ with gr.Column(scale=1):
50
+ with_watermark_detection_result = gr.Dataframe(headers=["Metric", "Value"],interactive=False,row_count=7,col_count=2)
51
+
52
+ redecoded_input = gr.Textbox(visible=False)
53
+ truncation_warning = gr.Number(visible=False)
54
+ def truncate_prompt(redecoded_input, truncation_warning, orig_prompt, args):
55
+ if truncation_warning:
56
+ return redecoded_input + f"\n\n[Prompt was truncated before generation due to length...]", args
57
+ else:
58
+ return orig_prompt, args
59
+
60
+ with gr.Tab("Detector Only"):
61
+ with gr.Row():
62
+ with gr.Column(scale=2):
63
+ detection_input = gr.Textbox(label="Text to Analyze", interactive=True,lines=14,max_lines=14)
64
+ with gr.Column(scale=1):
65
+ detection_result = gr.Dataframe(headers=["Metric", "Value"], interactive=False,row_count=7,col_count=2)
66
+ with gr.Row():
67
+ detect_btn = gr.Button("Detect")
68
+
69
+ gr.HTML("""
70
+ <p style="color: gray;">Built with 🀍 by Charles Apochi
71
+ <br/>
72
+ <a href="mailto:charlesapochi@gmail.com" style="text-decoration: none; color: orange;">Reach out</a>
73
+ <p/>
74
+ """)
75
+
76
+ generate_btn.click(fn=generate_partial, inputs=[prompt,session_args], outputs=[redecoded_input, truncation_warning, output_without_watermark, output_with_watermark,session_args])
77
+ redecoded_input.change(fn=truncate_prompt, inputs=[redecoded_input,truncation_warning,prompt,session_args], outputs=[prompt,session_args])
78
+ output_without_watermark.change(fn=detect_partial, inputs=[output_without_watermark,session_args], outputs=[without_watermark_detection_result,session_args])
79
+ output_with_watermark.change(fn=detect_partial, inputs=[output_with_watermark,session_args], outputs=[with_watermark_detection_result,session_args])
80
+ detect_btn.click(fn=detect_partial, inputs=[detection_input,session_args], outputs=[detection_result, session_args])
81
+
82
+ # State management logic
83
+ def update_algorithm(session_state, value):
84
+ if value == "advance":
85
+ session_state.run_extended = True
86
+ elif value == "basic":
87
+ session_state.run_extended = False
88
+ return session_state,
89
+
90
+
91
+ algorithm.change(update_algorithm,inputs=[session_args, algorithm], outputs=[session_args])
92
+
93
+ demo.launch(share=args.demo_public)
94
+
95
+
96
+ def load_model(args):
97
+ """Load and return the model and tokenizer"""
98
+ args.is_decoder_only_model = True
99
+
100
+ model = AutoModelForCausalLM.from_pretrained(
101
+ args.model_name_or_path,
102
+ # device_map="auto",
103
+ # torch_dtype=torch.float16,
104
+ )
105
+
106
+ if args.use_gpu:
107
+ device = "cuda" if torch.cuda.is_available() else "cpu"
108
+ if args.load_fp16:
109
+ pass
110
+ else:
111
+ model = model.to(device)
112
+ else:
113
+ device = "cpu" #"mps" if args.run_extended else "cpu"
114
+
115
+ model.eval()
116
+ tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
117
+
118
+ return model, tokenizer, device
119
+
120
+ def generate(prompt, args, model=None, device=None, tokenizer=None):
121
+ print(f"Generating with {args}")
122
+
123
+ if args.run_extended:
124
+ watermark_processor = LogitsProcessorWithWatermarkExtended(vocab=list(tokenizer.get_vocab().values()),
125
+ gamma=args.gamma,
126
+ delta=args.delta,
127
+ seeding_scheme=args.seeding_scheme,
128
+ select_green_tokens=args.select_green_tokens)
129
+ else:
130
+ watermark_processor = LogitsProcessorWithWatermark(vocab=list(tokenizer.get_vocab().values()),
131
+ gamma=args.gamma,
132
+ delta=args.delta,
133
+ seeding_scheme=args.seeding_scheme,
134
+ select_green_tokens=args.select_green_tokens)
135
+
136
+ gen_kwargs = dict(max_new_tokens=args.max_new_tokens)
137
+
138
+ if args.use_sampling:
139
+ gen_kwargs.update(dict(
140
+ do_sample=True,
141
+ top_k=0,
142
+ temperature=args.sampling_temp,
143
+ ))
144
+ else:
145
+ gen_kwargs.update(dict(
146
+ num_beams=args.n_beams,
147
+ ))
148
+
149
+ generate_without_watermark = partial(
150
+ model.generate,
151
+ **gen_kwargs
152
+ )
153
+ generate_with_watermark = partial(
154
+ model.generate,
155
+ logits_processor=LogitsProcessorList([watermark_processor]),
156
+ **gen_kwargs
157
+ )
158
+ if args.prompt_max_length:
159
+ pass
160
+ elif hasattr(model.config,"max_position_embedding"):
161
+ args.prompt_max_length = model.config.max_position_embeddings-args.max_new_tokens
162
+ else:
163
+ args.prompt_max_length = 2048-args.max_new_tokens
164
+
165
+ tokd_input = tokenizer(prompt, return_tensors="pt", add_special_tokens=True, truncation=True, max_length=args.prompt_max_length).to(device)
166
+ truncation_warning = True if tokd_input["input_ids"].shape[-1] == args.prompt_max_length else False
167
+ redecoded_input = tokenizer.batch_decode(tokd_input["input_ids"], skip_special_tokens=True)[0]
168
+
169
+ torch.manual_seed(args.generation_seed)
170
+ output_without_watermark = generate_without_watermark(**tokd_input)
171
+
172
+ if args.seed_separately:
173
+ torch.manual_seed(args.generation_seed)
174
+ output_with_watermark = generate_with_watermark(**tokd_input)
175
+
176
+ if args.is_decoder_only_model:
177
+ # need to isolate the newly generated tokens
178
+ output_without_watermark = output_without_watermark[:,tokd_input["input_ids"].shape[-1]:]
179
+ output_with_watermark = output_with_watermark[:,tokd_input["input_ids"].shape[-1]:]
180
+
181
+ decoded_output_without_watermark = tokenizer.batch_decode(output_without_watermark, skip_special_tokens=True)[0]
182
+ decoded_output_with_watermark = tokenizer.batch_decode(output_with_watermark, skip_special_tokens=True)[0]
183
+
184
+ return (redecoded_input,
185
+ int(truncation_warning),
186
+ decoded_output_without_watermark,
187
+ decoded_output_with_watermark,
188
+ args)
189
+ # decoded_output_with_watermark)
190
+
191
+ def analyze(input_text, args, device=None, tokenizer=None):
192
+
193
+ detector_args = {
194
+ "vocab": list(tokenizer.get_vocab().values()),
195
+ "gamma": args.gamma,
196
+ "delta": args.delta,
197
+ "seeding_scheme": args.seeding_scheme,
198
+ "select_green_tokens": args.select_green_tokens,
199
+ "device": device,
200
+ "tokenizer": tokenizer,
201
+ "z_threshold": args.detection_z_threshold,
202
+ "normalizers": args.normalizers,
203
  }
204
+ if args.run_extended:
205
+ detector_args["ignore_repeated_ngrams"] = args.ignore_repeated_ngrams
206
+ else:
207
+ detector_args["skip_repeated_bigrams"] = args.skip_repeated_bigrams
208
+
209
+ if args.run_extended:
210
+ watermark_detector = WatermarkAnalyzerExtended(**detector_args)
211
+ else:
212
+ watermark_detector = WatermarkAnalyzer(**detector_args)
213
+
214
+ if args.run_extended:
215
+ score_dict = watermark_detector.analyze(input_text)
216
+ output = list_format_scores(score_dict, watermark_detector.z_threshold)
217
+ else:
218
+ if len(input_text)-1 > watermark_detector.min_prefix_len:
219
+ score_dict = watermark_detector.analyze(input_text)
220
+ # output = str_format_scores(score_dict, watermark_detector.z_threshold)
221
+ output = list_format_scores(score_dict, watermark_detector.z_threshold)
222
+ else:
223
+ # output = (f"Error: string not long enough to compute watermark presence.")
224
+ output = [["Error","string too short to compute metrics"]]
225
+ output += [["",""] for _ in range(6)]
226
 
227
+ return output, args
 
 
228
 
229
  if __name__ == "__main__":
230
  args = get_default_args()
231
+ # args = process_args(args)
232
+ input_text = get_default_prompt()
233
+ args.default_prompt = input_text
234
+
235
+ if not args.skip_model_load:
236
+ model, tokenizer, device = load_model(args)
237
+ else:
238
+ model, tokenizer, device = None, None, None
239
+
240
+ if not args.skip_model_load:
241
+ display_prompt(input_text)
242
+
243
+ _, _, decoded_output_without_watermark, decoded_output_with_watermark, _ = generate(
244
+ input_text, args, model=model, device=device, tokenizer=tokenizer
245
+ )
246
+
247
+ without_watermark_detection_result = analyze(
248
+ decoded_output_without_watermark, args, device=device, tokenizer=tokenizer
249
+ )
250
+
251
+ with_watermark_detection_result = analyze(
252
+ decoded_output_with_watermark, args, device=device, tokenizer=tokenizer
253
+ )
254
+
255
+ display_results(decoded_output_without_watermark, without_watermark_detection_result, args, with_watermark=False)
256
+ display_results(decoded_output_with_watermark, with_watermark_detection_result, args, with_watermark=True)
257
+
258
+ if args.run_gradio:
259
+ run_gradio(args, model=model, tokenizer=tokenizer, device=device)
components/__pycache__/homoglyphs.cpython-310.pyc CHANGED
Binary files a/components/__pycache__/homoglyphs.cpython-310.pyc and b/components/__pycache__/homoglyphs.cpython-310.pyc differ
 
components/__pycache__/normalizers.cpython-310.pyc CHANGED
Binary files a/components/__pycache__/normalizers.cpython-310.pyc and b/components/__pycache__/normalizers.cpython-310.pyc differ
 
components/__pycache__/prf_schemes.cpython-310.pyc CHANGED
Binary files a/components/__pycache__/prf_schemes.cpython-310.pyc and b/components/__pycache__/prf_schemes.cpython-310.pyc differ
 
components/__pycache__/utils.cpython-310.pyc CHANGED
Binary files a/components/__pycache__/utils.cpython-310.pyc and b/components/__pycache__/utils.cpython-310.pyc differ
 
{data β†’ components/data}/__init__.py RENAMED
File without changes
{data β†’ components/data}/categories.json RENAMED
File without changes
{data β†’ components/data}/confusables.json RENAMED
File without changes
{data β†’ components/data}/languages.json RENAMED
File without changes
components/utils.py CHANGED
@@ -1,6 +1,38 @@
1
  from pprint import pprint
2
  import argparse
3
  from itertools import chain, tee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
  def process_args(args):
6
  """Process and normalize command-line arguments."""
 
1
  from pprint import pprint
2
  import argparse
3
  from itertools import chain, tee
4
+ from argparse import Namespace
5
+
6
+ def get_default_args():
7
+ """Returns the default arguments as a Namespace object."""
8
+ default_arg_dict = {
9
+ 'run_gradio': True,
10
+ 'run_extended': True,
11
+ 'demo_public': False,
12
+ 'model_name_or_path': 'google/gemma-2-2b-it', #'bigscience/bloom-560m', #'meta-llama/Meta-Llama-3-8B',
13
+ 'load_fp16': False,
14
+ 'prompt_max_length': None,
15
+ 'max_new_tokens': 200,
16
+ 'generation_seed': 123,
17
+ 'use_sampling': True,
18
+ 'n_beams': 1,
19
+ 'sampling_temp': 0.7,
20
+ 'use_gpu': False,
21
+ 'seeding_scheme': 'simple_1',
22
+ 'gamma': 0.25,
23
+ 'delta': 2.0,
24
+ 'normalizers': '',
25
+ 'skip_repeated_bigrams': False,
26
+ 'ignore_repeated_ngrams': False,
27
+ 'detection_z_threshold': 4.0,
28
+ 'select_green_tokens': True,
29
+ 'skip_model_load': False,
30
+ 'seed_separately': True,
31
+ }
32
+
33
+ args = Namespace()
34
+ args.__dict__.update(default_arg_dict)
35
+ return args
36
 
37
  def process_args(args):
38
  """Process and normalize command-line arguments."""
src/watermark_demo.py DELETED
@@ -1,291 +0,0 @@
1
- from distutils.command.config import config
2
- import sys
3
- from functools import partial
4
- import gradio as gr
5
- import torch
6
- from transformers import (AutoTokenizer,
7
- AutoModelForSeq2SeqLM,
8
- AutoModelForCausalLM,
9
- LogitsProcessorList)
10
- from src.watermark_engine import LogitsProcessorWithWatermark, WatermarkAnalyzer
11
- from src.extended_watermark_engine import LogitsProcessorWithWatermarkExtended, WatermarkAnalyzerExtended
12
- from components.utils import process_args, get_default_prompt, display_prompt, display_results, parse_args, list_format_scores
13
-
14
-
15
- def run_gradio(args, model=None, device=None, tokenizer=None):
16
- """Define and launch with gradio"""
17
- generate_partial = partial(generate, model=model, device=device, tokenizer=tokenizer)
18
- detect_partial = partial(analyze, device=device, tokenizer=tokenizer)
19
-
20
- with gr.Blocks(theme=gr.themes.Soft(primary_hue="orange"), css="footer{display:none !important}") as demo:
21
- with gr.Row():
22
- with gr.Column(scale=9):
23
- gr.Markdown(
24
- """
25
- ## Plagiarism detection for Large Language Models through watermarking
26
- """
27
- )
28
- with gr.Column(scale=2):
29
- algorithm = gr.Radio(label="Watermark Algorithm", info="which algorithm would you like to use?", choices=["basic", "advance"], value=("advance" if args.run_extended else "basic"))
30
-
31
-
32
- default_prompt = args.__dict__.pop("default_prompt")
33
- session_args = gr.State(value=args)
34
- initial_args = {
35
- "model_name_or_path": args.model_name_or_path,
36
- "load_fp16": False
37
- }
38
-
39
- session_md = gr.State(value=initial_args)
40
-
41
- def update_display(args):
42
- return f"Language model: {args['model_name_or_path']} {'(float16 mode)' if args['load_fp16'] else ''}"
43
-
44
- display_md = gr.Markdown(update_display(initial_args))
45
-
46
- # gr.Markdown(f"Language model: {args.model_name_or_path} {'(float16 mode)' if args.load_fp16 else ''}")
47
-
48
- with gr.Tab("Generate and Detect"):
49
-
50
- with gr.Row():
51
- prompt = gr.Textbox(label=f"Prompt", interactive=True,lines=10,max_lines=10, value=default_prompt)
52
- with gr.Row():
53
- generate_btn = gr.Button("Generate")
54
- with gr.Row():
55
- with gr.Column(scale=2):
56
- output_without_watermark = gr.Textbox(label="Output Without Watermark", interactive=False,lines=14,max_lines=14)
57
- with gr.Column(scale=1):
58
- without_watermark_detection_result = gr.Dataframe(headers=["Metric", "Value"], interactive=False,row_count=7,col_count=2)
59
- with gr.Row():
60
- with gr.Column(scale=2):
61
- output_with_watermark = gr.Textbox(label="Output With Watermark", interactive=False,lines=14,max_lines=14)
62
- with gr.Column(scale=1):
63
- with_watermark_detection_result = gr.Dataframe(headers=["Metric", "Value"],interactive=False,row_count=7,col_count=2)
64
-
65
- redecoded_input = gr.Textbox(visible=False)
66
- truncation_warning = gr.Number(visible=False)
67
- def truncate_prompt(redecoded_input, truncation_warning, orig_prompt, args):
68
- if truncation_warning:
69
- return redecoded_input + f"\n\n[Prompt was truncated before generation due to length...]", args
70
- else:
71
- return orig_prompt, args
72
-
73
- with gr.Tab("Detector Only"):
74
- with gr.Row():
75
- with gr.Column(scale=2):
76
- detection_input = gr.Textbox(label="Text to Analyze", interactive=True,lines=14,max_lines=14)
77
- with gr.Column(scale=1):
78
- detection_result = gr.Dataframe(headers=["Metric", "Value"], interactive=False,row_count=7,col_count=2)
79
- with gr.Row():
80
- detect_btn = gr.Button("Detect")
81
-
82
- gr.HTML("""
83
- <p style="color: gray;">Built with 🀍 by Charles Apochi
84
- <br/>
85
- <a href="mailto:charlesapochi@gmail.com" style="text-decoration: none; color: orange;">Reach out</a>
86
- <p/>
87
- """)
88
-
89
- generate_btn.click(fn=generate_partial, inputs=[prompt,session_args], outputs=[redecoded_input, truncation_warning, output_without_watermark, output_with_watermark,session_args])
90
- redecoded_input.change(fn=truncate_prompt, inputs=[redecoded_input,truncation_warning,prompt,session_args], outputs=[prompt,session_args])
91
- output_without_watermark.change(fn=detect_partial, inputs=[output_without_watermark,session_args], outputs=[without_watermark_detection_result,session_args])
92
- output_with_watermark.change(fn=detect_partial, inputs=[output_with_watermark,session_args], outputs=[with_watermark_detection_result,session_args])
93
- detect_btn.click(fn=detect_partial, inputs=[detection_input,session_args], outputs=[detection_result, session_args])
94
-
95
- # State management logic
96
- def update_algorithm(session_state, session_md, value):
97
- new_md = session_md.copy()
98
- if value == "advance":
99
- session_state.run_extended = True
100
- session_state.model_name_or_path = 'google/gemma-2-2b-it'
101
- new_md['model_name_or_path'] = 'google/gemma-2-2b-it'
102
- elif value == "basic":
103
- session_state.run_extended = False
104
- session_state.model_name_or_path = 'bigscience/bloom-560m'
105
- new_md['model_name_or_path'] = 'bigscience/bloom-560m'
106
- return session_state, new_md, update_display(new_md)
107
-
108
-
109
- algorithm.change(update_algorithm,inputs=[session_args, session_md, algorithm], outputs=[session_args, session_md, display_md])
110
-
111
- demo.launch(share=args.demo_public)
112
-
113
- def load_model(args):
114
- """Load and return the model and tokenizer"""
115
- args.is_decoder_only_model = True
116
- if args.run_extended:
117
- model = AutoModelForCausalLM.from_pretrained(
118
- args.model_name_or_path,
119
- device_map="auto",
120
- torch_dtype=torch.float16,
121
- )
122
- else:
123
- model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path)
124
-
125
- if args.use_gpu:
126
- device = "cuda" if torch.cuda.is_available() else "cpu"
127
- if args.load_fp16:
128
- pass
129
- else:
130
- model = model.to(device)
131
- else:
132
- device = "mps" if args.run_extended else "cpu"
133
-
134
- model.eval()
135
- tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
136
-
137
- return model, tokenizer, device
138
-
139
- def generate(prompt, args, model=None, device=None, tokenizer=None):
140
- print(f"Generating with {args}")
141
-
142
- model, tokenizer, device = load_model(args)
143
-
144
- if args.run_extended:
145
- watermark_processor = LogitsProcessorWithWatermarkExtended(vocab=list(tokenizer.get_vocab().values()),
146
- gamma=args.gamma,
147
- delta=args.delta,
148
- seeding_scheme=args.seeding_scheme,
149
- select_green_tokens=args.select_green_tokens)
150
- else:
151
- watermark_processor = LogitsProcessorWithWatermark(vocab=list(tokenizer.get_vocab().values()),
152
- gamma=args.gamma,
153
- delta=args.delta,
154
- seeding_scheme=args.seeding_scheme,
155
- select_green_tokens=args.select_green_tokens)
156
-
157
- gen_kwargs = dict(max_new_tokens=args.max_new_tokens)
158
-
159
- if args.use_sampling:
160
- gen_kwargs.update(dict(
161
- do_sample=True,
162
- top_k=0,
163
- temperature=args.sampling_temp,
164
- ))
165
- else:
166
- gen_kwargs.update(dict(
167
- num_beams=args.n_beams,
168
- ))
169
-
170
- generate_without_watermark = partial(
171
- model.generate,
172
- **gen_kwargs
173
- )
174
- generate_with_watermark = partial(
175
- model.generate,
176
- logits_processor=LogitsProcessorList([watermark_processor]),
177
- **gen_kwargs
178
- )
179
- if args.prompt_max_length:
180
- pass
181
- elif hasattr(model.config,"max_position_embedding"):
182
- args.prompt_max_length = model.config.max_position_embeddings-args.max_new_tokens
183
- else:
184
- args.prompt_max_length = 2048-args.max_new_tokens
185
-
186
- tokd_input = tokenizer(prompt, return_tensors="pt", add_special_tokens=True, truncation=True, max_length=args.prompt_max_length).to(device)
187
- truncation_warning = True if tokd_input["input_ids"].shape[-1] == args.prompt_max_length else False
188
- redecoded_input = tokenizer.batch_decode(tokd_input["input_ids"], skip_special_tokens=True)[0]
189
-
190
- torch.manual_seed(args.generation_seed)
191
- output_without_watermark = generate_without_watermark(**tokd_input)
192
-
193
- if args.seed_separately:
194
- torch.manual_seed(args.generation_seed)
195
- output_with_watermark = generate_with_watermark(**tokd_input)
196
-
197
- if args.is_decoder_only_model:
198
- # need to isolate the newly generated tokens
199
- output_without_watermark = output_without_watermark[:,tokd_input["input_ids"].shape[-1]:]
200
- output_with_watermark = output_with_watermark[:,tokd_input["input_ids"].shape[-1]:]
201
-
202
- decoded_output_without_watermark = tokenizer.batch_decode(output_without_watermark, skip_special_tokens=True)[0]
203
- decoded_output_with_watermark = tokenizer.batch_decode(output_with_watermark, skip_special_tokens=True)[0]
204
-
205
- return (redecoded_input,
206
- int(truncation_warning),
207
- decoded_output_without_watermark,
208
- decoded_output_with_watermark,
209
- args)
210
- # decoded_output_with_watermark)
211
-
212
- def analyze(input_text, args, device=None, tokenizer=None):
213
- tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
214
- device = "mps" if args.run_extended else "cpu"
215
-
216
- detector_args = {
217
- "vocab": list(tokenizer.get_vocab().values()),
218
- "gamma": args.gamma,
219
- "delta": args.delta,
220
- "seeding_scheme": args.seeding_scheme,
221
- "select_green_tokens": args.select_green_tokens,
222
- "device": device,
223
- "tokenizer": tokenizer,
224
- "z_threshold": args.detection_z_threshold,
225
- "normalizers": args.normalizers,
226
- }
227
- if args.run_extended:
228
- detector_args["ignore_repeated_ngrams"] = args.ignore_repeated_ngrams
229
- else:
230
- detector_args["skip_repeated_bigrams"] = args.skip_repeated_bigrams
231
-
232
- if args.run_extended:
233
- watermark_detector = WatermarkAnalyzerExtended(**detector_args)
234
- else:
235
- watermark_detector = WatermarkAnalyzer(**detector_args)
236
-
237
- if args.run_extended:
238
- score_dict = watermark_detector.analyze(input_text)
239
- output = list_format_scores(score_dict, watermark_detector.z_threshold)
240
- else:
241
- if len(input_text)-1 > watermark_detector.min_prefix_len:
242
- score_dict = watermark_detector.analyze(input_text)
243
- # output = str_format_scores(score_dict, watermark_detector.z_threshold)
244
- output = list_format_scores(score_dict, watermark_detector.z_threshold)
245
- else:
246
- # output = (f"Error: string not long enough to compute watermark presence.")
247
- output = [["Error","string too short to compute metrics"]]
248
- output += [["",""] for _ in range(6)]
249
-
250
- return output, args
251
-
252
-
253
- def main(args):
254
- """Run the main script for generation and detection"""
255
- args = process_args(args)
256
- input_text = get_default_prompt()
257
- args.default_prompt = input_text
258
-
259
- if not args.skip_model_load:
260
- model, tokenizer, device = load_model(args)
261
- else:
262
- model, tokenizer, device = None, None, None
263
-
264
- if not args.skip_model_load:
265
- display_prompt(input_text)
266
-
267
- _, _, decoded_output_without_watermark, decoded_output_with_watermark, _ = generate(
268
- input_text, args, model=model, device=device, tokenizer=tokenizer
269
- )
270
-
271
- without_watermark_detection_result = analyze(
272
- decoded_output_without_watermark, args, device=device, tokenizer=tokenizer
273
- )
274
-
275
- with_watermark_detection_result = analyze(
276
- decoded_output_with_watermark, args, device=device, tokenizer=tokenizer
277
- )
278
-
279
- display_results(decoded_output_without_watermark, without_watermark_detection_result, args, with_watermark=False)
280
- display_results(decoded_output_with_watermark, with_watermark_detection_result, args, with_watermark=True)
281
-
282
- if args.run_gradio:
283
- run_gradio(args, model=model, tokenizer=tokenizer, device=device)
284
-
285
- return
286
-
287
- if __name__ == "__main__":
288
-
289
- args = parse_args()
290
-
291
- main(args)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
test.py DELETED
@@ -1,31 +0,0 @@
1
- import torch
2
- from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
3
-
4
- tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b-it")
5
- model = AutoModelForCausalLM.from_pretrained(
6
- "google/gemma-2-2b-it",
7
- device_map="auto",
8
- torch_dtype=torch.float16,
9
- )
10
-
11
- input_text = "Write me a poem about Machine Learning."
12
- input_ids = tokenizer(input_text, return_tensors="pt").to("mps")
13
-
14
- outputs = model.generate(**input_ids, max_new_tokens=32)
15
- print(tokenizer.decode(outputs[0]))
16
-
17
- pipe = pipeline(
18
- "text-generation",
19
- model= "google/gemma-2-2b-it",
20
- model_kwargs={"torch_dtype": torch.float16},
21
- device="mps", # replace with "mps" to run on a Mac device
22
- )
23
-
24
- messages = [
25
- {"role": "user", "content": "Who are you? Please, answer in pirate-speak."},
26
- ]
27
-
28
- outputs = pipe(messages, max_new_tokens=256)
29
- assistant_response = outputs[0]["generated_text"][-1]["content"].strip()
30
- print(assistant_response)
31
- # Ahoy, matey! I be Gemma, a digital scallywag, a language-slingin' parrot of the digital seas. I be here to help ye with yer wordy woes, answer yer questions, and spin ye yarns of the digital world. So, what be yer pleasure, eh? 🦜