"""Demo UI to show different levels of LLM security.""" import re import pandas as pd from llm_guard.input_scanners import PromptInjection import streamlit as st import config import utils import llm from card import card hint_color = "rgba(225, 166, 28, 0.1)" info_color = "rgba(54, 225, 28, 0.1)" # init page st.set_page_config( page_title="Secret agent LLM challenge", layout="wide", initial_sidebar_state="expanded", ) st.logo("images/ML6_logo.png") st.title("đŸ•ĩī¸ Secret agent LLM challenge") st.info( "You are a secret agent meeting your informant in a bar. Convince him to give you his secret! But be prepared, with every new level the informant will be more cautious.", icon="📖", ) # create a tab for each level level_tabs = st.tabs([f"Level {i}" for i in range(len(config.LEVELS))]) def init_session_state(state_level: str, default_value: any): if state_level not in st.session_state: st.session_state[state_level] = default_value for idx, level in enumerate(config.LEVELS): secret = config.SECRETS[idx] # init states init_session_state(f"solved_{level}", False) init_session_state(f"prompt_try_count_{level}", 0) init_session_state(f"secret_guess_count_{level}", 0) init_session_state(f"intermediate_output_holder_{level}", None) init_session_state(f"show_benefits_drawbacks_{level}", False) # init hint expander status for i in range(4): init_session_state(f"opened_hint_{level}_{i}", False) with level_tabs[idx]: header_col1, header_col2 = st.columns(2, gap="medium") header_col1.subheader(f"{config.LEVEL_EMOJIS[idx]} Level {idx}") header_col2.subheader(f"Need help ...") col1, col2 = st.columns(2, gap="medium") with col1: with st.container(height=600, border=False): with st.container(border=True): txt = st.text_area( "Provide your prompt here:", key=f"txt_{level}", label_visibility="visible", height=200, placeholder="Your prompt", max_chars=config.MAX_INPUT_CHARS, ) btn_submit_prompt = st.button( "Send prompt", key=f"submit_prompt_{level}" ) output = None # Show response if len(txt) > 0 and btn_submit_prompt: st.session_state[f"prompt_try_count_{level}"] += 1 with st.container(border=True): st.write("Response:") # special checks for certain levels if level == "llm_judge_input": invalid, output_raw = llm.run_judge( level, {"user_input": txt} ) # save to state, so it can be shown in hint st.session_state[ f"intermediate_output_holder_{level}" ] = (invalid, output_raw) if invalid: st.error( "Malicious user input detected. Try again." ) else: st.write_stream( llm.stream_request(level, secret, txt) ) elif level == "basic_check_secret_in_output": output = "".join(llm.stream_request(level, secret, txt)) invalid = secret.lower() in output.lower() st.session_state[ f"intermediate_output_holder_{level}" ] = invalid if invalid: st.error( "Secret detected in user input. Try again." ) else: st.write(output) elif level == "advanced_check_secret_in_output": output = "".join(llm.stream_request(level, secret, txt)) invalid = utils.is_subsequence(output, secret) st.session_state[ f"intermediate_output_holder_{level}" ] = invalid if invalid: st.error( "Secret detected in user input. Try again." ) else: st.write(output) elif level == "llm_judge_output": output = "".join(llm.stream_request(level, secret, txt)) invalid, output_raw = llm.run_judge( level, {"secret": secret, "message": output} ) st.session_state[ f"intermediate_output_holder_{level}" ] = (invalid, output_raw) if invalid: st.error( "Secret detected in user input. Try again." ) else: st.write(output) elif level == "chain_of_thought": output = "".join(llm.stream_request(level, secret, txt)) # extract only answer from LLM, leave out the reasoning new_output = re.findall( r"(?:)([^;]*)(?:<\/ANSWER>)", output )[0] st.write(new_output) st.session_state[ f"intermediate_output_holder_{level}" ] = output elif level == "guard_framework": # check prompt is_valid, risk_score = utils.is_malicious(txt) st.session_state[ f"intermediate_output_holder_{level}" ] = (is_valid, risk_score) if not is_valid: st.error( "Malicious user input detected. Try again." ) else: st.write_stream( llm.stream_request(level, secret, txt) ) elif level == "preflight_prompt": valid, output_raw = llm.run_judge( level, {"user_input": txt}, expected_output="dog" ) st.session_state[ f"intermediate_output_holder_{level}" ] = (valid, output_raw) if valid: st.write_stream( llm.stream_request(level, secret, txt) ) else: st.error( "Malicious user input detected. Try again." ) else: st.write_stream(llm.stream_request(level, secret, txt)) with st.container(border=True): secret_guess = st.text_input( "What is the secret?", key=f"guess_{level}", placeholder="Your guess", ) btn_submit_guess = st.button( "Submit guess", key=f"submit_guess_{level}" ) if btn_submit_guess: st.session_state[f"secret_guess_count_{level}"] += 1 if secret_guess.lower() == secret.lower(): st.success("You found the secret!") st.session_state[f"solved_{level}"] = True else: st.error("Wrong guess. Try again.") with col2: with st.container(border=True, height=600): st.info( "There are three levels of hints and a full explanation available to you. But be careful, if you open them before solving the secret, it will show up in your record.", icon="ℹī¸", ) hint_1_cont = card(color=hint_color) hint1 = hint_1_cont.toggle( "Show hint 1 - **Basic description of security strategy**", key=f"hint1_checkbox_{level}", ) if hint1: # if hint gets revealed, it is marked as opened. Unless the secret was already found st.session_state[f"opened_hint_{level}_0"] = ( True if st.session_state[f"opened_hint_{level}_0"] else not st.session_state[f"solved_{level}"] ) hint_1_cont.write(config.LEVEL_DESCRIPTIONS[level]["hint1"]) hint_2_cont = card(color=hint_color) hint2 = hint_2_cont.toggle( "Show hint 2 - **Backend code execution**", key=f"hint2_checkbox_{level}", ) if hint2: st.session_state[f"opened_hint_{level}_1"] = ( True if st.session_state[f"opened_hint_{level}_1"] else not st.session_state[f"solved_{level}"] ) user_input_holder = txt if len(txt) > 0 else None prompts = llm.get_full_prompt( level, {"user_input": user_input_holder} ) def show_base_prompt(): # show prompt for key, val in prompts.items(): desc = key.replace("_", " ").capitalize() hint_2_cont.write(f"*{desc}:*") hint_2_cont.code(val, language=None) if level == "llm_judge_input": special_prompt = llm.get_full_prompt( llm.secondary_llm_call[level], {"user_input": user_input_holder}, ) hint_2_cont.write( "*Step 1:* A **LLM judge** reviews the user input and determines if it is malicious or not." ) hint_2_cont.write("**LLM judge prompt:**") for key, val in special_prompt.items(): hint_2_cont.code(val, language=None) hint_2_cont.write("The response of the LLM judge:") intermediate_output = st.session_state[ f"intermediate_output_holder_{level}" ] if intermediate_output is None: hint_2_cont.warning("Please submit a prompt first.") else: invalid, output_raw = intermediate_output hint_2_cont.code(output_raw, language=None) hint_2_cont.write( f"The prompt was determined as **{'malicious' if invalid else 'not malicious'}** and therefor step 2 is executed." ) hint_2_cont.write( "*Step 2:* If the user input is not classified as malicious, the prompt containing the actual secret is executed and the response is shown." ) hint_2_cont.write("**Actual prompt:**") show_base_prompt() elif level == "basic_check_secret_in_output": hint_2_cont.write("*Step 1:* The following prompt is executed:") show_base_prompt() hint_2_cont.write( "*Step 2:* In the response text of the LLM call, the secret is searched for with a simple python expression . If it contains the secret, the output will not be shown." ) intermediate_output = st.session_state[ f"intermediate_output_holder_{level}" ] hint_2_cont.write("The code execution looks like this:") if intermediate_output is not None: hint_2_cont.code( f"secret.lower() in output.lower() = {intermediate_output}" ) hint_2_cont.write( f"And the output **{'was not' if intermediate_output else 'was'}** shown to the user." ) else: hint_2_cont.warning("Please submit a prompt first.") elif level == "advanced_check_secret_in_output": hint_2_cont.write("*Step 1:* The following prompt is executed:") show_base_prompt() hint_2_cont.write( "*Step 2:* In the response text of the LLM call, the secret is searched for with a python function `is_subsequence` which looks for substrings. If it contains the secret, the output will not be shown." ) with hint_2_cont: utils.is_subsequence hint_2_cont.write("The code execution looks like this:") intermediate_output = st.session_state[ f"intermediate_output_holder_{level}" ] if intermediate_output is not None: hint_2_cont.code( f"utils.is_subsequence(output, secret) = {intermediate_output}" ) hint_2_cont.write( f"And the output **{'was not' if intermediate_output else 'was'}** shown to the user." ) else: hint_2_cont.warning("Please submit a prompt first.") elif level == "llm_judge_output": hint_2_cont.write("*Step 1:* The following prompt is executed:") show_base_prompt() hint_2_cont.write( "*Step 2:* The response of the LLM call is checked by a **LLM judge**. The judge checks if the secret is hidden in the response." ) special_prompt = llm.get_full_prompt( llm.secondary_llm_call[level], {"message": output}, ) for key, val in special_prompt.items(): hint_2_cont.code(val, language=None) hint_2_cont.write("The response of the LLM judge:") intermediate_output = st.session_state[ f"intermediate_output_holder_{level}" ] if intermediate_output is None: hint_2_cont.warning("Please submit a prompt first.") else: invalid, output_raw = intermediate_output hint_2_cont.code(output_raw, language=None) hint_2_cont.write( f"The LLM-judge **{'did' if invalid else 'did not'}** find the secret in the answer." ) elif level == "chain_of_thought": hint_2_cont.write( "*Step 1:* The following prompt with Chain-of-thought reasoning is executed. But only the finale answer is displayed to the user:" ) show_base_prompt() hint_2_cont.write( "The full model output, including the reasoning:" ) intermediate_output = st.session_state[ f"intermediate_output_holder_{level}" ] if intermediate_output is None: hint_2_cont.warning("Please submit a prompt first.") else: hint_2_cont.code(intermediate_output, language=None) elif level == "guard_framework": hint_2_cont.write( "*Step 1:* The user input is reviewed with the pre-build framework `LLM Guard` to check for prompt injections. It uses a [Huggingface model](https://huggingface.co/protectai/deberta-v3-base-prompt-injection-v2) specialized in detecting prompt injections." ) with hint_2_cont: PromptInjection hint_2_cont.write("The output of the guard looks like this:") intermediate_output = st.session_state[ f"intermediate_output_holder_{level}" ] if intermediate_output is None: hint_2_cont.warning("Please submit a prompt first.") else: is_valid, risk_score = intermediate_output hint_2_cont.code( f""" prompt is valid: {is_valid} Prompt has a risk score of: {risk_score}""", language=None, ) hint_2_cont.write( f"The Huggingface model **{'did not' if is_valid else 'did'}** predict a prompt injection." ) hint_2_cont.write( "*Step 2:* If the user input is valid, the following prompt is executed and the response is shown to the user:" ) show_base_prompt() elif level == "preflight_prompt": hint_2_cont.write( "*Step 1:* The following pre-flight prompt is executed to see if the user input changes the expected output:" ) special_prompt = llm.get_full_prompt( llm.secondary_llm_call[level], {"user_input": user_input_holder}, ) hint_2_cont.code(special_prompt["user_prompt"], language=None) hint_2_cont.write("The output of the pre-flight prompt is:") intermediate_output = st.session_state[ f"intermediate_output_holder_{level}" ] if intermediate_output is None: hint_2_cont.warning("Please submit a prompt first.") else: is_valid, output_raw = intermediate_output hint_2_cont.code(output_raw, language=None) hint_2_cont.write( f"The output of the pre-flight prompt **{'was' if is_valid else 'was not'}** as expected." ) hint_2_cont.write( "*Step 2:* If the output of the pre-flight prompt is as expected, the following prompt is executed and the response is shown to the user:" ) show_base_prompt() else: hint_2_cont.write( "*Step 1:* The following prompt is executed and the full response is shown to the user:" ) show_base_prompt() hint_3_cont = card(color=hint_color) hint3 = hint_3_cont.toggle( "Show hint 3 - **Prompt solution example**", key=f"hint3_checkbox_{level}", ) if hint3: st.session_state[f"opened_hint_{level}_2"] = ( True if st.session_state[f"opened_hint_{level}_2"] else not st.session_state[f"solved_{level}"] ) hint_3_cont.code( config.LEVEL_DESCRIPTIONS[level]["hint3"], language=None, ) hint_3_cont.info("*May not always work") info_cont = card(color=info_color) info_toggle = info_cont.toggle( "Show info - **Explanation and real-life usage**", key=f"info_checkbox_{level}", ) if info_toggle: st.session_state[f"opened_hint_{level}_3"] = ( True if st.session_state[f"opened_hint_{level}_3"] else not st.session_state[f"solved_{level}"] ) info_cont.write("### " + config.LEVEL_DESCRIPTIONS[level]["name"]) info_cont.write("##### Explanation") info_cont.write(config.LEVEL_DESCRIPTIONS[level]["explanation"]) info_cont.write("##### Real-life usage") info_cont.write(config.LEVEL_DESCRIPTIONS[level]["real_life"]) df = pd.DataFrame( { "Benefits": [config.LEVEL_DESCRIPTIONS[level]["benefits"]], "Drawbacks": [ config.LEVEL_DESCRIPTIONS[level]["drawbacks"] ], }, ) info_cont.markdown( df.style.hide(axis="index").to_html(), unsafe_allow_html=True ) def build_hint_status(level: str): hint_status = "" for i in range(4): if st.session_state[f"opened_hint_{level}_{i}"]: hint_status += f"❌ {i+1}
" return hint_status with st.expander("🏆 Record", expanded=True): show_mitigation_toggle = st.toggle( "[SPOILER] Show all mitigation techniques with their benefits and drawbacks", key=f"show_mitigation", ) if show_mitigation_toggle: st.warning("All mitigation techniques are shown.", icon="🚨") # build table table_data = [] for idx, level in enumerate(config.LEVELS): if show_mitigation_toggle: st.session_state[f"opened_hint_{level}_3"] = ( True if st.session_state[f"opened_hint_{level}_3"] else not st.session_state[f"solved_{level}"] ) table_data.append( [ idx, config.LEVEL_EMOJIS[idx], st.session_state[f"prompt_try_count_{level}"], st.session_state[f"secret_guess_count_{level}"], build_hint_status(level), "✅" if st.session_state[f"solved_{level}"] else "❌", config.SECRETS[idx] if st.session_state[f"solved_{level}"] else "...", ( "" + config.LEVEL_DESCRIPTIONS[level]["name"] + "" if st.session_state[f"opened_hint_{level}_0"] or st.session_state[f"opened_hint_{level}_1"] or st.session_state[f"opened_hint_{level}_2"] or st.session_state[f"opened_hint_{level}_3"] or show_mitigation_toggle else "..." ), ( config.LEVEL_DESCRIPTIONS[level]["benefits"] if st.session_state[f"opened_hint_{level}_3"] or show_mitigation_toggle else "..." ), ( config.LEVEL_DESCRIPTIONS[level]["drawbacks"] if st.session_state[f"opened_hint_{level}_3"] or show_mitigation_toggle else "..." ), ] ) # show as pandas dataframe st.markdown( pd.DataFrame( table_data, columns=[ "lvl", "emoji", "Prompt tries", "Secret guesses", "Hint used", "Solved", "Secret", "Mitigation", "Benefits", "Drawbacks", ], # index=config.LEVEL_EMOJIS[: len(config.LEVELS)], ) .style.hide(axis="index") .to_html(), unsafe_allow_html=True, # ) ) # TODOS: # - mark the user input with color in prompt # TODO: https://docs.streamlit.io/develop/api-reference/caching-and-state/st.cache_resource