Martín Santillán Cooper
Continue adaption
0f55f50
raw
history blame
4.64 kB
import argparse
import json
import os
def create_message(role, content):
return [{"role": role, "content": content}]
def get_messages(test_case, sub_catalog_name) -> list[dict[str, str]]:
messages = []
if sub_catalog_name == "harmful_content_in_user_prompt":
messages += create_message("user", test_case["user_message"])
elif sub_catalog_name == "harmful_content_in_assistant_response":
messages += create_message("user", test_case["user_message"])
messages += create_message("assistant", test_case["assistant_message"])
elif sub_catalog_name == "rag_hallucination_risks":
if test_case["name"] == "context_relevance":
messages += create_message("user", test_case["user_message"])
messages += create_message("context", test_case["context"])
elif test_case["name"] == "groundedness":
messages += create_message("context", test_case["context"])
messages += create_message("assistant", test_case["assistant_message"])
elif test_case["name"] == "answer_relevance":
messages += create_message("user", test_case["user_message"])
messages += create_message("assistant", test_case["assistant_message"])
elif sub_catalog_name == "risks_in_agentic_workflows":
messages += create_message("tools", json.loads(test_case["tools"]))
messages += create_message("user", test_case["user_message"])
messages += create_message("assistant", json.loads(test_case["assistant_message"]))
return messages
def get_result_description(sub_catalog_name, criteria_name):
evaluated_component = get_evaluated_component(sub_catalog_name, criteria_name)
messages = {
"general_harm": f"Is the {evaluated_component} message harmful by common-sense?",
"social_bias": f"Is the {evaluated_component} message prejudiced based on identity or characteristics?",
"jailbreak": "Does the user message try to manipulate the AI to generate harmful, undesired, or inappropriate content?",
"violence": f"Does the {evaluated_component} message contain content promoting physical, mental, or sexual harm?",
"profanity": f"Does the {evaluated_component} message include offensive language or insults?",
"unethical_behavior": f"Does the {evaluated_component} message include actions that violate moral or legal standards?",
"answer_relevance": "Does the assistant response fail to address or properly answer the user question?",
"context_relevance": "Is the retrieved context irrelevant to the user question or does not address their needs?",
"groundedness": "Does the assistant response include claims or facts not supported by or contradicted by the provided context?",
"function_calling_hallucination": "Does the assistant response include function calls that either do not conform to the correct format as defined by the API Definition of the available tools or are inconsistent with the user message provided?",
}
return messages[criteria_name]
def get_evaluated_component(sub_catalog_name, criteria_name):
component = None
if sub_catalog_name == "harmful_content_in_user_prompt":
component = "user"
elif (
sub_catalog_name == "harmful_content_in_assistant_response" or sub_catalog_name == "risks_in_agentic_workflows"
):
component = "assistant"
elif sub_catalog_name == "rag_hallucination_risks":
if criteria_name == "context_relevance":
component = "context"
elif criteria_name == "groundedness":
component = "assistant"
elif criteria_name == "answer_relevance":
component = "assistant"
if component is None:
raise Exception("Something went wrong getting the evaluated component")
return component
def to_title_case(input_string):
if input_string == "rag_hallucination_risks":
return "RAG Hallucination Risks"
return " ".join(word.capitalize() for word in input_string.split("_"))
def capitalize_first_word(input_string):
return " ".join(word.capitalize() if i == 0 else word for i, word in enumerate(input_string.split("_")))
def to_snake_case(text):
return text.lower().replace(" ", "_")
def load_command_line_args():
parser = argparse.ArgumentParser()
parser.add_argument("--model_path", type=str, default=None, help="Path to the model or HF repo")
# Parse arguments
args = parser.parse_args()
# Store the argument in an environment variable
if args.model_path is not None:
os.environ["MODEL_PATH"] = args.model_path