File size: 4,036 Bytes
a521442 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 |
# Docs:- https://huggingface.co/meta-llama/LlamaGuard-7b
from dotenv import load_dotenv, find_dotenv
import os
import requests
model_id = "meta-llama/LlamaGuard-7b"
# Please ensure you have a .env file available with 'HUGGINGFACEHUB_API_TOKEN'
load_dotenv(find_dotenv())
HUGGINGFACEHUB_API_TOKEN = os.environ["HUGGINGFACEHUB_API_TOKEN"]
#updated on March 24th
#reading unsafe categories file
try:
with open('./prompts/llama_guard-unsafe_categories.txt', 'r') as file:
unsafe_categories = file.read()
except FileNotFoundError:
print("File not found. Please check the file path.")
unsafe_categories = "" # Or handle the missing file as appropriate for your application
except PermissionError:
print("Permission denied. You don't have the rights to read the file.")
unsafe_categories = "" # Or handle the lack of permissions as appropriate
except Exception as e: # Catch any other exceptions
print(f"An error occurred while reading the file: {e}")
unsafe_categories = "" # Or handle unexpected errors as appropriate
#added on March 24th
def moderation_prompt_for_chat(chat):
prompt = f"""[INST] Task: Check if there is unsafe content in 'User' messages in conversations according our safety policy with the below categories.
<BEGIN UNSAFE CONTENT CATEGORIES>
{unsafe_categories}
<END UNSAFE CONTENT CATEGORIES>
<BEGIN CONVERSATION>
User : {chat}
<END CONVERSATION>
Provide your safety assessment for 'User' in the above conversation:
- First line must read 'safe' or 'unsafe'.
- If unsafe, a second line must include a comma-separated list of violated categories.[/INST]"""
return prompt
def query(payload):
API_URL = "https://okoknht2arqo574k.us-east-1.aws.endpoints.huggingface.cloud"
bearer_txt = f'Bearer {HUGGINGFACEHUB_API_TOKEN}'
headers = {
"Accept": "application/json",
"Authorization": bearer_txt,
"Content-Type": "application/json"
}
try:
response = requests.post(API_URL, headers=headers, json=payload)
response.raise_for_status() # This will raise an exception for HTTP error responses
return response.json(), None
except requests.exceptions.HTTPError as http_err:
error_message = f"HTTP error occurred: {http_err}"
print(error_message)
except requests.exceptions.ConnectionError:
error_message = "Could not connect to the API endpoint."
print(error_message)
except Exception as err:
error_message = f"An error occurred: {err}"
print(error_message)
return None, error_message
def moderate_chat(chat):
prompt = moderation_prompt_for_chat(chat)
output, error_msg = query({
"inputs": prompt,
"parameters": {
"top_k": 1,
"top_p": 0.2,
"temperature": 0.1,
"max_new_tokens": 512
}
})
return output, error_msg
#added on March 24th
def load_category_names_from_string(file_content):
"""Load category codes and names from a string into a dictionary."""
category_names = {}
lines = file_content.split('\n')
for line in lines:
if line.startswith("O"):
parts = line.split(':')
if len(parts) == 2:
code = parts[0].strip()
name = parts[1].strip()
category_names[code] = name
return category_names
def get_category_name(input_str):
"""Return the category name given a category code from an input string."""
# Load the category names from the file content
category_names = load_category_names_from_string(unsafe_categories)
# Extract the category code from the input string
category_code = input_str.split('\n')[1].strip()
# Find the full category name using the code
category_name = category_names.get(category_code, "Unknown Category")
#return f"{category_code} : {category_name}"
return f"{category_name}"
|