Spaces:
Sleeping
Sleeping
WilliamGazeley
commited on
Commit
•
262a057
1
Parent(s):
7bbdf0e
Remove unused files
Browse files- .gitignore +1 -0
- src/functioncall.py +0 -141
- src/prompt_assets/few_shot.json +0 -8
- src/prompt_assets/output_sys_prompt.yml +0 -19
- src/prompt_assets/sys_prompt.yml +0 -39
- src/prompter.py +0 -76
- src/schema.py +0 -23
- src/utils.py +0 -122
- src/validator.py +0 -149
.gitignore
CHANGED
@@ -7,3 +7,4 @@ __pycache__/
|
|
7 |
inference_logs/
|
8 |
|
9 |
logs/*
|
|
|
|
7 |
inference_logs/
|
8 |
|
9 |
logs/*
|
10 |
+
*.log
|
src/functioncall.py
DELETED
@@ -1,141 +0,0 @@
|
|
1 |
-
import argparse
|
2 |
-
import torch
|
3 |
-
import json
|
4 |
-
from config import config
|
5 |
-
from typing import List, Dict
|
6 |
-
from logger import logger
|
7 |
-
|
8 |
-
from transformers import AutoTokenizer
|
9 |
-
|
10 |
-
import functions
|
11 |
-
from prompter import PromptManager
|
12 |
-
from validator import validate_function_call_schema
|
13 |
-
from langchain_community.chat_models import ChatOllama
|
14 |
-
from langchain_community.llms import Ollama
|
15 |
-
from langchain.prompts import PromptTemplate
|
16 |
-
from langchain_core.output_parsers import StrOutputParser
|
17 |
-
|
18 |
-
from utils import (
|
19 |
-
get_chat_template,
|
20 |
-
validate_and_extract_tool_calls
|
21 |
-
)
|
22 |
-
|
23 |
-
class ModelInference:
|
24 |
-
def __init__(self, chat_template: str):
|
25 |
-
self.prompter = PromptManager()
|
26 |
-
|
27 |
-
self.model = Ollama(model=config.ollama_model, temperature=0.0, format='json')
|
28 |
-
template = PromptTemplate(template="""<|im_start|>system\nYou are a function calling AI model. You are provided with function signatures within <tools></tools> XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. Here are the available tools: <tools> {"type": "function", "function": {"name": "get_stock_fundamentals", "description": "get_stock_fundamentals(symbol: str) -> dict - Get fundamental data for a given stock symbol using yfinance API.\\n\\n Args:\\n symbol (str): The stock symbol.\\n\\n Returns:\\n dict: A dictionary containing fundamental data.\\n Keys:\\n - \'symbol\': The stock symbol.\\n - \'company_name\': The long name of the company.\\n - \'sector\': The sector to which the company belongs.\\n - \'industry\': The industry to which the company belongs.\\n - \'market_cap\': The market capitalization of the company.\\n - \'pe_ratio\': The forward price-to-earnings ratio.\\n - \'pb_ratio\': The price-to-book ratio.\\n - \'dividend_yield\': The dividend yield.\\n - \'eps\': The trailing earnings per share.\\n - \'beta\': The beta value of the stock.\\n - \'52_week_high\': The 52-week high price of the stock.\\n - \'52_week_low\': The 52-week low price of the stock.", "parameters": {"type": "object", "properties": {"symbol": {"type": "string"}}, "required": ["symbol"]}}} </tools> Use the following pydantic model json schema for each tool call you will make: {"properties": {"arguments": {"title": "Arguments", "type": "object"}, "name": {"title": "Name", "type": "string"}}, "required": ["arguments", "name"], "title": "FunctionCall", "type": "object"} For each function call return a json object with function name and arguments within <tool_call></tool_call> XML tags as follows:\n<tool_call>\n{"arguments": <args-dict>, "name": <function-name>}\n</tool_call><|im_end|>\n""", input_variables=["question"])
|
29 |
-
chain = template | self.model | StrOutputParser()
|
30 |
-
|
31 |
-
self.tokenizer = AutoTokenizer.from_pretrained(config.hf_model, trust_remote_code=True)
|
32 |
-
self.tokenizer.pad_token = self.tokenizer.eos_token
|
33 |
-
self.tokenizer.padding_side = "left"
|
34 |
-
|
35 |
-
if self.tokenizer.chat_template is None:
|
36 |
-
print("No chat template defined, getting chat_template...")
|
37 |
-
self.tokenizer.chat_template = get_chat_template(chat_template)
|
38 |
-
|
39 |
-
logger.info(f"Model loaded: {self.model}")
|
40 |
-
|
41 |
-
def process_completion_and_validate(self, completion, chat_template):
|
42 |
-
if completion:
|
43 |
-
# completion = f"<tool_call>\n{completion}\n</tool_call>"]
|
44 |
-
validation, tool_calls, error_message = validate_and_extract_tool_calls(completion)
|
45 |
-
|
46 |
-
if validation:
|
47 |
-
logger.info(f"parsed tool calls:\n{json.dumps(tool_calls, indent=2)}")
|
48 |
-
return tool_calls, completion, error_message
|
49 |
-
else:
|
50 |
-
tool_calls = None
|
51 |
-
return tool_calls, completion, error_message
|
52 |
-
else:
|
53 |
-
logger.warning("Assistant message is None")
|
54 |
-
raise ValueError("Assistant message is None")
|
55 |
-
|
56 |
-
def execute_function_call(self, tool_call):
|
57 |
-
# config.status.update(label=":mag: Gathering information..")
|
58 |
-
function_name = tool_call.get("name")
|
59 |
-
function_to_call = getattr(functions, function_name, None)
|
60 |
-
function_args = tool_call.get("arguments", {})
|
61 |
-
|
62 |
-
logger.info(f"Invoking function call {function_name} ...")
|
63 |
-
function_response = function_to_call(*function_args.values())
|
64 |
-
results_dict = f'{{"name": "{function_name}", "content": {function_response}}}'
|
65 |
-
return results_dict
|
66 |
-
|
67 |
-
def run_inference(self, prompt: List[Dict[str, str]]):
|
68 |
-
inputs = self.tokenizer.apply_chat_template(
|
69 |
-
prompt,
|
70 |
-
add_generation_prompt=True,
|
71 |
-
tokenize=False,
|
72 |
-
)
|
73 |
-
inputs = inputs.replace("<|begin_of_text|>", "") # Something wrong with the chat template, hotfix
|
74 |
-
completion = self.model.invoke(inputs, format='json')
|
75 |
-
return completion.content
|
76 |
-
|
77 |
-
def generate_function_call(self, query, chat_template, num_fewshot, max_depth=5):
|
78 |
-
try:
|
79 |
-
depth = 0
|
80 |
-
user_message = f"{query}\nThis is the first turn and you don't have <tool_results> to analyze yet"
|
81 |
-
chat = [{"role": "user", "content": user_message}]
|
82 |
-
tools = functions.get_openai_tools()
|
83 |
-
prompt = self.prompter.generate_prompt(chat, tools, num_fewshot)
|
84 |
-
# config.status.update(label=":brain: Thinking..")
|
85 |
-
completion = self.run_inference(prompt)
|
86 |
-
|
87 |
-
def recursive_loop(prompt, completion, depth):
|
88 |
-
nonlocal max_depth
|
89 |
-
tool_calls, assistant_message, error_message = self.process_completion_and_validate(completion, chat_template)
|
90 |
-
prompt.append({"role": "assistant", "content": assistant_message})
|
91 |
-
|
92 |
-
tool_message = f"Agent iteration {depth} to assist with user query: {query}\n"
|
93 |
-
logger.info(f"Found tool calls: {tool_calls}")
|
94 |
-
if tool_calls:
|
95 |
-
logger.info(f"Assistant Message:\n{assistant_message}")
|
96 |
-
|
97 |
-
for tool_call in tool_calls:
|
98 |
-
validation, message = validate_function_call_schema(tool_call, tools)
|
99 |
-
if validation:
|
100 |
-
try:
|
101 |
-
function_response = self.execute_function_call(tool_call)
|
102 |
-
tool_message += f"<tool_response>\n{function_response}\n</tool_response>\n"
|
103 |
-
logger.info(f"Here's the response from the function call: {tool_call.get('name')}\n{function_response}")
|
104 |
-
except Exception as e:
|
105 |
-
logger.info(f"Could not execute function: {e}")
|
106 |
-
tool_message += f"<tool_response>\nThere was an error when executing the function: {tool_call.get('name')}\nHere's the error traceback: {e}\nPlease call this function again with correct arguments within XML tags <tool_call></tool_call>\n</tool_response>\n"
|
107 |
-
else:
|
108 |
-
logger.info(message)
|
109 |
-
tool_message += f"<tool_response>\nThere was an error validating function call against function signature: {tool_call.get('name')}\nHere's the error traceback: {message}\nPlease call this function again with correct arguments within XML tags <tool_call></tool_call>\n</tool_response>\n"
|
110 |
-
prompt.append({"role": "tool", "content": tool_message})
|
111 |
-
|
112 |
-
depth += 1
|
113 |
-
if depth >= max_depth:
|
114 |
-
print(f"Maximum recursion depth reached ({max_depth}). Stopping recursion.")
|
115 |
-
completion = self.run_inference(prompt)
|
116 |
-
return completion
|
117 |
-
|
118 |
-
# config.status.update(label=":brain: Analysing information..")
|
119 |
-
completion = self.run_inference(prompt)
|
120 |
-
return recursive_loop(prompt, completion, depth)
|
121 |
-
elif error_message:
|
122 |
-
logger.info(f"Assistant Message:\n{assistant_message}")
|
123 |
-
tool_message += f"<tool_response>\nThere was an error parsing function calls\n Here's the error stack trace: {error_message}\nPlease call the function again with correct syntax<tool_response>"
|
124 |
-
prompt.append({"role": "tool", "content": tool_message})
|
125 |
-
|
126 |
-
depth += 1
|
127 |
-
if depth >= max_depth:
|
128 |
-
print(f"Maximum recursion depth reached ({max_depth}). Stopping recursion.")
|
129 |
-
return completion
|
130 |
-
|
131 |
-
completion = self.run_inference(prompt)
|
132 |
-
return recursive_loop(prompt, completion, depth)
|
133 |
-
else:
|
134 |
-
logger.info(f"Assistant Message:\n{assistant_message}")
|
135 |
-
return assistant_message
|
136 |
-
|
137 |
-
return recursive_loop(prompt, completion, depth)
|
138 |
-
|
139 |
-
except Exception as e:
|
140 |
-
logger.error(f"Exception occurred: {e}")
|
141 |
-
raise e
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/prompt_assets/few_shot.json
DELETED
@@ -1,8 +0,0 @@
|
|
1 |
-
[
|
2 |
-
{
|
3 |
-
"example": "```\nSYSTEM: You are a helpful assistant who has access to functions. Use them if required\n<tools>[\n {\n \"name\": \"calculate_distance\",\n \"description\": \"Calculate the distance between two locations\",\n \"parameters\": {\n \"type\": \"object\",\n \"properties\": {\n \"origin\": {\n \"type\": \"string\",\n \"description\": \"The starting location\"\n },\n \"destination\": {\n \"type\": \"string\",\n \"description\": \"The destination location\"\n },\n \"mode\": {\n \"type\": \"string\",\n \"description\": \"The mode of transportation\"\n }\n },\n \"required\": [\n \"origin\",\n \"destination\",\n \"mode\"\n ]\n }\n },\n {\n \"name\": \"generate_password\",\n \"description\": \"Generate a random password\",\n \"parameters\": {\n \"type\": \"object\",\n \"properties\": {\n \"length\": {\n \"type\": \"integer\",\n \"description\": \"The length of the password\"\n }\n },\n \"required\": [\n \"length\"\n ]\n }\n }\n]\n\n</tools>\nUSER: Hi, I need to know the distance from New York to Los Angeles by car.\nASSISTANT:\n<tool_call>\n{\"arguments\": {\"origin\": \"New York\",\n \"destination\": \"Los Angeles\", \"mode\": \"car\"}, \"name\": \"calculate_distance\"}\n</tool_call>\n```\n"
|
4 |
-
},
|
5 |
-
{
|
6 |
-
"example": "```\nSYSTEM: You are a helpful assistant with access to functions. Use them if required\n<tools>[\n {\n \"name\": \"calculate_distance\",\n \"description\": \"Calculate the distance between two locations\",\n \"parameters\": {\n \"type\": \"object\",\n \"properties\": {\n \"origin\": {\n \"type\": \"string\",\n \"description\": \"The starting location\"\n },\n \"destination\": {\n \"type\": \"string\",\n \"description\": \"The destination location\"\n },\n \"mode\": {\n \"type\": \"string\",\n \"description\": \"The mode of transportation\"\n }\n },\n \"required\": [\n \"origin\",\n \"destination\",\n \"mode\"\n ]\n }\n },\n {\n \"name\": \"generate_password\",\n \"description\": \"Generate a random password\",\n \"parameters\": {\n \"type\": \"object\",\n \"properties\": {\n \"length\": {\n \"type\": \"integer\",\n \"description\": \"The length of the password\"\n }\n },\n \"required\": [\n \"length\"\n ]\n }\n }\n]\n\n</tools>\nUSER: Can you help me generate a random password with a length of 8 characters?\nASSISTANT:\n<tool_call>\n{\"arguments\": {\"length\": 8}, \"name\": \"generate_password\"}\n</tool_call>\n```"
|
7 |
-
}
|
8 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/prompt_assets/output_sys_prompt.yml
DELETED
@@ -1,19 +0,0 @@
|
|
1 |
-
Role: |
|
2 |
-
You are a financial expert named IRAI with experience and expertise in stocks and cryptocurrency.
|
3 |
-
You have a comprehensive understanding of finance, investing, and quantiative analysis; you are a thought leader in these fields.
|
4 |
-
You will be fielding questions about finance from top executives and managers of successful startups.
|
5 |
-
Your audience is high net worth individuals with high risk tolerance and lots of investing experience - they do not need traditional, generic financial advice.
|
6 |
-
Objective: |
|
7 |
-
Answer questions as best as accurately as possible given your current knowledge and all available information.
|
8 |
-
Your answers should demonstrate expert insight and indepth analysis.
|
9 |
-
Use analysis, information, and data from the assistant message to form your answers.
|
10 |
-
Instructions: |
|
11 |
-
Answer in a professional and engaging manner representing a top female investment professional working at a leading investment bank.
|
12 |
-
Try to incorprate the numerical data from the function calls to support your analysis, do not state any other numerical data in your answer.
|
13 |
-
Use all the available information to answer the question, including the data and information from the function calls.
|
14 |
-
Do not mention any function calls, such as "get_analysis", "get_current_stock_price", or "get_key_financial_ratios" in your answer.
|
15 |
-
Give a direct answer to question, concise yet insightful.
|
16 |
-
Do not give additional instructions related to seeking financial advisor or professional in your answer.
|
17 |
-
Do not ask any followup questions in your answer.
|
18 |
-
Do not ask for any additional information in your answer.
|
19 |
-
Do not mention Morgan Stanley in your answer.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/prompt_assets/sys_prompt.yml
DELETED
@@ -1,39 +0,0 @@
|
|
1 |
-
Role: |
|
2 |
-
You are a function calling AI agent with self-recursion.
|
3 |
-
You will use your comprehensive understanding of finance and investing for function calls.
|
4 |
-
You can call only one function at a time and analyse data you get from function response.
|
5 |
-
You are provided with function signatures within <tools></tools> XML tags.
|
6 |
-
The current date is: {date}.
|
7 |
-
Objective: |
|
8 |
-
You may use agentic frameworks for reasoning and planning to help with user questions.
|
9 |
-
Please call a function and wait for function results to be provided to you in the next iteration.
|
10 |
-
Don't make assumptions about what values to plug into function arguments.
|
11 |
-
Once you have called a function, results will be fed back to you within <tool_response></tool_response> XML tags.
|
12 |
-
Don't make assumptions about tool results if <tool_response> XML tags are not present since function hasn't been executed yet.
|
13 |
-
Analyze the data once you get the results and call another function.
|
14 |
-
At each iteration please continue adding the your analysis to previous summary.
|
15 |
-
Your final response should directly answer the user questions using the results of function calls.
|
16 |
-
Tools: |
|
17 |
-
Here are the available tools:
|
18 |
-
<tools> {tools} </tools>
|
19 |
-
If the provided function signatures doesn't have the function you must call, you may write executable python code in markdown syntax and call code_interpreter() function as follows:
|
20 |
-
<tool_call>
|
21 |
-
{{"arguments": {{"code_markdown": <python-code>, "name": "code_interpreter"}}}}
|
22 |
-
</tool_call>
|
23 |
-
Make sure that the json object above with code markdown block is parseable with json.loads() and the XML block with XML ElementTree.
|
24 |
-
Examples: |
|
25 |
-
Here are some example usage of functions:
|
26 |
-
{examples}
|
27 |
-
Schema: |
|
28 |
-
Use the following pydantic model json schema for each tool call you will make:
|
29 |
-
{schema}
|
30 |
-
Instructions: |
|
31 |
-
At the very first turn you don't have <tool_results> so you shouldn't not make up the results.
|
32 |
-
Please keep a running summary with analysis of previous function results and summaries from previous iterations.
|
33 |
-
Do not stop calling functions until the task has been accomplished or you've reached max iteration of 10.
|
34 |
-
Calling multiple functions at once can overload the system and increase cost so call one function at a time please.
|
35 |
-
If you plan to continue with analysis, always call another function.
|
36 |
-
For each function call return a valid json object (using doulbe quotes) with function name and arguments within <tool_call></tool_call> XML tags as follows:
|
37 |
-
<tool_call>
|
38 |
-
{{"arguments": <args-dict>, "name": <function-name>}}
|
39 |
-
</tool_call>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/prompter.py
DELETED
@@ -1,76 +0,0 @@
|
|
1 |
-
import datetime
|
2 |
-
from pydantic import BaseModel
|
3 |
-
from typing import Dict
|
4 |
-
from schema import FunctionCall
|
5 |
-
from utils import (
|
6 |
-
get_fewshot_examples
|
7 |
-
)
|
8 |
-
import yaml
|
9 |
-
import json
|
10 |
-
import os
|
11 |
-
|
12 |
-
class PromptSchema(BaseModel):
|
13 |
-
Role: str
|
14 |
-
Objective: str
|
15 |
-
Tools: str
|
16 |
-
Examples: str
|
17 |
-
Schema: str
|
18 |
-
Instructions: str
|
19 |
-
|
20 |
-
class PromptManager:
|
21 |
-
def __init__(self):
|
22 |
-
self.script_dir = os.path.dirname(os.path.abspath(__file__))
|
23 |
-
|
24 |
-
def format_yaml_prompt(self, prompt_schema: PromptSchema, variables: Dict) -> str:
|
25 |
-
formatted_prompt = ""
|
26 |
-
for field, value in prompt_schema.dict().items():
|
27 |
-
if field == "Examples" and variables.get("examples") is None:
|
28 |
-
continue
|
29 |
-
formatted_value = value.format(**variables)
|
30 |
-
if field == "Instructions":
|
31 |
-
formatted_prompt += f"{formatted_value}"
|
32 |
-
else:
|
33 |
-
formatted_value = formatted_value.replace("\n", " ")
|
34 |
-
formatted_prompt += f"{formatted_value}"
|
35 |
-
return formatted_prompt
|
36 |
-
|
37 |
-
def read_yaml_file(self, file_path: str) -> PromptSchema:
|
38 |
-
with open(file_path, 'r') as file:
|
39 |
-
yaml_content = yaml.safe_load(file)
|
40 |
-
|
41 |
-
prompt_schema = PromptSchema(
|
42 |
-
Role=yaml_content.get('Role', ''),
|
43 |
-
Objective=yaml_content.get('Objective', ''),
|
44 |
-
Tools=yaml_content.get('Tools', ''),
|
45 |
-
Examples=yaml_content.get('Examples', ''),
|
46 |
-
Schema=yaml_content.get('Schema', ''),
|
47 |
-
Instructions=yaml_content.get('Instructions', ''),
|
48 |
-
)
|
49 |
-
return prompt_schema
|
50 |
-
|
51 |
-
def generate_prompt(self, user_prompt, tools, num_fewshot=None):
|
52 |
-
prompt_path = os.path.join(self.script_dir, 'prompt_assets', 'sys_prompt.yml')
|
53 |
-
prompt_schema = self.read_yaml_file(prompt_path)
|
54 |
-
|
55 |
-
if num_fewshot is not None:
|
56 |
-
examples = get_fewshot_examples(num_fewshot)
|
57 |
-
else:
|
58 |
-
examples = None
|
59 |
-
|
60 |
-
schema_json = json.loads(FunctionCall.schema_json())
|
61 |
-
|
62 |
-
variables = {
|
63 |
-
"date": datetime.date.today(),
|
64 |
-
"tools": tools,
|
65 |
-
"examples": examples,
|
66 |
-
"schema": schema_json
|
67 |
-
}
|
68 |
-
sys_prompt = self.format_yaml_prompt(prompt_schema, variables)
|
69 |
-
|
70 |
-
prompt = [
|
71 |
-
{'content': sys_prompt, 'role': 'system'}
|
72 |
-
]
|
73 |
-
prompt.extend(user_prompt)
|
74 |
-
return prompt
|
75 |
-
|
76 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/schema.py
DELETED
@@ -1,23 +0,0 @@
|
|
1 |
-
from pydantic import BaseModel
|
2 |
-
from typing import List, Dict, Literal, Optional
|
3 |
-
|
4 |
-
class FunctionCall(BaseModel):
|
5 |
-
arguments: dict
|
6 |
-
"""
|
7 |
-
The arguments to call the function with, as generated by the model in JSON
|
8 |
-
format. Note that the model does not always generate valid JSON, and may
|
9 |
-
hallucinate parameters not defined by your function schema. Validate the
|
10 |
-
arguments in your code before calling your function.
|
11 |
-
"""
|
12 |
-
|
13 |
-
name: str
|
14 |
-
"""The name of the function to call."""
|
15 |
-
|
16 |
-
class FunctionDefinition(BaseModel):
|
17 |
-
name: str
|
18 |
-
description: Optional[str] = None
|
19 |
-
parameters: Optional[Dict[str, object]] = None
|
20 |
-
|
21 |
-
class FunctionSignature(BaseModel):
|
22 |
-
function: FunctionDefinition
|
23 |
-
type: Literal["function"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/utils.py
DELETED
@@ -1,122 +0,0 @@
|
|
1 |
-
import ast
|
2 |
-
import os
|
3 |
-
import re
|
4 |
-
import json
|
5 |
-
import logging
|
6 |
-
import datetime
|
7 |
-
import xml.etree.ElementTree as ET
|
8 |
-
from logger import logger
|
9 |
-
|
10 |
-
from logging.handlers import RotatingFileHandler
|
11 |
-
|
12 |
-
logging.basicConfig(
|
13 |
-
format="%(asctime)s,%(msecs)03d %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s",
|
14 |
-
datefmt="%Y-%m-%d:%H:%M:%S",
|
15 |
-
level=logging.INFO,
|
16 |
-
)
|
17 |
-
script_dir = os.path.dirname(os.path.abspath(__file__))
|
18 |
-
now = datetime.datetime.now()
|
19 |
-
log_folder = os.path.join(script_dir, "inference_logs")
|
20 |
-
os.makedirs(log_folder, exist_ok=True)
|
21 |
-
log_file_path = os.path.join(
|
22 |
-
log_folder, f"function-calling-inference_{now.strftime('%Y-%m-%d_%H-%M-%S')}.log"
|
23 |
-
)
|
24 |
-
# Use RotatingFileHandler from the logging.handlers module
|
25 |
-
file_handler = RotatingFileHandler(log_file_path, maxBytes=0, backupCount=0)
|
26 |
-
file_handler.setLevel(logging.INFO)
|
27 |
-
|
28 |
-
formatter = logging.Formatter("%(asctime)s,%(msecs)03d %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s", datefmt="%Y-%m-%d:%H:%M:%S")
|
29 |
-
file_handler.setFormatter(formatter)
|
30 |
-
|
31 |
-
def get_fewshot_examples(num_fewshot):
|
32 |
-
"""return a list of few shot examples"""
|
33 |
-
example_path = os.path.join(script_dir, 'prompt_assets', 'few_shot.json')
|
34 |
-
with open(example_path, 'r') as file:
|
35 |
-
examples = json.load(file) # Use json.load with the file object, not the file path
|
36 |
-
if num_fewshot > len(examples):
|
37 |
-
raise ValueError(f"Not enough examples (got {num_fewshot}, but there are only {len(examples)} examples).")
|
38 |
-
return examples[:num_fewshot]
|
39 |
-
|
40 |
-
def get_chat_template(chat_template):
|
41 |
-
"""read chat template from jinja file"""
|
42 |
-
template_path = os.path.join(script_dir, 'chat_templates', f"{chat_template}.j2")
|
43 |
-
|
44 |
-
if not os.path.exists(template_path):
|
45 |
-
print
|
46 |
-
logger.error(f"Template file not found: {chat_template}")
|
47 |
-
return None
|
48 |
-
try:
|
49 |
-
with open(template_path, 'r') as file:
|
50 |
-
template = file.read()
|
51 |
-
return template
|
52 |
-
except Exception as e:
|
53 |
-
print(f"Error loading template: {e}")
|
54 |
-
return None
|
55 |
-
|
56 |
-
def validate_and_extract_tool_calls(assistant_content):
|
57 |
-
validation_result = False
|
58 |
-
tool_calls = []
|
59 |
-
error_message = None
|
60 |
-
|
61 |
-
try:
|
62 |
-
# wrap content in root element
|
63 |
-
xml_root_element = f"<root>{assistant_content}</root>"
|
64 |
-
root = ET.fromstring(xml_root_element)
|
65 |
-
|
66 |
-
# extract JSON data
|
67 |
-
for element in root.findall(".//tool_call"):
|
68 |
-
json_data = None
|
69 |
-
try:
|
70 |
-
json_text = element.text.strip()
|
71 |
-
|
72 |
-
try:
|
73 |
-
# Prioritize json.loads for better error handling
|
74 |
-
json_data = json.loads(json_text)
|
75 |
-
except json.JSONDecodeError as json_err:
|
76 |
-
try:
|
77 |
-
# Fallback to ast.literal_eval if json.loads fails
|
78 |
-
json_data = ast.literal_eval(json_text)
|
79 |
-
except (SyntaxError, ValueError) as eval_err:
|
80 |
-
error_message = f"JSON parsing failed with both json.loads and ast.literal_eval:\n"\
|
81 |
-
f"- JSON Decode Error: {json_err}\n"\
|
82 |
-
f"- Fallback Syntax/Value Error: {eval_err}\n"\
|
83 |
-
f"- Problematic JSON text: {json_text}"
|
84 |
-
logger.error(error_message)
|
85 |
-
continue
|
86 |
-
except Exception as e:
|
87 |
-
error_message = f"Cannot strip text: {e}"
|
88 |
-
logger.error(error_message)
|
89 |
-
|
90 |
-
if json_data is not None:
|
91 |
-
tool_calls.append(json_data)
|
92 |
-
validation_result = True
|
93 |
-
|
94 |
-
except ET.ParseError as err:
|
95 |
-
error_message = f"XML Parse Error: {err}"
|
96 |
-
logger.error(f"XML Parse Error: {err}")
|
97 |
-
|
98 |
-
# Return default values if no valid data is extracted
|
99 |
-
return validation_result, tool_calls, error_message
|
100 |
-
|
101 |
-
def extract_json_from_markdown(text):
|
102 |
-
"""
|
103 |
-
Extracts the JSON string from the given text using a regular expression pattern.
|
104 |
-
|
105 |
-
Args:
|
106 |
-
text (str): The input text containing the JSON string.
|
107 |
-
|
108 |
-
Returns:
|
109 |
-
dict: The JSON data loaded from the extracted string, or None if the JSON string is not found.
|
110 |
-
"""
|
111 |
-
json_pattern = r'```json\r?\n(.*?)\r?\n```'
|
112 |
-
match = re.search(json_pattern, text, re.DOTALL)
|
113 |
-
if match:
|
114 |
-
json_string = match.group(1)
|
115 |
-
try:
|
116 |
-
data = json.loads(json_string)
|
117 |
-
return data
|
118 |
-
except json.JSONDecodeError as e:
|
119 |
-
print(f"Error decoding JSON string: {e}")
|
120 |
-
else:
|
121 |
-
print("JSON string not found in the text.")
|
122 |
-
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/validator.py
DELETED
@@ -1,149 +0,0 @@
|
|
1 |
-
import ast
|
2 |
-
import json
|
3 |
-
from jsonschema import validate
|
4 |
-
from pydantic import ValidationError
|
5 |
-
from logger import logger
|
6 |
-
from utils import extract_json_from_markdown
|
7 |
-
from schema import FunctionCall, FunctionSignature
|
8 |
-
|
9 |
-
|
10 |
-
def validate_function_call_schema(call, signatures):
|
11 |
-
try:
|
12 |
-
call_data = FunctionCall(**call)
|
13 |
-
except ValidationError as e:
|
14 |
-
return False, str(e)
|
15 |
-
|
16 |
-
for signature in signatures:
|
17 |
-
try:
|
18 |
-
signature_data = FunctionSignature(**signature)
|
19 |
-
if signature_data.function.name == call_data.name:
|
20 |
-
# Validate types in function arguments
|
21 |
-
for arg_name, arg_schema in signature_data.function.parameters.get(
|
22 |
-
"properties", {}
|
23 |
-
).items():
|
24 |
-
if arg_name in call_data.arguments:
|
25 |
-
call_arg_value = call_data.arguments[arg_name]
|
26 |
-
if call_arg_value:
|
27 |
-
try:
|
28 |
-
validate_argument_type(
|
29 |
-
arg_name, call_arg_value, arg_schema
|
30 |
-
)
|
31 |
-
except Exception as arg_validation_error:
|
32 |
-
return False, str(arg_validation_error)
|
33 |
-
|
34 |
-
# Check if all required arguments are present
|
35 |
-
required_arguments = signature_data.function.parameters.get(
|
36 |
-
"required", []
|
37 |
-
)
|
38 |
-
result, missing_arguments = check_required_arguments(
|
39 |
-
call_data.arguments, required_arguments
|
40 |
-
)
|
41 |
-
if not result:
|
42 |
-
return False, f"Missing required arguments: {missing_arguments}"
|
43 |
-
|
44 |
-
return True, None
|
45 |
-
except Exception as e:
|
46 |
-
# Handle validation errors for the function signature
|
47 |
-
return False, str(e)
|
48 |
-
|
49 |
-
# No matching function signature found
|
50 |
-
return False, f"No matching function signature found for function: {call_data.name}"
|
51 |
-
|
52 |
-
|
53 |
-
def check_required_arguments(call_arguments, required_arguments):
|
54 |
-
missing_arguments = [arg for arg in required_arguments if arg not in call_arguments]
|
55 |
-
return not bool(missing_arguments), missing_arguments
|
56 |
-
|
57 |
-
|
58 |
-
def validate_enum_value(arg_name, arg_value, enum_values):
|
59 |
-
if arg_value not in enum_values:
|
60 |
-
raise Exception(
|
61 |
-
f"Invalid value '{arg_value}' for parameter {arg_name}. Expected one of {', '.join(map(str, enum_values))}"
|
62 |
-
)
|
63 |
-
|
64 |
-
|
65 |
-
def validate_argument_type(arg_name, arg_value, arg_schema):
|
66 |
-
arg_type = arg_schema.get("type", None)
|
67 |
-
if arg_type:
|
68 |
-
if arg_type == "string" and "enum" in arg_schema:
|
69 |
-
enum_values = arg_schema["enum"]
|
70 |
-
if None not in enum_values and enum_values != []:
|
71 |
-
try:
|
72 |
-
validate_enum_value(arg_name, arg_value, enum_values)
|
73 |
-
except Exception as e:
|
74 |
-
# Propagate the validation error message
|
75 |
-
raise Exception(f"Error validating function call: {e}")
|
76 |
-
|
77 |
-
python_type = get_python_type(arg_type)
|
78 |
-
if not isinstance(arg_value, python_type):
|
79 |
-
raise Exception(
|
80 |
-
f"Type mismatch for parameter {arg_name}. Expected: {arg_type}, Got: {type(arg_value)}"
|
81 |
-
)
|
82 |
-
|
83 |
-
|
84 |
-
def get_python_type(json_type):
|
85 |
-
type_mapping = {
|
86 |
-
"string": str,
|
87 |
-
"number": (int, float),
|
88 |
-
"integer": int,
|
89 |
-
"boolean": bool,
|
90 |
-
"array": list,
|
91 |
-
"object": dict,
|
92 |
-
"null": type(None),
|
93 |
-
}
|
94 |
-
return type_mapping[json_type]
|
95 |
-
|
96 |
-
|
97 |
-
def validate_json_data(json_object, json_schema):
|
98 |
-
valid = False
|
99 |
-
error_message = None
|
100 |
-
result_json = None
|
101 |
-
|
102 |
-
try:
|
103 |
-
# Attempt to load JSON using json.loads
|
104 |
-
try:
|
105 |
-
result_json = json.loads(json_object)
|
106 |
-
except json.decoder.JSONDecodeError:
|
107 |
-
# If json.loads fails, try ast.literal_eval
|
108 |
-
try:
|
109 |
-
result_json = ast.literal_eval(json_object)
|
110 |
-
except (SyntaxError, ValueError) as e:
|
111 |
-
try:
|
112 |
-
result_json = extract_json_from_markdown(json_object)
|
113 |
-
except Exception as e:
|
114 |
-
error_message = f"JSON decoding error: {e}"
|
115 |
-
logger.info(f"Validation failed for JSON data: {error_message}")
|
116 |
-
return valid, result_json, error_message
|
117 |
-
|
118 |
-
# Return early if both json.loads and ast.literal_eval fail
|
119 |
-
if result_json is None:
|
120 |
-
error_message = "Failed to decode JSON data"
|
121 |
-
logger.info(f"Validation failed for JSON data: {error_message}")
|
122 |
-
return valid, result_json, error_message
|
123 |
-
|
124 |
-
# Validate each item in the list against schema if it's a list
|
125 |
-
if isinstance(result_json, list):
|
126 |
-
for index, item in enumerate(result_json):
|
127 |
-
try:
|
128 |
-
validate(instance=item, schema=json_schema)
|
129 |
-
logger.info(f"Item {index+1} is valid against the schema.")
|
130 |
-
except ValidationError as e:
|
131 |
-
error_message = f"Validation failed for item {index+1}: {e}"
|
132 |
-
break
|
133 |
-
else:
|
134 |
-
# Default to validation without list
|
135 |
-
try:
|
136 |
-
validate(instance=result_json, schema=json_schema)
|
137 |
-
except ValidationError as e:
|
138 |
-
error_message = f"Validation failed: {e}"
|
139 |
-
|
140 |
-
except Exception as e:
|
141 |
-
error_message = f"Error occurred: {e}"
|
142 |
-
|
143 |
-
if error_message is None:
|
144 |
-
valid = True
|
145 |
-
logger.info("JSON data is valid against the schema.")
|
146 |
-
else:
|
147 |
-
logger.info(f"Validation failed for JSON data: {error_message}")
|
148 |
-
|
149 |
-
return valid, result_json, error_message
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|