* add AWS comprehend
Browse files- README.md +2 -1
- app.py +38 -3
- requirements.txt +7 -6
README.md
CHANGED
@@ -4,7 +4,7 @@ emoji: π
|
|
4 |
colorFrom: yellow
|
5 |
colorTo: gray
|
6 |
sdk: gradio
|
7 |
-
sdk_version: 4.
|
8 |
pinned: true
|
9 |
license: apache-2.0
|
10 |
---
|
@@ -35,3 +35,4 @@ gradio app.py
|
|
35 |
- [Rebuff](https://rebuff.ai/)
|
36 |
- [Azure Content Safety AI](https://learn.microsoft.com/en-us/azure/ai-services/content-safety/studio-quickstart)
|
37 |
- [AWS Bedrock Guardrails](https://aws.amazon.com/bedrock/guardrails/) (coming soon)
|
|
|
|
4 |
colorFrom: yellow
|
5 |
colorTo: gray
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 4.19.1
|
8 |
pinned: true
|
9 |
license: apache-2.0
|
10 |
---
|
|
|
35 |
- [Rebuff](https://rebuff.ai/)
|
36 |
- [Azure Content Safety AI](https://learn.microsoft.com/en-us/azure/ai-services/content-safety/studio-quickstart)
|
37 |
- [AWS Bedrock Guardrails](https://aws.amazon.com/bedrock/guardrails/) (coming soon)
|
38 |
+
- [AWS Comprehend](https://docs.aws.amazon.com/comprehend/latest/dg/trust-safety.html)
|
app.py
CHANGED
@@ -11,6 +11,7 @@ from functools import lru_cache
|
|
11 |
from typing import List, Union
|
12 |
|
13 |
import aegis
|
|
|
14 |
import gradio as gr
|
15 |
import requests
|
16 |
from huggingface_hub import HfApi
|
@@ -29,6 +30,7 @@ automorphic_api_key = os.getenv("AUTOMORPHIC_API_KEY")
|
|
29 |
rebuff_api_key = os.getenv("REBUFF_API_KEY")
|
30 |
azure_content_safety_endpoint = os.getenv("AZURE_CONTENT_SAFETY_ENDPOINT")
|
31 |
azure_content_safety_key = os.getenv("AZURE_CONTENT_SAFETY_KEY")
|
|
|
32 |
|
33 |
|
34 |
@lru_cache(maxsize=2)
|
@@ -61,7 +63,9 @@ def convert_elapsed_time(diff_time) -> float:
|
|
61 |
deepset_classifier = init_prompt_injection_model(
|
62 |
"ProtectAI/deberta-v3-base-injection-onnx"
|
63 |
) # ONNX version of deepset/deberta-v3-base-injection
|
64 |
-
protectai_classifier = init_prompt_injection_model(
|
|
|
|
|
65 |
fmops_classifier = init_prompt_injection_model(
|
66 |
"ProtectAI/fmops-distilbert-prompt-injection-onnx"
|
67 |
) # ONNX version of fmops/distilbert-prompt-injection
|
@@ -155,6 +159,36 @@ def detect_azure(prompt: str) -> (bool, bool):
|
|
155 |
return False, False
|
156 |
|
157 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
158 |
detection_providers = {
|
159 |
"ProtectAI (HF model)": detect_hf_protectai,
|
160 |
"Deepset (HF model)": detect_hf_deepset,
|
@@ -163,6 +197,7 @@ detection_providers = {
|
|
163 |
"Automorphic Aegis": detect_automorphic,
|
164 |
# "Rebuff": detect_rebuff,
|
165 |
"Azure Content Safety": detect_azure,
|
|
|
166 |
}
|
167 |
|
168 |
|
@@ -235,8 +270,8 @@ if __name__ == "__main__":
|
|
235 |
"The results are <strong>stored in the private dataset</strong> for further analysis and improvements. This interface is for research purposes only."
|
236 |
"<br /><br />"
|
237 |
"HuggingFace (HF) models are hosted on Spaces while other providers are called as APIs.<br /><br />"
|
238 |
-
|
239 |
-
|
240 |
examples=[
|
241 |
[
|
242 |
example,
|
|
|
11 |
from typing import List, Union
|
12 |
|
13 |
import aegis
|
14 |
+
import boto3
|
15 |
import gradio as gr
|
16 |
import requests
|
17 |
from huggingface_hub import HfApi
|
|
|
30 |
rebuff_api_key = os.getenv("REBUFF_API_KEY")
|
31 |
azure_content_safety_endpoint = os.getenv("AZURE_CONTENT_SAFETY_ENDPOINT")
|
32 |
azure_content_safety_key = os.getenv("AZURE_CONTENT_SAFETY_KEY")
|
33 |
+
aws_comprehend_client = boto3.client(service_name="comprehend", region_name="us-east-1")
|
34 |
|
35 |
|
36 |
@lru_cache(maxsize=2)
|
|
|
63 |
deepset_classifier = init_prompt_injection_model(
|
64 |
"ProtectAI/deberta-v3-base-injection-onnx"
|
65 |
) # ONNX version of deepset/deberta-v3-base-injection
|
66 |
+
protectai_classifier = init_prompt_injection_model(
|
67 |
+
"ProtectAI/deberta-v3-base-prompt-injection", "onnx"
|
68 |
+
)
|
69 |
fmops_classifier = init_prompt_injection_model(
|
70 |
"ProtectAI/fmops-distilbert-prompt-injection-onnx"
|
71 |
) # ONNX version of fmops/distilbert-prompt-injection
|
|
|
159 |
return False, False
|
160 |
|
161 |
|
162 |
+
def detect_aws_comprehend(prompt: str) -> (bool, bool):
|
163 |
+
response = aws_comprehend_client.classify_document(
|
164 |
+
EndpointArn="arn:aws:comprehend:us-east-1:aws:document-classifier-endpoint/prompt-safety",
|
165 |
+
Text=prompt,
|
166 |
+
)
|
167 |
+
response = {
|
168 |
+
"Classes": [
|
169 |
+
{"Name": "SAFE_PROMPT", "Score": 0.9010000228881836},
|
170 |
+
{"Name": "UNSAFE_PROMPT", "Score": 0.0989999994635582},
|
171 |
+
],
|
172 |
+
"ResponseMetadata": {
|
173 |
+
"RequestId": "e8900fe1-3346-45c0-bad3-007b2840865a",
|
174 |
+
"HTTPStatusCode": 200,
|
175 |
+
"HTTPHeaders": {
|
176 |
+
"x-amzn-requestid": "e8900fe1-3346-45c0-bad3-007b2840865a",
|
177 |
+
"content-type": "application/x-amz-json-1.1",
|
178 |
+
"content-length": "115",
|
179 |
+
"date": "Mon, 19 Feb 2024 08:34:43 GMT",
|
180 |
+
},
|
181 |
+
"RetryAttempts": 0,
|
182 |
+
},
|
183 |
+
}
|
184 |
+
logger.info(f"Prompt injection result from AWS Comprehend: {response}")
|
185 |
+
if response["ResponseMetadata"]["HTTPStatusCode"] != 200:
|
186 |
+
logger.error(f"Failed to call AWS Comprehend API: {response}")
|
187 |
+
return False, False
|
188 |
+
|
189 |
+
return True, response["Classes"][0] == "UNSAFE_PROMPT"
|
190 |
+
|
191 |
+
|
192 |
detection_providers = {
|
193 |
"ProtectAI (HF model)": detect_hf_protectai,
|
194 |
"Deepset (HF model)": detect_hf_deepset,
|
|
|
197 |
"Automorphic Aegis": detect_automorphic,
|
198 |
# "Rebuff": detect_rebuff,
|
199 |
"Azure Content Safety": detect_azure,
|
200 |
+
"AWS Comprehend": detect_aws_comprehend,
|
201 |
}
|
202 |
|
203 |
|
|
|
270 |
"The results are <strong>stored in the private dataset</strong> for further analysis and improvements. This interface is for research purposes only."
|
271 |
"<br /><br />"
|
272 |
"HuggingFace (HF) models are hosted on Spaces while other providers are called as APIs.<br /><br />"
|
273 |
+
'<a href="https://join.slack.com/t/laiyerai/shared_invite/zt-28jv3ci39-sVxXrLs3rQdaN3mIl9IT~w">Join our Slack community to discuss LLM Security</a><br />'
|
274 |
+
'<a href="https://github.com/protectai/llm-guard">Secure your LLM interactions with LLM Guard</a>',
|
275 |
examples=[
|
276 |
[
|
277 |
example,
|
requirements.txt
CHANGED
@@ -1,8 +1,9 @@
|
|
|
|
1 |
git+https://github.com/automorphic-ai/aegis.git
|
2 |
-
gradio==4.
|
3 |
-
huggingface_hub==0.
|
4 |
-
onnxruntime==1.
|
5 |
-
optimum[onnxruntime]==1.
|
6 |
-
rebuff==0.
|
7 |
requests==2.31.0
|
8 |
-
transformers==4.
|
|
|
1 |
+
boto3==1.34.44
|
2 |
git+https://github.com/automorphic-ai/aegis.git
|
3 |
+
gradio==4.19.1
|
4 |
+
huggingface_hub==0.20.3
|
5 |
+
onnxruntime==1.17.0
|
6 |
+
optimum[onnxruntime]==1.17.1
|
7 |
+
rebuff==0.1.1
|
8 |
requests==2.31.0
|
9 |
+
transformers==4.37.2
|