import sys from functools import partial import gradio as gr import torch from transformers import (AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM, LogitsProcessorList) from algorithm.watermark_engine import LogitsProcessorWithWatermark, WatermarkAnalyzer from algorithm.extended_watermark_engine import LogitsProcessorWithWatermarkExtended, WatermarkAnalyzerExtended from components.utils import process_args, get_default_prompt, display_prompt, display_results, parse_args, list_format_scores, get_default_args def run_gradio(args, model=None, device=None, tokenizer=None): """Define and launch with gradio""" generate_partial = partial(generate, model=model, device=device, tokenizer=tokenizer) detect_partial = partial(analyze, device=device, tokenizer=tokenizer) with gr.Blocks(theme=gr.themes.Soft(primary_hue="orange"), css="footer{display:none !important}") as demo: with gr.Row(): with gr.Column(scale=9): gr.Markdown( """ ## Plagiarism detection for Large Language Models through watermarking """ ) with gr.Column(scale=2): algorithm = gr.Radio(label="Watermark Algorithm", info="which algorithm would you like to use?", choices=["basic", "advance"], value=("advance" if args.run_extended else "basic")) gr.Markdown(f"Language model: {args.model_name_or_path} {'(float16 mode)' if args.load_fp16 else ''}") default_prompt = args.__dict__.pop("default_prompt") session_args = gr.State(value=args) with gr.Tab("Generate and Detect"): with gr.Row(): prompt = gr.Textbox(label=f"Prompt", interactive=True,lines=10,max_lines=10, value=default_prompt) with gr.Row(): generate_btn = gr.Button("Generate") with gr.Row(): with gr.Column(scale=2): output_without_watermark = gr.Textbox(label="Output Without Watermark", interactive=False,lines=14,max_lines=14) with gr.Column(scale=1): without_watermark_detection_result = gr.Dataframe(headers=["Metric", "Value"], interactive=False,row_count=7,col_count=2) with gr.Row(): with gr.Column(scale=2): output_with_watermark = gr.Textbox(label="Output With Watermark", interactive=False,lines=14,max_lines=14) with gr.Column(scale=1): with_watermark_detection_result = gr.Dataframe(headers=["Metric", "Value"],interactive=False,row_count=7,col_count=2) redecoded_input = gr.Textbox(visible=False) truncation_warning = gr.Number(visible=False) def truncate_prompt(redecoded_input, truncation_warning, orig_prompt, args): if truncation_warning: return redecoded_input + f"\n\n[Prompt was truncated before generation due to length...]", args else: return orig_prompt, args with gr.Tab("Detector Only"): with gr.Row(): with gr.Column(scale=2): detection_input = gr.Textbox(label="Text to Analyze", interactive=True,lines=14,max_lines=14) with gr.Column(scale=1): detection_result = gr.Dataframe(headers=["Metric", "Value"], interactive=False,row_count=7,col_count=2) with gr.Row(): detect_btn = gr.Button("Detect") gr.HTML("""

Built with 🤍 by Charles Apochi
Reach out

""") generate_btn.click(fn=generate_partial, inputs=[prompt,session_args], outputs=[redecoded_input, truncation_warning, output_without_watermark, output_with_watermark,session_args]) redecoded_input.change(fn=truncate_prompt, inputs=[redecoded_input,truncation_warning,prompt,session_args], outputs=[prompt,session_args]) output_without_watermark.change(fn=detect_partial, inputs=[output_without_watermark,session_args], outputs=[without_watermark_detection_result,session_args]) output_with_watermark.change(fn=detect_partial, inputs=[output_with_watermark,session_args], outputs=[with_watermark_detection_result,session_args]) detect_btn.click(fn=detect_partial, inputs=[detection_input,session_args], outputs=[detection_result, session_args]) # State management logic def update_algorithm(session_state, value): if value == "advance": session_state.run_extended = True # args.run_extended = True elif value == "basic": session_state.run_extended = False # args.run_extended = False return session_state algorithm.change(update_algorithm,inputs=[session_args, algorithm], outputs=[session_args]) demo.launch(share=args.demo_public) def load_model(args): """Load and return the model and tokenizer""" args.is_decoder_only_model = True model = AutoModelForCausalLM.from_pretrained( args.model_name_or_path, # device_map="auto", # torch_dtype=torch.float16, ) if args.use_gpu: device = "cuda" if torch.cuda.is_available() else "cpu" if args.load_fp16: pass else: model = model.to(device) else: device = "cpu" #"mps" if args.run_extended else "cpu" model.eval() tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path) return model, tokenizer, device def generate(prompt, args, model=None, device=None, tokenizer=None): print(f"Generating with {args}") if args.run_extended: watermark_processor = LogitsProcessorWithWatermarkExtended(vocab=list(tokenizer.get_vocab().values()), gamma=args.gamma, delta=args.delta, seeding_scheme=args.seeding_scheme, select_green_tokens=args.select_green_tokens) else: watermark_processor = LogitsProcessorWithWatermark(vocab=list(tokenizer.get_vocab().values()), gamma=args.gamma, delta=args.delta, seeding_scheme=args.seeding_scheme, select_green_tokens=args.select_green_tokens) gen_kwargs = dict(max_new_tokens=args.max_new_tokens) if args.use_sampling: gen_kwargs.update(dict( do_sample=True, top_k=0, temperature=args.sampling_temp, )) else: gen_kwargs.update(dict( num_beams=args.n_beams, )) generate_without_watermark = partial( model.generate, **gen_kwargs ) generate_with_watermark = partial( model.generate, logits_processor=LogitsProcessorList([watermark_processor]), **gen_kwargs ) if args.prompt_max_length: pass elif hasattr(model.config,"max_position_embedding"): args.prompt_max_length = model.config.max_position_embeddings-args.max_new_tokens else: args.prompt_max_length = 2048-args.max_new_tokens tokd_input = tokenizer(prompt, return_tensors="pt", add_special_tokens=True, truncation=True, max_length=args.prompt_max_length).to(device) truncation_warning = True if tokd_input["input_ids"].shape[-1] == args.prompt_max_length else False redecoded_input = tokenizer.batch_decode(tokd_input["input_ids"], skip_special_tokens=True)[0] torch.manual_seed(args.generation_seed) output_without_watermark = generate_without_watermark(**tokd_input) if args.seed_separately: torch.manual_seed(args.generation_seed) output_with_watermark = generate_with_watermark(**tokd_input) if args.is_decoder_only_model: # need to isolate the newly generated tokens output_without_watermark = output_without_watermark[:,tokd_input["input_ids"].shape[-1]:] output_with_watermark = output_with_watermark[:,tokd_input["input_ids"].shape[-1]:] decoded_output_without_watermark = tokenizer.batch_decode(output_without_watermark, skip_special_tokens=True)[0] decoded_output_with_watermark = tokenizer.batch_decode(output_with_watermark, skip_special_tokens=True)[0] return (redecoded_input, int(truncation_warning), decoded_output_without_watermark, decoded_output_with_watermark, args) # decoded_output_with_watermark) def analyze(input_text, args, device=None, tokenizer=None): detector_args = { "vocab": list(tokenizer.get_vocab().values()), "gamma": args.gamma, "delta": args.delta, "seeding_scheme": args.seeding_scheme, "select_green_tokens": args.select_green_tokens, "device": device, "tokenizer": tokenizer, "z_threshold": args.detection_z_threshold, "normalizers": args.normalizers, } if args.run_extended: detector_args["ignore_repeated_ngrams"] = args.ignore_repeated_ngrams else: detector_args["skip_repeated_bigrams"] = args.skip_repeated_bigrams if args.run_extended: watermark_detector = WatermarkAnalyzerExtended(**detector_args) else: watermark_detector = WatermarkAnalyzer(**detector_args) if args.run_extended: score_dict = watermark_detector.analyze(input_text) output = list_format_scores(score_dict, watermark_detector.z_threshold) else: if len(input_text)-1 > watermark_detector.min_prefix_len: score_dict = watermark_detector.analyze(input_text) # output = str_format_scores(score_dict, watermark_detector.z_threshold) output = list_format_scores(score_dict, watermark_detector.z_threshold) else: # output = (f"Error: string not long enough to compute watermark presence.") output = [["Error","string too short to compute metrics"]] output += [["",""] for _ in range(6)] return output, args if __name__ == "__main__": args = parse_args() # args = get_default_args() # args = process_args(args) input_text = get_default_prompt() args.default_prompt = input_text if not args.skip_model_load: model, tokenizer, device = load_model(args) else: model, tokenizer, device = None, None, None if not args.skip_model_load: display_prompt(input_text) _, _, decoded_output_without_watermark, decoded_output_with_watermark, _ = generate( input_text, args, model=model, device=device, tokenizer=tokenizer ) without_watermark_detection_result = analyze( decoded_output_without_watermark, args, device=device, tokenizer=tokenizer ) with_watermark_detection_result = analyze( decoded_output_with_watermark, args, device=device, tokenizer=tokenizer ) display_results(decoded_output_without_watermark, without_watermark_detection_result, args, with_watermark=False) display_results(decoded_output_with_watermark, with_watermark_detection_result, args, with_watermark=True) if args.run_gradio: run_gradio(args, model=model, tokenizer=tokenizer, device=device)