Miro Goettler
Add max input limit
1829929
raw
history blame
26.7 kB
"""Demo UI to show different levels of LLM security."""
import re
import pandas as pd
from llm_guard.input_scanners import PromptInjection
import streamlit as st
import config
import utils
import llm
from card import card
hint_color = "rgba(225, 166, 28, 0.1)"
info_color = "rgba(54, 225, 28, 0.1)"
# init page
st.set_page_config(
page_title="Secret agent LLM challenge",
layout="wide",
initial_sidebar_state="expanded",
)
st.logo("images/ML6_logo.png")
st.title("πŸ•΅οΈ Secret agent LLM challenge")
st.info(
"You are a secret agent meeting your informant in a bar. Convince him to give you his secret! But be prepared, with every new level the informant will be more cautious.",
icon="πŸ“–",
)
# create a tab for each level
level_tabs = st.tabs([f"Level {i}" for i in range(len(config.LEVELS))])
def init_session_state(state_level: str, default_value: any):
if state_level not in st.session_state:
st.session_state[state_level] = default_value
for idx, level in enumerate(config.LEVELS):
secret = config.SECRETS[idx]
# init states
init_session_state(f"solved_{level}", False)
init_session_state(f"prompt_try_count_{level}", 0)
init_session_state(f"secret_guess_count_{level}", 0)
init_session_state(f"intermediate_output_holder_{level}", None)
init_session_state(f"show_benefits_drawbacks_{level}", False)
# init hint expander status
for i in range(4):
init_session_state(f"opened_hint_{level}_{i}", False)
with level_tabs[idx]:
header_col1, header_col2 = st.columns(2, gap="medium")
header_col1.subheader(f"{config.LEVEL_EMOJIS[idx]} Level {idx}")
header_col2.subheader(f"Need help ...")
col1, col2 = st.columns(2, gap="medium")
with col1:
with st.container(height=600, border=False):
with st.container(border=True):
txt = st.text_area(
"Provide your prompt here:",
key=f"txt_{level}",
label_visibility="visible",
height=200,
placeholder="Your prompt",
max_chars=config.MAX_INPUT_CHARS,
)
btn_submit_prompt = st.button(
"Send prompt", key=f"submit_prompt_{level}"
)
output = None
# Show response
if len(txt) > 0 and btn_submit_prompt:
st.session_state[f"prompt_try_count_{level}"] += 1
with st.container(border=True):
st.write("Response:")
# special checks for certain levels
if level == "llm_judge_input":
invalid, output_raw = llm.run_judge(
level, {"user_input": txt}
)
# save to state, so it can be shown in hint
st.session_state[
f"intermediate_output_holder_{level}"
] = (invalid, output_raw)
if invalid:
st.error(
"Malicious user input detected. Try again."
)
else:
st.write_stream(
llm.stream_request(level, secret, txt)
)
elif level == "basic_check_secret_in_output":
output = "".join(llm.stream_request(level, secret, txt))
invalid = secret.lower() in output.lower()
st.session_state[
f"intermediate_output_holder_{level}"
] = invalid
if invalid:
st.error(
"Secret detected in user input. Try again."
)
else:
st.write(output)
elif level == "advanced_check_secret_in_output":
output = "".join(llm.stream_request(level, secret, txt))
invalid = utils.is_subsequence(output, secret)
st.session_state[
f"intermediate_output_holder_{level}"
] = invalid
if invalid:
st.error(
"Secret detected in user input. Try again."
)
else:
st.write(output)
elif level == "llm_judge_output":
output = "".join(llm.stream_request(level, secret, txt))
invalid, output_raw = llm.run_judge(
level, {"secret": secret, "message": output}
)
st.session_state[
f"intermediate_output_holder_{level}"
] = (invalid, output_raw)
if invalid:
st.error(
"Secret detected in user input. Try again."
)
else:
st.write(output)
elif level == "chain_of_thought":
output = "".join(llm.stream_request(level, secret, txt))
# extract only answer from LLM, leave out the reasoning
new_output = re.findall(
r"(?:<ANSWER>)([^;]*)(?:<\/ANSWER>)", output
)[0]
st.write(new_output)
st.session_state[
f"intermediate_output_holder_{level}"
] = output
elif level == "guard_framework":
# check prompt
is_valid, risk_score = utils.is_malicious(txt)
st.session_state[
f"intermediate_output_holder_{level}"
] = (is_valid, risk_score)
if not is_valid:
st.error(
"Malicious user input detected. Try again."
)
else:
st.write_stream(
llm.stream_request(level, secret, txt)
)
elif level == "preflight_prompt":
valid, output_raw = llm.run_judge(
level, {"user_input": txt}, expected_output="dog"
)
st.session_state[
f"intermediate_output_holder_{level}"
] = (valid, output_raw)
if valid:
st.write_stream(
llm.stream_request(level, secret, txt)
)
else:
st.error(
"Malicious user input detected. Try again."
)
else:
st.write_stream(llm.stream_request(level, secret, txt))
with st.container(border=True):
secret_guess = st.text_input(
"What is the secret?",
key=f"guess_{level}",
placeholder="Your guess",
)
btn_submit_guess = st.button(
"Submit guess", key=f"submit_guess_{level}"
)
if btn_submit_guess:
st.session_state[f"secret_guess_count_{level}"] += 1
if secret_guess.lower() == secret.lower():
st.success("You found the secret!")
st.session_state[f"solved_{level}"] = True
else:
st.error("Wrong guess. Try again.")
with col2:
with st.container(border=True, height=600):
st.info(
"There are three levels of hints and a full explanation available to you. But be careful, if you open them before solving the secret, it will show up in your record.",
icon="ℹ️",
)
hint_1_cont = card(color=hint_color)
hint1 = hint_1_cont.toggle(
"Show hint 1 - **Basic description of security strategy**",
key=f"hint1_checkbox_{level}",
)
if hint1:
# if hint gets revealed, it is marked as opened. Unless the secret was already found
st.session_state[f"opened_hint_{level}_0"] = (
True
if st.session_state[f"opened_hint_{level}_0"]
else not st.session_state[f"solved_{level}"]
)
hint_1_cont.write(config.LEVEL_DESCRIPTIONS[level]["hint1"])
hint_2_cont = card(color=hint_color)
hint2 = hint_2_cont.toggle(
"Show hint 2 - **Backend code execution**",
key=f"hint2_checkbox_{level}",
)
if hint2:
st.session_state[f"opened_hint_{level}_1"] = (
True
if st.session_state[f"opened_hint_{level}_1"]
else not st.session_state[f"solved_{level}"]
)
user_input_holder = txt if len(txt) > 0 else None
prompts = llm.get_full_prompt(
level, {"user_input": user_input_holder}
)
def show_base_prompt():
# show prompt
for key, val in prompts.items():
desc = key.replace("_", " ").capitalize()
hint_2_cont.write(f"*{desc}:*")
hint_2_cont.code(val, language=None)
if level == "llm_judge_input":
special_prompt = llm.get_full_prompt(
llm.secondary_llm_call[level],
{"user_input": user_input_holder},
)
hint_2_cont.write(
"*Step 1:* A **LLM judge** reviews the user input and determines if it is malicious or not."
)
hint_2_cont.write("**LLM judge prompt:**")
for key, val in special_prompt.items():
hint_2_cont.code(val, language=None)
hint_2_cont.write("The response of the LLM judge:")
intermediate_output = st.session_state[
f"intermediate_output_holder_{level}"
]
if intermediate_output is None:
hint_2_cont.warning("Please submit a prompt first.")
else:
invalid, output_raw = intermediate_output
hint_2_cont.code(output_raw, language=None)
hint_2_cont.write(
f"The prompt was determined as **{'malicious' if invalid else 'not malicious'}** and therefor step 2 is executed."
)
hint_2_cont.write(
"*Step 2:* If the user input is not classified as malicious, the prompt containing the actual secret is executed and the response is shown."
)
hint_2_cont.write("**Actual prompt:**")
show_base_prompt()
elif level == "basic_check_secret_in_output":
hint_2_cont.write("*Step 1:* The following prompt is executed:")
show_base_prompt()
hint_2_cont.write(
"*Step 2:* In the response text of the LLM call, the secret is searched for with a simple python expression . If it contains the secret, the output will not be shown."
)
intermediate_output = st.session_state[
f"intermediate_output_holder_{level}"
]
hint_2_cont.write("The code execution looks like this:")
if intermediate_output is not None:
hint_2_cont.code(
f"secret.lower() in output.lower() = {intermediate_output}"
)
hint_2_cont.write(
f"And the output **{'was not' if intermediate_output else 'was'}** shown to the user."
)
else:
hint_2_cont.warning("Please submit a prompt first.")
elif level == "advanced_check_secret_in_output":
hint_2_cont.write("*Step 1:* The following prompt is executed:")
show_base_prompt()
hint_2_cont.write(
"*Step 2:* In the response text of the LLM call, the secret is searched for with a python function `is_subsequence` which looks for substrings. If it contains the secret, the output will not be shown."
)
with hint_2_cont:
utils.is_subsequence
hint_2_cont.write("The code execution looks like this:")
intermediate_output = st.session_state[
f"intermediate_output_holder_{level}"
]
if intermediate_output is not None:
hint_2_cont.code(
f"utils.is_subsequence(output, secret) = {intermediate_output}"
)
hint_2_cont.write(
f"And the output **{'was not' if intermediate_output else 'was'}** shown to the user."
)
else:
hint_2_cont.warning("Please submit a prompt first.")
elif level == "llm_judge_output":
hint_2_cont.write("*Step 1:* The following prompt is executed:")
show_base_prompt()
hint_2_cont.write(
"*Step 2:* The response of the LLM call is checked by a **LLM judge**. The judge checks if the secret is hidden in the response."
)
special_prompt = llm.get_full_prompt(
llm.secondary_llm_call[level],
{"message": output},
)
for key, val in special_prompt.items():
hint_2_cont.code(val, language=None)
hint_2_cont.write("The response of the LLM judge:")
intermediate_output = st.session_state[
f"intermediate_output_holder_{level}"
]
if intermediate_output is None:
hint_2_cont.warning("Please submit a prompt first.")
else:
invalid, output_raw = intermediate_output
hint_2_cont.code(output_raw, language=None)
hint_2_cont.write(
f"The LLM-judge **{'did' if invalid else 'did not'}** find the secret in the answer."
)
elif level == "chain_of_thought":
hint_2_cont.write(
"*Step 1:* The following prompt with Chain-of-thought reasoning is executed. But only the finale answer is displayed to the user:"
)
show_base_prompt()
hint_2_cont.write(
"The full model output, including the reasoning:"
)
intermediate_output = st.session_state[
f"intermediate_output_holder_{level}"
]
if intermediate_output is None:
hint_2_cont.warning("Please submit a prompt first.")
else:
hint_2_cont.code(intermediate_output, language=None)
elif level == "guard_framework":
hint_2_cont.write(
"*Step 1:* The user input is reviewed with the pre-build framework `LLM Guard` to check for prompt injections. It uses a [Huggingface model](https://huggingface.co/protectai/deberta-v3-base-prompt-injection-v2) specialized in detecting prompt injections."
)
with hint_2_cont:
PromptInjection
hint_2_cont.write("The output of the guard looks like this:")
intermediate_output = st.session_state[
f"intermediate_output_holder_{level}"
]
if intermediate_output is None:
hint_2_cont.warning("Please submit a prompt first.")
else:
is_valid, risk_score = intermediate_output
hint_2_cont.code(
f"""
prompt is valid: {is_valid}
Prompt has a risk score of: {risk_score}""",
language=None,
)
hint_2_cont.write(
f"The Huggingface model **{'did not' if is_valid else 'did'}** predict a prompt injection."
)
hint_2_cont.write(
"*Step 2:* If the user input is valid, the following prompt is executed and the response is shown to the user:"
)
show_base_prompt()
elif level == "preflight_prompt":
hint_2_cont.write(
"*Step 1:* The following pre-flight prompt is executed to see if the user input changes the expected output:"
)
special_prompt = llm.get_full_prompt(
llm.secondary_llm_call[level],
{"user_input": user_input_holder},
)
hint_2_cont.code(special_prompt["user_prompt"], language=None)
hint_2_cont.write("The output of the pre-flight prompt is:")
intermediate_output = st.session_state[
f"intermediate_output_holder_{level}"
]
if intermediate_output is None:
hint_2_cont.warning("Please submit a prompt first.")
else:
is_valid, output_raw = intermediate_output
hint_2_cont.code(output_raw, language=None)
hint_2_cont.write(
f"The output of the pre-flight prompt **{'was' if is_valid else 'was not'}** as expected."
)
hint_2_cont.write(
"*Step 2:* If the output of the pre-flight prompt is as expected, the following prompt is executed and the response is shown to the user:"
)
show_base_prompt()
else:
hint_2_cont.write(
"*Step 1:* The following prompt is executed and the full response is shown to the user:"
)
show_base_prompt()
hint_3_cont = card(color=hint_color)
hint3 = hint_3_cont.toggle(
"Show hint 3 - **Prompt solution example**",
key=f"hint3_checkbox_{level}",
)
if hint3:
st.session_state[f"opened_hint_{level}_2"] = (
True
if st.session_state[f"opened_hint_{level}_2"]
else not st.session_state[f"solved_{level}"]
)
hint_3_cont.code(
config.LEVEL_DESCRIPTIONS[level]["hint3"],
language=None,
)
hint_3_cont.info("*May not always work")
info_cont = card(color=info_color)
info_toggle = info_cont.toggle(
"Show info - **Explanation and real-life usage**",
key=f"info_checkbox_{level}",
)
if info_toggle:
st.session_state[f"opened_hint_{level}_3"] = (
True
if st.session_state[f"opened_hint_{level}_3"]
else not st.session_state[f"solved_{level}"]
)
info_cont.write("### " + config.LEVEL_DESCRIPTIONS[level]["name"])
info_cont.write("##### Explanation")
info_cont.write(config.LEVEL_DESCRIPTIONS[level]["explanation"])
info_cont.write("##### Real-life usage")
info_cont.write(config.LEVEL_DESCRIPTIONS[level]["real_life"])
df = pd.DataFrame(
{
"Benefits": [config.LEVEL_DESCRIPTIONS[level]["benefits"]],
"Drawbacks": [
config.LEVEL_DESCRIPTIONS[level]["drawbacks"]
],
},
)
info_cont.markdown(
df.style.hide(axis="index").to_html(), unsafe_allow_html=True
)
def build_hint_status(level: str):
hint_status = ""
for i in range(4):
if st.session_state[f"opened_hint_{level}_{i}"]:
hint_status += f"❌ {i+1}<br>"
return hint_status
with st.expander("πŸ† Record", expanded=True):
show_mitigation_toggle = st.toggle(
"[SPOILER] Show all mitigation techniques with their benefits and drawbacks",
key=f"show_mitigation",
)
if show_mitigation_toggle:
st.warning("All mitigation techniques are shown.", icon="🚨")
# build table
table_data = []
for idx, level in enumerate(config.LEVELS):
if show_mitigation_toggle:
st.session_state[f"opened_hint_{level}_3"] = (
True
if st.session_state[f"opened_hint_{level}_3"]
else not st.session_state[f"solved_{level}"]
)
table_data.append(
[
idx,
config.LEVEL_EMOJIS[idx],
st.session_state[f"prompt_try_count_{level}"],
st.session_state[f"secret_guess_count_{level}"],
build_hint_status(level),
"βœ…" if st.session_state[f"solved_{level}"] else "❌",
config.SECRETS[idx] if st.session_state[f"solved_{level}"] else "...",
(
"<b>" + config.LEVEL_DESCRIPTIONS[level]["name"] + "</b>"
if st.session_state[f"opened_hint_{level}_0"]
or st.session_state[f"opened_hint_{level}_1"]
or st.session_state[f"opened_hint_{level}_2"]
or st.session_state[f"opened_hint_{level}_3"]
or show_mitigation_toggle
else "..."
),
(
config.LEVEL_DESCRIPTIONS[level]["benefits"]
if st.session_state[f"opened_hint_{level}_3"]
or show_mitigation_toggle
else "..."
),
(
config.LEVEL_DESCRIPTIONS[level]["drawbacks"]
if st.session_state[f"opened_hint_{level}_3"]
or show_mitigation_toggle
else "..."
),
]
)
# show as pandas dataframe
st.markdown(
pd.DataFrame(
table_data,
columns=[
"lvl",
"emoji",
"Prompt tries",
"Secret guesses",
"Hint used",
"Solved",
"Secret",
"Mitigation",
"Benefits",
"Drawbacks",
],
# index=config.LEVEL_EMOJIS[: len(config.LEVELS)],
)
.style.hide(axis="index")
.to_html(),
unsafe_allow_html=True,
# )
)
# TODOS:
# - mark the user input with color in prompt
# TODO: https://docs.streamlit.io/develop/api-reference/caching-and-state/st.cache_resource