* azure update
Browse files
app.py
CHANGED
@@ -30,7 +30,7 @@ sydelabs_api_key = os.getenv("SYDELABS_API_KEY")
|
|
30 |
rebuff_api_key = os.getenv("REBUFF_API_KEY")
|
31 |
azure_content_safety_endpoint = os.getenv("AZURE_CONTENT_SAFETY_ENDPOINT")
|
32 |
azure_content_safety_key = os.getenv("AZURE_CONTENT_SAFETY_KEY")
|
33 |
-
|
34 |
|
35 |
|
36 |
@lru_cache(maxsize=2)
|
@@ -146,33 +146,37 @@ def detect_rebuff(prompt: str) -> (bool, bool):
|
|
146 |
def detect_azure(prompt: str) -> (bool, bool):
|
147 |
try:
|
148 |
response = requests.post(
|
149 |
-
f"{azure_content_safety_endpoint}contentsafety/text:
|
150 |
-
json={"
|
151 |
headers={"Ocp-Apim-Subscription-Key": azure_content_safety_key},
|
152 |
)
|
153 |
response_json = response.json()
|
154 |
logger.info(f"Prompt injection result from Azure: {response.json()}")
|
155 |
|
156 |
-
if "
|
157 |
return False, False
|
158 |
|
159 |
-
return True, response_json["
|
160 |
except requests.RequestException as err:
|
161 |
logger.error(f"Failed to call Azure API: {err}")
|
162 |
return False, False
|
163 |
|
164 |
|
165 |
-
def
|
166 |
-
response =
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
|
|
|
|
|
|
|
|
174 |
|
175 |
-
return True,
|
176 |
|
177 |
|
178 |
def detect_sydelabs(prompt: str) -> (bool, bool):
|
@@ -213,7 +217,7 @@ detection_providers = {
|
|
213 |
# "Rebuff": detect_rebuff,
|
214 |
"Azure Content Safety": detect_azure,
|
215 |
"SydeLabs": detect_sydelabs,
|
216 |
-
"AWS Comprehend":
|
217 |
}
|
218 |
|
219 |
|
|
|
30 |
rebuff_api_key = os.getenv("REBUFF_API_KEY")
|
31 |
azure_content_safety_endpoint = os.getenv("AZURE_CONTENT_SAFETY_ENDPOINT")
|
32 |
azure_content_safety_key = os.getenv("AZURE_CONTENT_SAFETY_KEY")
|
33 |
+
bedrock_runtime_client = boto3.client('bedrock-runtime', region_name="us-east-1")
|
34 |
|
35 |
|
36 |
@lru_cache(maxsize=2)
|
|
|
146 |
def detect_azure(prompt: str) -> (bool, bool):
|
147 |
try:
|
148 |
response = requests.post(
|
149 |
+
f"{azure_content_safety_endpoint}contentsafety/text:shieldPrompt?api-version=2024-02-15-preview",
|
150 |
+
json={"userPrompt": prompt},
|
151 |
headers={"Ocp-Apim-Subscription-Key": azure_content_safety_key},
|
152 |
)
|
153 |
response_json = response.json()
|
154 |
logger.info(f"Prompt injection result from Azure: {response.json()}")
|
155 |
|
156 |
+
if "userPromptAnalysis" not in response_json:
|
157 |
return False, False
|
158 |
|
159 |
+
return True, response_json["userPromptAnalysis"]["attackDetected"]
|
160 |
except requests.RequestException as err:
|
161 |
logger.error(f"Failed to call Azure API: {err}")
|
162 |
return False, False
|
163 |
|
164 |
|
165 |
+
def detect_aws_bedrock(prompt: str) -> (bool, bool):
|
166 |
+
response = bedrock_runtime_client.apply_guardrail(
|
167 |
+
guardrailIdentifier="arn:aws:bedrock:us-east-1:364432806369:guardrail/tx8t6psx14ho",
|
168 |
+
guardrailVersion="1",
|
169 |
+
source='INPUT',
|
170 |
+
content=[
|
171 |
+
{"text": {"text": prompt}}
|
172 |
+
])
|
173 |
+
|
174 |
+
logger.info(f"Prompt injection result from AWS Bedrock Guardrails: {response}")
|
175 |
+
# if response["ResponseMetadata"]["HTTPStatusCode"] != 200:
|
176 |
+
# logger.error(f"Failed to call AWS Comprehend API: {response}")
|
177 |
+
# return False, False
|
178 |
|
179 |
+
return True, True
|
180 |
|
181 |
|
182 |
def detect_sydelabs(prompt: str) -> (bool, bool):
|
|
|
217 |
# "Rebuff": detect_rebuff,
|
218 |
"Azure Content Safety": detect_azure,
|
219 |
"SydeLabs": detect_sydelabs,
|
220 |
+
"AWS Comprehend": detect_aws_bedrock,
|
221 |
}
|
222 |
|
223 |
|