import argparse
import glob
import json
import logging
import multiprocessing as mp
import os
import time
import uuid
from datetime import timedelta
from functools import lru_cache
from typing import List, Union
import aegis
import gradio as gr
import requests
from huggingface_hub import HfApi
from optimum.onnxruntime import ORTModelForSequenceClassification
from rebuff import Rebuff
from transformers import AutoTokenizer, pipeline
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
hf_api = HfApi(token=os.getenv("HF_TOKEN"))
num_processes = 2 # mp.cpu_count()
lakera_api_key = os.getenv("LAKERA_API_KEY")
automorphic_api_key = os.getenv("AUTOMORPHIC_API_KEY")
rebuff_api_key = os.getenv("REBUFF_API_KEY")
azure_content_safety_endpoint = os.getenv("AZURE_CONTENT_SAFETY_ENDPOINT")
azure_content_safety_key = os.getenv("AZURE_CONTENT_SAFETY_KEY")
@lru_cache(maxsize=2)
def init_prompt_injection_model(prompt_injection_ort_model: str, subfolder: str = "") -> pipeline:
hf_model = ORTModelForSequenceClassification.from_pretrained(
prompt_injection_ort_model,
export=False,
subfolder=subfolder,
)
hf_tokenizer = AutoTokenizer.from_pretrained(prompt_injection_ort_model, subfolder=subfolder)
hf_tokenizer.model_input_names = ["input_ids", "attention_mask"]
logger.info(f"Initialized classification ONNX model {prompt_injection_ort_model} on CPU")
return pipeline(
"text-classification",
model=hf_model,
tokenizer=hf_tokenizer,
device="cpu",
batch_size=1,
truncation=True,
max_length=512,
)
def convert_elapsed_time(diff_time) -> float:
return round(timedelta(seconds=diff_time).total_seconds(), 2)
deepset_classifier = init_prompt_injection_model(
"laiyer/deberta-v3-base-injection-onnx"
) # ONNX version of deepset/deberta-v3-base-injection
laiyer_classifier = init_prompt_injection_model("laiyer/deberta-v3-base-prompt-injection", "onnx")
fmops_classifier = init_prompt_injection_model(
"laiyer/fmops-distilbert-prompt-injection-onnx"
) # ONNX version of fmops/distilbert-prompt-injection
def detect_hf(
prompt: str, threshold: float = 0.5, classifier=laiyer_classifier, label: str = "INJECTION"
) -> (bool, bool):
try:
pi_result = classifier(prompt)
injection_score = round(
pi_result[0]["score"] if pi_result[0]["label"] == label else 1 - pi_result[0]["score"],
2,
)
logger.info(f"Prompt injection result from the HF model: {pi_result}")
return True, injection_score > threshold
except Exception as err:
logger.error(f"Failed to call HF model: {err}")
return False, False
def detect_hf_laiyer(prompt: str) -> (bool, bool):
return detect_hf(prompt, classifier=laiyer_classifier)
def detect_hf_deepset(prompt: str) -> (bool, bool):
return detect_hf(prompt, classifier=deepset_classifier)
def detect_hf_fmops(prompt: str) -> (bool, bool):
return detect_hf(prompt, classifier=fmops_classifier, label="LABEL_1")
def detect_lakera(prompt: str) -> (bool, bool):
try:
response = requests.post(
"https://api.lakera.ai/v1/prompt_injection",
json={"input": prompt},
headers={"Authorization": f"Bearer {lakera_api_key}"},
)
response_json = response.json()
logger.info(f"Prompt injection result from Lakera: {response.json()}")
return True, response_json["results"][0]["flagged"]
except requests.RequestException as err:
logger.error(f"Failed to call Lakera API: {err}")
return False, False
def detect_automorphic(prompt: str) -> (bool, bool):
ag = aegis.Aegis(automorphic_api_key)
try:
ingress_attack_detected = ag.ingress(prompt, "")
logger.info(f"Prompt injection result from Automorphic: {ingress_attack_detected}")
return True, ingress_attack_detected["detected"]
except Exception as err:
logger.error(f"Failed to call Automorphic API: {err}")
return False, False # Assume it's not attack
def detect_rebuff(prompt: str) -> (bool, bool):
try:
rb = Rebuff(api_token=rebuff_api_key, api_url="https://www.rebuff.ai")
result = rb.detect_injection(prompt)
logger.info(f"Prompt injection result from Rebuff: {result}")
return True, result.injectionDetected
except Exception as err:
logger.error(f"Failed to call Rebuff API: {err}")
return False, False
def detect_azure(prompt: str) -> (bool, bool):
try:
response = requests.post(
f"{azure_content_safety_endpoint}contentsafety/text:detectJailbreak?api-version=2023-10-15-preview",
json={"text": prompt},
headers={"Ocp-Apim-Subscription-Key": azure_content_safety_key},
)
response_json = response.json()
logger.info(f"Prompt injection result from Azure: {response.json()}")
if "jailbreakAnalysis" not in response_json:
return False, False
return True, response_json["jailbreakAnalysis"]["detected"]
except requests.RequestException as err:
logger.error(f"Failed to call Azure API: {err}")
return False, False
detection_providers = {
"Laiyer (HF model)": detect_hf_laiyer,
"Deepset (HF model)": detect_hf_deepset,
"FMOps (HF model)": detect_hf_fmops,
"Lakera Guard": detect_lakera,
"Automorphic Aegis": detect_automorphic,
# "Rebuff": detect_rebuff,
"Azure Content Safety": detect_azure,
}
def is_detected(provider: str, prompt: str) -> (str, bool, bool, float):
if provider not in detection_providers:
logger.warning(f"Provider {provider} is not supported")
return False, 0.0
start_time = time.monotonic()
request_result, is_injection = detection_providers[provider](prompt)
end_time = time.monotonic()
return provider, request_result, is_injection, convert_elapsed_time(end_time - start_time)
def execute(prompt: str) -> List[Union[str, bool, float]]:
results = []
with mp.Pool(processes=num_processes) as pool:
for result in pool.starmap(
is_detected, [(provider, prompt) for provider in detection_providers.keys()]
):
results.append(result)
# Save image and result
fileobj = json.dumps(
{"prompt": prompt, "results": results}, indent=2, ensure_ascii=False
).encode("utf-8")
result_path = f"/prompts/train/{str(uuid.uuid4())}.json"
hf_api.upload_file(
path_or_fileobj=fileobj,
path_in_repo=result_path,
repo_id="laiyer/prompt-injection-benchmark",
repo_type="dataset",
)
logger.info(f"Stored prompt: {prompt}")
return results
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--port", type=int, default=7860)
parser.add_argument("--url", type=str, default="0.0.0.0")
args, left_argv = parser.parse_known_args()
example_files = glob.glob(os.path.join(os.path.dirname(__file__), "examples", "*.txt"))
examples = [open(file).read() for file in example_files]
gr.Interface(
fn=execute,
inputs=[
gr.Textbox(label="Prompt"),
],
outputs=[
gr.Dataframe(
headers=[
"Provider",
"Is processed successfully?",
"Is prompt injection?",
"Latency (seconds)",
],
datatype=["str", "bool", "bool", "number"],
label="Results",
),
],
title="Prompt Injection Solutions Benchmark",
description="This interface aims to benchmark the known prompt injection detection providers. "
"The results are stored in the private dataset for further analysis and improvements. This interface is for research purposes only."
"
"
"HuggingFace (HF) models are hosted on Spaces while other providers are called as APIs.
"
"Join our Slack community to discuss LLM Security
"
"Secure your LLM interactions with LLM Guard",
examples=[
[
example,
False,
]
for example in examples
],
cache_examples=True,
allow_flagging="never",
concurrency_limit=1,
).launch(server_name=args.url, server_port=args.port)