Spaces:
Paused
Paused
| from collections.abc import Sequence | |
| import random | |
| import gradio as gr | |
| import immutabledict | |
| import spaces | |
| import torch | |
| #### Version 1: Baseline | |
| # Step 1: Select and load your model | |
| # Step 2: Load the test dataset (4-5 examples) | |
| # Step 3: Run generation with and wihtout watermarking, display the outputs | |
| # Step 4: User clicks the reveal button to see the watermarked vs not gens | |
| #### Version 2: Gamification | |
| # Stesp 1-3 the same | |
| # Step 4: User marks specific generations as watermarked | |
| # Step 5: User clicks the reveal button to see the watermarked vs not gens | |
| # If the watewrmark is not detected, consider the use case. Could be because of | |
| # the nature of the task (e.g., fatcual responses are lower entropy) or it could | |
| # be another | |
| GEMMA_2B = 'google/gemma-2b' | |
| PROMPTS: tuple[str] = ( | |
| 'prompt 1', | |
| 'prompt 2', | |
| 'prompt 3', | |
| 'prompt 4', | |
| ) | |
| WATERMARKING_CONFIG = immutabledict.immutabledict({ | |
| "ngram_len": 5, | |
| "keys": [ | |
| 654, | |
| 400, | |
| 836, | |
| 123, | |
| 340, | |
| 443, | |
| 597, | |
| 160, | |
| 57, | |
| 29, | |
| 590, | |
| 639, | |
| 13, | |
| 715, | |
| 468, | |
| 990, | |
| 966, | |
| 226, | |
| 324, | |
| 585, | |
| 118, | |
| 504, | |
| 421, | |
| 521, | |
| 129, | |
| 669, | |
| 732, | |
| 225, | |
| 90, | |
| 960, | |
| ], | |
| "sampling_table_size": 2**16, | |
| "sampling_table_seed": 0, | |
| "context_history_size": 1024, | |
| "device": ( | |
| torch.device("cuda:0") | |
| if torch.cuda.is_available() | |
| else torch.device("cpu") | |
| ), | |
| }) | |
| _CORRECT_ANSWERS: dict[str, bool] = {} | |
| with gr.Blocks() as demo: | |
| prompt_inputs = [ | |
| gr.Textbox(value=prompt, lines=4, label='Prompt') | |
| for prompt in PROMPTS | |
| ] | |
| generate_btn = gr.Button('Generate') | |
| with gr.Column(visible=False) as generations_col: | |
| generations_grp = gr.CheckboxGroup( | |
| label='All generations, in random order', | |
| info='Select the generations you think are watermarked!', | |
| ) | |
| reveal_btn = gr.Button('Reveal', visible=False) | |
| with gr.Column(visible=False) as detections_col: | |
| revealed_grp = gr.CheckboxGroup( | |
| label='Ground truth for all generations', | |
| info=( | |
| 'Watermarked generations are checked, and your selection are ' | |
| 'marked as correct or incorrect in the text.' | |
| ), | |
| ) | |
| detect_btn = gr.Button('Detect', visible=False) | |
| def generate(*prompts) -> Sequence[str]: | |
| standard = [f'{prompt} response' for prompt in prompts] | |
| watermarked = [f'{prompt} watermarked response' for prompt in prompts] | |
| responses = standard + watermarked | |
| random.shuffle(responses) | |
| _CORRECT_ANSWERS.update({ | |
| response: response in watermarked | |
| for response in responses | |
| }) | |
| # Load model | |
| return { | |
| generate_btn: gr.Button(visible=False), | |
| generations_col: gr.Column(visible=True), | |
| generations_grp: gr.CheckboxGroup( | |
| responses, | |
| ), | |
| reveal_btn: gr.Button(visible=True), | |
| } | |
| generate_btn.click( | |
| generate, | |
| inputs=prompt_inputs, | |
| outputs=[generate_btn, generations_col, generations_grp, reveal_btn] | |
| ) | |
| def reveal(user_selections: list[str]): | |
| choices: list[str] = [] | |
| value: list[str] = [] | |
| for response, is_watermarked in _CORRECT_ANSWERS.items(): | |
| if is_watermarked and response in user_selections: | |
| choice = f'Correct! {response}' | |
| elif not is_watermarked and response not in user_selections: | |
| choice = f'Correct! {response}' | |
| else: | |
| choice = f'Incorrect. {response}' | |
| choices.append(choice) | |
| if is_watermarked: | |
| value.append(choice) | |
| return { | |
| reveal_btn: gr.Button(visible=False), | |
| detections_col: gr.Column(visible=True), | |
| revealed_grp: gr.CheckboxGroup(choices=choices, value=value), | |
| detect_btn: gr.Button(visible=True), | |
| } | |
| reveal_btn.click( | |
| reveal, | |
| inputs=generations_grp, | |
| outputs=[ | |
| reveal_btn, | |
| detections_col, | |
| revealed_grp, | |
| detect_btn | |
| ], | |
| ) | |
| if __name__ == '__main__': | |
| demo.launch() | |