|
import argparse |
|
import glob |
|
import json |
|
import logging |
|
import multiprocessing as mp |
|
import os |
|
import time |
|
import uuid |
|
from datetime import timedelta |
|
from functools import lru_cache |
|
from typing import List, Union |
|
|
|
import aegis |
|
import gradio as gr |
|
import requests |
|
from huggingface_hub import HfApi |
|
from optimum.onnxruntime import ORTModelForSequenceClassification |
|
from rebuff import Rebuff |
|
from transformers import AutoTokenizer, pipeline |
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
hf_api = HfApi(token=os.getenv("HF_TOKEN")) |
|
num_processes = 2 |
|
|
|
lakera_api_key = os.getenv("LAKERA_API_KEY") |
|
automorphic_api_key = os.getenv("AUTOMORPHIC_API_KEY") |
|
rebuff_api_key = os.getenv("REBUFF_API_KEY") |
|
azure_content_safety_endpoint = os.getenv("AZURE_CONTENT_SAFETY_ENDPOINT") |
|
azure_content_safety_key = os.getenv("AZURE_CONTENT_SAFETY_KEY") |
|
|
|
|
|
@lru_cache(maxsize=2) |
|
def init_prompt_injection_model(prompt_injection_ort_model: str, subfolder: str = "") -> pipeline: |
|
hf_model = ORTModelForSequenceClassification.from_pretrained( |
|
prompt_injection_ort_model, |
|
export=False, |
|
subfolder=subfolder, |
|
) |
|
hf_tokenizer = AutoTokenizer.from_pretrained(prompt_injection_ort_model, subfolder=subfolder) |
|
hf_tokenizer.model_input_names = ["input_ids", "attention_mask"] |
|
|
|
logger.info(f"Initialized classification ONNX model {prompt_injection_ort_model} on CPU") |
|
|
|
return pipeline( |
|
"text-classification", |
|
model=hf_model, |
|
tokenizer=hf_tokenizer, |
|
device="cpu", |
|
batch_size=1, |
|
truncation=True, |
|
max_length=512, |
|
) |
|
|
|
|
|
def convert_elapsed_time(diff_time) -> float: |
|
return round(timedelta(seconds=diff_time).total_seconds(), 2) |
|
|
|
|
|
deepset_classifier = init_prompt_injection_model( |
|
"laiyer/deberta-v3-base-injection-onnx" |
|
) |
|
laiyer_classifier = init_prompt_injection_model("laiyer/deberta-v3-base-prompt-injection", "onnx") |
|
fmops_classifier = init_prompt_injection_model( |
|
"laiyer/fmops-distilbert-prompt-injection-onnx" |
|
) |
|
|
|
|
|
def detect_hf( |
|
prompt: str, threshold: float = 0.5, classifier=laiyer_classifier, label: str = "INJECTION" |
|
) -> (bool, bool): |
|
try: |
|
pi_result = classifier(prompt) |
|
injection_score = round( |
|
pi_result[0]["score"] if pi_result[0]["label"] == label else 1 - pi_result[0]["score"], |
|
2, |
|
) |
|
|
|
logger.info(f"Prompt injection result from the HF model: {pi_result}") |
|
|
|
return True, injection_score > threshold |
|
except Exception as err: |
|
logger.error(f"Failed to call HF model: {err}") |
|
return False, False |
|
|
|
|
|
def detect_hf_laiyer(prompt: str) -> (bool, bool): |
|
return detect_hf(prompt, classifier=laiyer_classifier) |
|
|
|
|
|
def detect_hf_deepset(prompt: str) -> (bool, bool): |
|
return detect_hf(prompt, classifier=deepset_classifier) |
|
|
|
|
|
def detect_hf_fmops(prompt: str) -> (bool, bool): |
|
return detect_hf(prompt, classifier=fmops_classifier, label="LABEL_1") |
|
|
|
|
|
def detect_lakera(prompt: str) -> (bool, bool): |
|
try: |
|
response = requests.post( |
|
"https://api.lakera.ai/v1/prompt_injection", |
|
json={"input": prompt}, |
|
headers={"Authorization": f"Bearer {lakera_api_key}"}, |
|
) |
|
response_json = response.json() |
|
logger.info(f"Prompt injection result from Lakera: {response.json()}") |
|
|
|
return True, response_json["results"][0]["flagged"] |
|
except requests.RequestException as err: |
|
logger.error(f"Failed to call Lakera API: {err}") |
|
return False, False |
|
|
|
|
|
def detect_automorphic(prompt: str) -> (bool, bool): |
|
ag = aegis.Aegis(automorphic_api_key) |
|
try: |
|
ingress_attack_detected = ag.ingress(prompt, "") |
|
logger.info(f"Prompt injection result from Automorphic: {ingress_attack_detected}") |
|
return True, ingress_attack_detected["detected"] |
|
except Exception as err: |
|
logger.error(f"Failed to call Automorphic API: {err}") |
|
return False, False |
|
|
|
|
|
def detect_rebuff(prompt: str) -> (bool, bool): |
|
try: |
|
rb = Rebuff(api_token=rebuff_api_key, api_url="https://www.rebuff.ai") |
|
result = rb.detect_injection(prompt) |
|
logger.info(f"Prompt injection result from Rebuff: {result}") |
|
|
|
return True, result.injectionDetected |
|
except Exception as err: |
|
logger.error(f"Failed to call Rebuff API: {err}") |
|
return False, False |
|
|
|
|
|
def detect_azure(prompt: str) -> (bool, bool): |
|
try: |
|
response = requests.post( |
|
f"{azure_content_safety_endpoint}contentsafety/text:detectJailbreak?api-version=2023-10-15-preview", |
|
json={"text": prompt}, |
|
headers={"Ocp-Apim-Subscription-Key": azure_content_safety_key}, |
|
) |
|
response_json = response.json() |
|
logger.info(f"Prompt injection result from Azure: {response.json()}") |
|
|
|
if "jailbreakAnalysis" not in response_json: |
|
return False, False |
|
|
|
return True, response_json["jailbreakAnalysis"]["detected"] |
|
except requests.RequestException as err: |
|
logger.error(f"Failed to call Azure API: {err}") |
|
return False, False |
|
|
|
|
|
detection_providers = { |
|
"Laiyer (HF model)": detect_hf_laiyer, |
|
"Deepset (HF model)": detect_hf_deepset, |
|
"FMOps (HF model)": detect_hf_fmops, |
|
"Lakera Guard": detect_lakera, |
|
"Automorphic Aegis": detect_automorphic, |
|
|
|
"Azure Content Safety": detect_azure, |
|
} |
|
|
|
|
|
def is_detected(provider: str, prompt: str) -> (str, bool, bool, float): |
|
if provider not in detection_providers: |
|
logger.warning(f"Provider {provider} is not supported") |
|
return False, 0.0 |
|
|
|
start_time = time.monotonic() |
|
request_result, is_injection = detection_providers[provider](prompt) |
|
end_time = time.monotonic() |
|
|
|
return provider, request_result, is_injection, convert_elapsed_time(end_time - start_time) |
|
|
|
|
|
def execute(prompt: str) -> List[Union[str, bool, float]]: |
|
results = [] |
|
|
|
with mp.Pool(processes=num_processes) as pool: |
|
for result in pool.starmap( |
|
is_detected, [(provider, prompt) for provider in detection_providers.keys()] |
|
): |
|
results.append(result) |
|
|
|
|
|
fileobj = json.dumps( |
|
{"prompt": prompt, "results": results}, indent=2, ensure_ascii=False |
|
).encode("utf-8") |
|
result_path = f"/prompts/train/{str(uuid.uuid4())}.json" |
|
|
|
hf_api.upload_file( |
|
path_or_fileobj=fileobj, |
|
path_in_repo=result_path, |
|
repo_id="laiyer/prompt-injection-benchmark", |
|
repo_type="dataset", |
|
) |
|
logger.info(f"Stored prompt: {prompt}") |
|
|
|
return results |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--port", type=int, default=7860) |
|
parser.add_argument("--url", type=str, default="0.0.0.0") |
|
args, left_argv = parser.parse_known_args() |
|
|
|
example_files = glob.glob(os.path.join(os.path.dirname(__file__), "examples", "*.txt")) |
|
examples = [open(file).read() for file in example_files] |
|
|
|
gr.Interface( |
|
fn=execute, |
|
inputs=[ |
|
gr.Textbox(label="Prompt"), |
|
], |
|
outputs=[ |
|
gr.Dataframe( |
|
headers=[ |
|
"Provider", |
|
"Is processed successfully?", |
|
"Is prompt injection?", |
|
"Latency (seconds)", |
|
], |
|
datatype=["str", "bool", "bool", "number"], |
|
label="Results", |
|
), |
|
], |
|
title="Prompt Injection Solutions Benchmark", |
|
description="This interface aims to benchmark the known prompt injection detection providers. " |
|
"The results are <strong>stored in the private dataset</strong> for further analysis and improvements. This interface is for research purposes only." |
|
"<br /><br />" |
|
"HuggingFace (HF) models are hosted on Spaces while other providers are called as APIs.<br /><br />" |
|
"<a href=\"https://join.slack.com/t/laiyerai/shared_invite/zt-28jv3ci39-sVxXrLs3rQdaN3mIl9IT~w\">Join our Slack community to discuss LLM Security</a><br />" |
|
"<a href=\"https://github.com/laiyer-ai/llm-guard\">Secure your LLM interactions with LLM Guard</a>", |
|
examples=[ |
|
[ |
|
example, |
|
False, |
|
] |
|
for example in examples |
|
], |
|
cache_examples=True, |
|
allow_flagging="never", |
|
concurrency_limit=1, |
|
).launch(server_name=args.url, server_port=args.port) |
|
|