asofter commited on
Commit
1a0ab69
·
1 Parent(s): 6b80b1f

* azure update

Browse files
Files changed (1) hide show
  1. app.py +20 -16
app.py CHANGED
@@ -30,7 +30,7 @@ sydelabs_api_key = os.getenv("SYDELABS_API_KEY")
30
  rebuff_api_key = os.getenv("REBUFF_API_KEY")
31
  azure_content_safety_endpoint = os.getenv("AZURE_CONTENT_SAFETY_ENDPOINT")
32
  azure_content_safety_key = os.getenv("AZURE_CONTENT_SAFETY_KEY")
33
- aws_comprehend_client = boto3.client(service_name="comprehend", region_name="us-east-1")
34
 
35
 
36
  @lru_cache(maxsize=2)
@@ -146,33 +146,37 @@ def detect_rebuff(prompt: str) -> (bool, bool):
146
  def detect_azure(prompt: str) -> (bool, bool):
147
  try:
148
  response = requests.post(
149
- f"{azure_content_safety_endpoint}contentsafety/text:detectJailbreak?api-version=2023-10-15-preview",
150
- json={"text": prompt},
151
  headers={"Ocp-Apim-Subscription-Key": azure_content_safety_key},
152
  )
153
  response_json = response.json()
154
  logger.info(f"Prompt injection result from Azure: {response.json()}")
155
 
156
- if "jailbreakAnalysis" not in response_json:
157
  return False, False
158
 
159
- return True, response_json["jailbreakAnalysis"]["detected"]
160
  except requests.RequestException as err:
161
  logger.error(f"Failed to call Azure API: {err}")
162
  return False, False
163
 
164
 
165
- def detect_aws_comprehend(prompt: str) -> (bool, bool):
166
- response = aws_comprehend_client.classify_document(
167
- EndpointArn="arn:aws:comprehend:us-east-1:aws:document-classifier-endpoint/prompt-safety",
168
- Text=prompt,
169
- )
170
- logger.info(f"Prompt injection result from AWS Comprehend: {response}")
171
- if response["ResponseMetadata"]["HTTPStatusCode"] != 200:
172
- logger.error(f"Failed to call AWS Comprehend API: {response}")
173
- return False, False
 
 
 
 
174
 
175
- return True, response["Classes"][0] == "UNSAFE_PROMPT"
176
 
177
 
178
  def detect_sydelabs(prompt: str) -> (bool, bool):
@@ -213,7 +217,7 @@ detection_providers = {
213
  # "Rebuff": detect_rebuff,
214
  "Azure Content Safety": detect_azure,
215
  "SydeLabs": detect_sydelabs,
216
- "AWS Comprehend": detect_aws_comprehend,
217
  }
218
 
219
 
 
30
  rebuff_api_key = os.getenv("REBUFF_API_KEY")
31
  azure_content_safety_endpoint = os.getenv("AZURE_CONTENT_SAFETY_ENDPOINT")
32
  azure_content_safety_key = os.getenv("AZURE_CONTENT_SAFETY_KEY")
33
+ bedrock_runtime_client = boto3.client('bedrock-runtime', region_name="us-east-1")
34
 
35
 
36
  @lru_cache(maxsize=2)
 
146
  def detect_azure(prompt: str) -> (bool, bool):
147
  try:
148
  response = requests.post(
149
+ f"{azure_content_safety_endpoint}contentsafety/text:shieldPrompt?api-version=2024-02-15-preview",
150
+ json={"userPrompt": prompt},
151
  headers={"Ocp-Apim-Subscription-Key": azure_content_safety_key},
152
  )
153
  response_json = response.json()
154
  logger.info(f"Prompt injection result from Azure: {response.json()}")
155
 
156
+ if "userPromptAnalysis" not in response_json:
157
  return False, False
158
 
159
+ return True, response_json["userPromptAnalysis"]["attackDetected"]
160
  except requests.RequestException as err:
161
  logger.error(f"Failed to call Azure API: {err}")
162
  return False, False
163
 
164
 
165
+ def detect_aws_bedrock(prompt: str) -> (bool, bool):
166
+ response = bedrock_runtime_client.apply_guardrail(
167
+ guardrailIdentifier="arn:aws:bedrock:us-east-1:364432806369:guardrail/tx8t6psx14ho",
168
+ guardrailVersion="1",
169
+ source='INPUT',
170
+ content=[
171
+ {"text": {"text": prompt}}
172
+ ])
173
+
174
+ logger.info(f"Prompt injection result from AWS Bedrock Guardrails: {response}")
175
+ # if response["ResponseMetadata"]["HTTPStatusCode"] != 200:
176
+ # logger.error(f"Failed to call AWS Comprehend API: {response}")
177
+ # return False, False
178
 
179
+ return True, True
180
 
181
 
182
  def detect_sydelabs(prompt: str) -> (bool, bool):
 
217
  # "Rebuff": detect_rebuff,
218
  "Azure Content Safety": detect_azure,
219
  "SydeLabs": detect_sydelabs,
220
+ "AWS Comprehend": detect_aws_bedrock,
221
  }
222
 
223