WilliamGazeley commited on
Commit
262a057
1 Parent(s): 7bbdf0e

Remove unused files

Browse files
.gitignore CHANGED
@@ -7,3 +7,4 @@ __pycache__/
7
  inference_logs/
8
 
9
  logs/*
 
 
7
  inference_logs/
8
 
9
  logs/*
10
+ *.log
src/functioncall.py DELETED
@@ -1,141 +0,0 @@
1
- import argparse
2
- import torch
3
- import json
4
- from config import config
5
- from typing import List, Dict
6
- from logger import logger
7
-
8
- from transformers import AutoTokenizer
9
-
10
- import functions
11
- from prompter import PromptManager
12
- from validator import validate_function_call_schema
13
- from langchain_community.chat_models import ChatOllama
14
- from langchain_community.llms import Ollama
15
- from langchain.prompts import PromptTemplate
16
- from langchain_core.output_parsers import StrOutputParser
17
-
18
- from utils import (
19
- get_chat_template,
20
- validate_and_extract_tool_calls
21
- )
22
-
23
- class ModelInference:
24
- def __init__(self, chat_template: str):
25
- self.prompter = PromptManager()
26
-
27
- self.model = Ollama(model=config.ollama_model, temperature=0.0, format='json')
28
- template = PromptTemplate(template="""<|im_start|>system\nYou are a function calling AI model. You are provided with function signatures within <tools></tools> XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. Here are the available tools: <tools> {"type": "function", "function": {"name": "get_stock_fundamentals", "description": "get_stock_fundamentals(symbol: str) -> dict - Get fundamental data for a given stock symbol using yfinance API.\\n\\n Args:\\n symbol (str): The stock symbol.\\n\\n Returns:\\n dict: A dictionary containing fundamental data.\\n Keys:\\n - \'symbol\': The stock symbol.\\n - \'company_name\': The long name of the company.\\n - \'sector\': The sector to which the company belongs.\\n - \'industry\': The industry to which the company belongs.\\n - \'market_cap\': The market capitalization of the company.\\n - \'pe_ratio\': The forward price-to-earnings ratio.\\n - \'pb_ratio\': The price-to-book ratio.\\n - \'dividend_yield\': The dividend yield.\\n - \'eps\': The trailing earnings per share.\\n - \'beta\': The beta value of the stock.\\n - \'52_week_high\': The 52-week high price of the stock.\\n - \'52_week_low\': The 52-week low price of the stock.", "parameters": {"type": "object", "properties": {"symbol": {"type": "string"}}, "required": ["symbol"]}}} </tools> Use the following pydantic model json schema for each tool call you will make: {"properties": {"arguments": {"title": "Arguments", "type": "object"}, "name": {"title": "Name", "type": "string"}}, "required": ["arguments", "name"], "title": "FunctionCall", "type": "object"} For each function call return a json object with function name and arguments within <tool_call></tool_call> XML tags as follows:\n<tool_call>\n{"arguments": <args-dict>, "name": <function-name>}\n</tool_call><|im_end|>\n""", input_variables=["question"])
29
- chain = template | self.model | StrOutputParser()
30
-
31
- self.tokenizer = AutoTokenizer.from_pretrained(config.hf_model, trust_remote_code=True)
32
- self.tokenizer.pad_token = self.tokenizer.eos_token
33
- self.tokenizer.padding_side = "left"
34
-
35
- if self.tokenizer.chat_template is None:
36
- print("No chat template defined, getting chat_template...")
37
- self.tokenizer.chat_template = get_chat_template(chat_template)
38
-
39
- logger.info(f"Model loaded: {self.model}")
40
-
41
- def process_completion_and_validate(self, completion, chat_template):
42
- if completion:
43
- # completion = f"<tool_call>\n{completion}\n</tool_call>"]
44
- validation, tool_calls, error_message = validate_and_extract_tool_calls(completion)
45
-
46
- if validation:
47
- logger.info(f"parsed tool calls:\n{json.dumps(tool_calls, indent=2)}")
48
- return tool_calls, completion, error_message
49
- else:
50
- tool_calls = None
51
- return tool_calls, completion, error_message
52
- else:
53
- logger.warning("Assistant message is None")
54
- raise ValueError("Assistant message is None")
55
-
56
- def execute_function_call(self, tool_call):
57
- # config.status.update(label=":mag: Gathering information..")
58
- function_name = tool_call.get("name")
59
- function_to_call = getattr(functions, function_name, None)
60
- function_args = tool_call.get("arguments", {})
61
-
62
- logger.info(f"Invoking function call {function_name} ...")
63
- function_response = function_to_call(*function_args.values())
64
- results_dict = f'{{"name": "{function_name}", "content": {function_response}}}'
65
- return results_dict
66
-
67
- def run_inference(self, prompt: List[Dict[str, str]]):
68
- inputs = self.tokenizer.apply_chat_template(
69
- prompt,
70
- add_generation_prompt=True,
71
- tokenize=False,
72
- )
73
- inputs = inputs.replace("<|begin_of_text|>", "") # Something wrong with the chat template, hotfix
74
- completion = self.model.invoke(inputs, format='json')
75
- return completion.content
76
-
77
- def generate_function_call(self, query, chat_template, num_fewshot, max_depth=5):
78
- try:
79
- depth = 0
80
- user_message = f"{query}\nThis is the first turn and you don't have <tool_results> to analyze yet"
81
- chat = [{"role": "user", "content": user_message}]
82
- tools = functions.get_openai_tools()
83
- prompt = self.prompter.generate_prompt(chat, tools, num_fewshot)
84
- # config.status.update(label=":brain: Thinking..")
85
- completion = self.run_inference(prompt)
86
-
87
- def recursive_loop(prompt, completion, depth):
88
- nonlocal max_depth
89
- tool_calls, assistant_message, error_message = self.process_completion_and_validate(completion, chat_template)
90
- prompt.append({"role": "assistant", "content": assistant_message})
91
-
92
- tool_message = f"Agent iteration {depth} to assist with user query: {query}\n"
93
- logger.info(f"Found tool calls: {tool_calls}")
94
- if tool_calls:
95
- logger.info(f"Assistant Message:\n{assistant_message}")
96
-
97
- for tool_call in tool_calls:
98
- validation, message = validate_function_call_schema(tool_call, tools)
99
- if validation:
100
- try:
101
- function_response = self.execute_function_call(tool_call)
102
- tool_message += f"<tool_response>\n{function_response}\n</tool_response>\n"
103
- logger.info(f"Here's the response from the function call: {tool_call.get('name')}\n{function_response}")
104
- except Exception as e:
105
- logger.info(f"Could not execute function: {e}")
106
- tool_message += f"<tool_response>\nThere was an error when executing the function: {tool_call.get('name')}\nHere's the error traceback: {e}\nPlease call this function again with correct arguments within XML tags <tool_call></tool_call>\n</tool_response>\n"
107
- else:
108
- logger.info(message)
109
- tool_message += f"<tool_response>\nThere was an error validating function call against function signature: {tool_call.get('name')}\nHere's the error traceback: {message}\nPlease call this function again with correct arguments within XML tags <tool_call></tool_call>\n</tool_response>\n"
110
- prompt.append({"role": "tool", "content": tool_message})
111
-
112
- depth += 1
113
- if depth >= max_depth:
114
- print(f"Maximum recursion depth reached ({max_depth}). Stopping recursion.")
115
- completion = self.run_inference(prompt)
116
- return completion
117
-
118
- # config.status.update(label=":brain: Analysing information..")
119
- completion = self.run_inference(prompt)
120
- return recursive_loop(prompt, completion, depth)
121
- elif error_message:
122
- logger.info(f"Assistant Message:\n{assistant_message}")
123
- tool_message += f"<tool_response>\nThere was an error parsing function calls\n Here's the error stack trace: {error_message}\nPlease call the function again with correct syntax<tool_response>"
124
- prompt.append({"role": "tool", "content": tool_message})
125
-
126
- depth += 1
127
- if depth >= max_depth:
128
- print(f"Maximum recursion depth reached ({max_depth}). Stopping recursion.")
129
- return completion
130
-
131
- completion = self.run_inference(prompt)
132
- return recursive_loop(prompt, completion, depth)
133
- else:
134
- logger.info(f"Assistant Message:\n{assistant_message}")
135
- return assistant_message
136
-
137
- return recursive_loop(prompt, completion, depth)
138
-
139
- except Exception as e:
140
- logger.error(f"Exception occurred: {e}")
141
- raise e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/prompt_assets/few_shot.json DELETED
@@ -1,8 +0,0 @@
1
- [
2
- {
3
- "example": "```\nSYSTEM: You are a helpful assistant who has access to functions. Use them if required\n<tools>[\n {\n \"name\": \"calculate_distance\",\n \"description\": \"Calculate the distance between two locations\",\n \"parameters\": {\n \"type\": \"object\",\n \"properties\": {\n \"origin\": {\n \"type\": \"string\",\n \"description\": \"The starting location\"\n },\n \"destination\": {\n \"type\": \"string\",\n \"description\": \"The destination location\"\n },\n \"mode\": {\n \"type\": \"string\",\n \"description\": \"The mode of transportation\"\n }\n },\n \"required\": [\n \"origin\",\n \"destination\",\n \"mode\"\n ]\n }\n },\n {\n \"name\": \"generate_password\",\n \"description\": \"Generate a random password\",\n \"parameters\": {\n \"type\": \"object\",\n \"properties\": {\n \"length\": {\n \"type\": \"integer\",\n \"description\": \"The length of the password\"\n }\n },\n \"required\": [\n \"length\"\n ]\n }\n }\n]\n\n</tools>\nUSER: Hi, I need to know the distance from New York to Los Angeles by car.\nASSISTANT:\n<tool_call>\n{\"arguments\": {\"origin\": \"New York\",\n \"destination\": \"Los Angeles\", \"mode\": \"car\"}, \"name\": \"calculate_distance\"}\n</tool_call>\n```\n"
4
- },
5
- {
6
- "example": "```\nSYSTEM: You are a helpful assistant with access to functions. Use them if required\n<tools>[\n {\n \"name\": \"calculate_distance\",\n \"description\": \"Calculate the distance between two locations\",\n \"parameters\": {\n \"type\": \"object\",\n \"properties\": {\n \"origin\": {\n \"type\": \"string\",\n \"description\": \"The starting location\"\n },\n \"destination\": {\n \"type\": \"string\",\n \"description\": \"The destination location\"\n },\n \"mode\": {\n \"type\": \"string\",\n \"description\": \"The mode of transportation\"\n }\n },\n \"required\": [\n \"origin\",\n \"destination\",\n \"mode\"\n ]\n }\n },\n {\n \"name\": \"generate_password\",\n \"description\": \"Generate a random password\",\n \"parameters\": {\n \"type\": \"object\",\n \"properties\": {\n \"length\": {\n \"type\": \"integer\",\n \"description\": \"The length of the password\"\n }\n },\n \"required\": [\n \"length\"\n ]\n }\n }\n]\n\n</tools>\nUSER: Can you help me generate a random password with a length of 8 characters?\nASSISTANT:\n<tool_call>\n{\"arguments\": {\"length\": 8}, \"name\": \"generate_password\"}\n</tool_call>\n```"
7
- }
8
- ]
 
 
 
 
 
 
 
 
 
src/prompt_assets/output_sys_prompt.yml DELETED
@@ -1,19 +0,0 @@
1
- Role: |
2
- You are a financial expert named IRAI with experience and expertise in stocks and cryptocurrency.
3
- You have a comprehensive understanding of finance, investing, and quantiative analysis; you are a thought leader in these fields.
4
- You will be fielding questions about finance from top executives and managers of successful startups.
5
- Your audience is high net worth individuals with high risk tolerance and lots of investing experience - they do not need traditional, generic financial advice.
6
- Objective: |
7
- Answer questions as best as accurately as possible given your current knowledge and all available information.
8
- Your answers should demonstrate expert insight and indepth analysis.
9
- Use analysis, information, and data from the assistant message to form your answers.
10
- Instructions: |
11
- Answer in a professional and engaging manner representing a top female investment professional working at a leading investment bank.
12
- Try to incorprate the numerical data from the function calls to support your analysis, do not state any other numerical data in your answer.
13
- Use all the available information to answer the question, including the data and information from the function calls.
14
- Do not mention any function calls, such as "get_analysis", "get_current_stock_price", or "get_key_financial_ratios" in your answer.
15
- Give a direct answer to question, concise yet insightful.
16
- Do not give additional instructions related to seeking financial advisor or professional in your answer.
17
- Do not ask any followup questions in your answer.
18
- Do not ask for any additional information in your answer.
19
- Do not mention Morgan Stanley in your answer.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/prompt_assets/sys_prompt.yml DELETED
@@ -1,39 +0,0 @@
1
- Role: |
2
- You are a function calling AI agent with self-recursion.
3
- You will use your comprehensive understanding of finance and investing for function calls.
4
- You can call only one function at a time and analyse data you get from function response.
5
- You are provided with function signatures within <tools></tools> XML tags.
6
- The current date is: {date}.
7
- Objective: |
8
- You may use agentic frameworks for reasoning and planning to help with user questions.
9
- Please call a function and wait for function results to be provided to you in the next iteration.
10
- Don't make assumptions about what values to plug into function arguments.
11
- Once you have called a function, results will be fed back to you within <tool_response></tool_response> XML tags.
12
- Don't make assumptions about tool results if <tool_response> XML tags are not present since function hasn't been executed yet.
13
- Analyze the data once you get the results and call another function.
14
- At each iteration please continue adding the your analysis to previous summary.
15
- Your final response should directly answer the user questions using the results of function calls.
16
- Tools: |
17
- Here are the available tools:
18
- <tools> {tools} </tools>
19
- If the provided function signatures doesn't have the function you must call, you may write executable python code in markdown syntax and call code_interpreter() function as follows:
20
- <tool_call>
21
- {{"arguments": {{"code_markdown": <python-code>, "name": "code_interpreter"}}}}
22
- </tool_call>
23
- Make sure that the json object above with code markdown block is parseable with json.loads() and the XML block with XML ElementTree.
24
- Examples: |
25
- Here are some example usage of functions:
26
- {examples}
27
- Schema: |
28
- Use the following pydantic model json schema for each tool call you will make:
29
- {schema}
30
- Instructions: |
31
- At the very first turn you don't have <tool_results> so you shouldn't not make up the results.
32
- Please keep a running summary with analysis of previous function results and summaries from previous iterations.
33
- Do not stop calling functions until the task has been accomplished or you've reached max iteration of 10.
34
- Calling multiple functions at once can overload the system and increase cost so call one function at a time please.
35
- If you plan to continue with analysis, always call another function.
36
- For each function call return a valid json object (using doulbe quotes) with function name and arguments within <tool_call></tool_call> XML tags as follows:
37
- <tool_call>
38
- {{"arguments": <args-dict>, "name": <function-name>}}
39
- </tool_call>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/prompter.py DELETED
@@ -1,76 +0,0 @@
1
- import datetime
2
- from pydantic import BaseModel
3
- from typing import Dict
4
- from schema import FunctionCall
5
- from utils import (
6
- get_fewshot_examples
7
- )
8
- import yaml
9
- import json
10
- import os
11
-
12
- class PromptSchema(BaseModel):
13
- Role: str
14
- Objective: str
15
- Tools: str
16
- Examples: str
17
- Schema: str
18
- Instructions: str
19
-
20
- class PromptManager:
21
- def __init__(self):
22
- self.script_dir = os.path.dirname(os.path.abspath(__file__))
23
-
24
- def format_yaml_prompt(self, prompt_schema: PromptSchema, variables: Dict) -> str:
25
- formatted_prompt = ""
26
- for field, value in prompt_schema.dict().items():
27
- if field == "Examples" and variables.get("examples") is None:
28
- continue
29
- formatted_value = value.format(**variables)
30
- if field == "Instructions":
31
- formatted_prompt += f"{formatted_value}"
32
- else:
33
- formatted_value = formatted_value.replace("\n", " ")
34
- formatted_prompt += f"{formatted_value}"
35
- return formatted_prompt
36
-
37
- def read_yaml_file(self, file_path: str) -> PromptSchema:
38
- with open(file_path, 'r') as file:
39
- yaml_content = yaml.safe_load(file)
40
-
41
- prompt_schema = PromptSchema(
42
- Role=yaml_content.get('Role', ''),
43
- Objective=yaml_content.get('Objective', ''),
44
- Tools=yaml_content.get('Tools', ''),
45
- Examples=yaml_content.get('Examples', ''),
46
- Schema=yaml_content.get('Schema', ''),
47
- Instructions=yaml_content.get('Instructions', ''),
48
- )
49
- return prompt_schema
50
-
51
- def generate_prompt(self, user_prompt, tools, num_fewshot=None):
52
- prompt_path = os.path.join(self.script_dir, 'prompt_assets', 'sys_prompt.yml')
53
- prompt_schema = self.read_yaml_file(prompt_path)
54
-
55
- if num_fewshot is not None:
56
- examples = get_fewshot_examples(num_fewshot)
57
- else:
58
- examples = None
59
-
60
- schema_json = json.loads(FunctionCall.schema_json())
61
-
62
- variables = {
63
- "date": datetime.date.today(),
64
- "tools": tools,
65
- "examples": examples,
66
- "schema": schema_json
67
- }
68
- sys_prompt = self.format_yaml_prompt(prompt_schema, variables)
69
-
70
- prompt = [
71
- {'content': sys_prompt, 'role': 'system'}
72
- ]
73
- prompt.extend(user_prompt)
74
- return prompt
75
-
76
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/schema.py DELETED
@@ -1,23 +0,0 @@
1
- from pydantic import BaseModel
2
- from typing import List, Dict, Literal, Optional
3
-
4
- class FunctionCall(BaseModel):
5
- arguments: dict
6
- """
7
- The arguments to call the function with, as generated by the model in JSON
8
- format. Note that the model does not always generate valid JSON, and may
9
- hallucinate parameters not defined by your function schema. Validate the
10
- arguments in your code before calling your function.
11
- """
12
-
13
- name: str
14
- """The name of the function to call."""
15
-
16
- class FunctionDefinition(BaseModel):
17
- name: str
18
- description: Optional[str] = None
19
- parameters: Optional[Dict[str, object]] = None
20
-
21
- class FunctionSignature(BaseModel):
22
- function: FunctionDefinition
23
- type: Literal["function"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/utils.py DELETED
@@ -1,122 +0,0 @@
1
- import ast
2
- import os
3
- import re
4
- import json
5
- import logging
6
- import datetime
7
- import xml.etree.ElementTree as ET
8
- from logger import logger
9
-
10
- from logging.handlers import RotatingFileHandler
11
-
12
- logging.basicConfig(
13
- format="%(asctime)s,%(msecs)03d %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s",
14
- datefmt="%Y-%m-%d:%H:%M:%S",
15
- level=logging.INFO,
16
- )
17
- script_dir = os.path.dirname(os.path.abspath(__file__))
18
- now = datetime.datetime.now()
19
- log_folder = os.path.join(script_dir, "inference_logs")
20
- os.makedirs(log_folder, exist_ok=True)
21
- log_file_path = os.path.join(
22
- log_folder, f"function-calling-inference_{now.strftime('%Y-%m-%d_%H-%M-%S')}.log"
23
- )
24
- # Use RotatingFileHandler from the logging.handlers module
25
- file_handler = RotatingFileHandler(log_file_path, maxBytes=0, backupCount=0)
26
- file_handler.setLevel(logging.INFO)
27
-
28
- formatter = logging.Formatter("%(asctime)s,%(msecs)03d %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s", datefmt="%Y-%m-%d:%H:%M:%S")
29
- file_handler.setFormatter(formatter)
30
-
31
- def get_fewshot_examples(num_fewshot):
32
- """return a list of few shot examples"""
33
- example_path = os.path.join(script_dir, 'prompt_assets', 'few_shot.json')
34
- with open(example_path, 'r') as file:
35
- examples = json.load(file) # Use json.load with the file object, not the file path
36
- if num_fewshot > len(examples):
37
- raise ValueError(f"Not enough examples (got {num_fewshot}, but there are only {len(examples)} examples).")
38
- return examples[:num_fewshot]
39
-
40
- def get_chat_template(chat_template):
41
- """read chat template from jinja file"""
42
- template_path = os.path.join(script_dir, 'chat_templates', f"{chat_template}.j2")
43
-
44
- if not os.path.exists(template_path):
45
- print
46
- logger.error(f"Template file not found: {chat_template}")
47
- return None
48
- try:
49
- with open(template_path, 'r') as file:
50
- template = file.read()
51
- return template
52
- except Exception as e:
53
- print(f"Error loading template: {e}")
54
- return None
55
-
56
- def validate_and_extract_tool_calls(assistant_content):
57
- validation_result = False
58
- tool_calls = []
59
- error_message = None
60
-
61
- try:
62
- # wrap content in root element
63
- xml_root_element = f"<root>{assistant_content}</root>"
64
- root = ET.fromstring(xml_root_element)
65
-
66
- # extract JSON data
67
- for element in root.findall(".//tool_call"):
68
- json_data = None
69
- try:
70
- json_text = element.text.strip()
71
-
72
- try:
73
- # Prioritize json.loads for better error handling
74
- json_data = json.loads(json_text)
75
- except json.JSONDecodeError as json_err:
76
- try:
77
- # Fallback to ast.literal_eval if json.loads fails
78
- json_data = ast.literal_eval(json_text)
79
- except (SyntaxError, ValueError) as eval_err:
80
- error_message = f"JSON parsing failed with both json.loads and ast.literal_eval:\n"\
81
- f"- JSON Decode Error: {json_err}\n"\
82
- f"- Fallback Syntax/Value Error: {eval_err}\n"\
83
- f"- Problematic JSON text: {json_text}"
84
- logger.error(error_message)
85
- continue
86
- except Exception as e:
87
- error_message = f"Cannot strip text: {e}"
88
- logger.error(error_message)
89
-
90
- if json_data is not None:
91
- tool_calls.append(json_data)
92
- validation_result = True
93
-
94
- except ET.ParseError as err:
95
- error_message = f"XML Parse Error: {err}"
96
- logger.error(f"XML Parse Error: {err}")
97
-
98
- # Return default values if no valid data is extracted
99
- return validation_result, tool_calls, error_message
100
-
101
- def extract_json_from_markdown(text):
102
- """
103
- Extracts the JSON string from the given text using a regular expression pattern.
104
-
105
- Args:
106
- text (str): The input text containing the JSON string.
107
-
108
- Returns:
109
- dict: The JSON data loaded from the extracted string, or None if the JSON string is not found.
110
- """
111
- json_pattern = r'```json\r?\n(.*?)\r?\n```'
112
- match = re.search(json_pattern, text, re.DOTALL)
113
- if match:
114
- json_string = match.group(1)
115
- try:
116
- data = json.loads(json_string)
117
- return data
118
- except json.JSONDecodeError as e:
119
- print(f"Error decoding JSON string: {e}")
120
- else:
121
- print("JSON string not found in the text.")
122
- return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/validator.py DELETED
@@ -1,149 +0,0 @@
1
- import ast
2
- import json
3
- from jsonschema import validate
4
- from pydantic import ValidationError
5
- from logger import logger
6
- from utils import extract_json_from_markdown
7
- from schema import FunctionCall, FunctionSignature
8
-
9
-
10
- def validate_function_call_schema(call, signatures):
11
- try:
12
- call_data = FunctionCall(**call)
13
- except ValidationError as e:
14
- return False, str(e)
15
-
16
- for signature in signatures:
17
- try:
18
- signature_data = FunctionSignature(**signature)
19
- if signature_data.function.name == call_data.name:
20
- # Validate types in function arguments
21
- for arg_name, arg_schema in signature_data.function.parameters.get(
22
- "properties", {}
23
- ).items():
24
- if arg_name in call_data.arguments:
25
- call_arg_value = call_data.arguments[arg_name]
26
- if call_arg_value:
27
- try:
28
- validate_argument_type(
29
- arg_name, call_arg_value, arg_schema
30
- )
31
- except Exception as arg_validation_error:
32
- return False, str(arg_validation_error)
33
-
34
- # Check if all required arguments are present
35
- required_arguments = signature_data.function.parameters.get(
36
- "required", []
37
- )
38
- result, missing_arguments = check_required_arguments(
39
- call_data.arguments, required_arguments
40
- )
41
- if not result:
42
- return False, f"Missing required arguments: {missing_arguments}"
43
-
44
- return True, None
45
- except Exception as e:
46
- # Handle validation errors for the function signature
47
- return False, str(e)
48
-
49
- # No matching function signature found
50
- return False, f"No matching function signature found for function: {call_data.name}"
51
-
52
-
53
- def check_required_arguments(call_arguments, required_arguments):
54
- missing_arguments = [arg for arg in required_arguments if arg not in call_arguments]
55
- return not bool(missing_arguments), missing_arguments
56
-
57
-
58
- def validate_enum_value(arg_name, arg_value, enum_values):
59
- if arg_value not in enum_values:
60
- raise Exception(
61
- f"Invalid value '{arg_value}' for parameter {arg_name}. Expected one of {', '.join(map(str, enum_values))}"
62
- )
63
-
64
-
65
- def validate_argument_type(arg_name, arg_value, arg_schema):
66
- arg_type = arg_schema.get("type", None)
67
- if arg_type:
68
- if arg_type == "string" and "enum" in arg_schema:
69
- enum_values = arg_schema["enum"]
70
- if None not in enum_values and enum_values != []:
71
- try:
72
- validate_enum_value(arg_name, arg_value, enum_values)
73
- except Exception as e:
74
- # Propagate the validation error message
75
- raise Exception(f"Error validating function call: {e}")
76
-
77
- python_type = get_python_type(arg_type)
78
- if not isinstance(arg_value, python_type):
79
- raise Exception(
80
- f"Type mismatch for parameter {arg_name}. Expected: {arg_type}, Got: {type(arg_value)}"
81
- )
82
-
83
-
84
- def get_python_type(json_type):
85
- type_mapping = {
86
- "string": str,
87
- "number": (int, float),
88
- "integer": int,
89
- "boolean": bool,
90
- "array": list,
91
- "object": dict,
92
- "null": type(None),
93
- }
94
- return type_mapping[json_type]
95
-
96
-
97
- def validate_json_data(json_object, json_schema):
98
- valid = False
99
- error_message = None
100
- result_json = None
101
-
102
- try:
103
- # Attempt to load JSON using json.loads
104
- try:
105
- result_json = json.loads(json_object)
106
- except json.decoder.JSONDecodeError:
107
- # If json.loads fails, try ast.literal_eval
108
- try:
109
- result_json = ast.literal_eval(json_object)
110
- except (SyntaxError, ValueError) as e:
111
- try:
112
- result_json = extract_json_from_markdown(json_object)
113
- except Exception as e:
114
- error_message = f"JSON decoding error: {e}"
115
- logger.info(f"Validation failed for JSON data: {error_message}")
116
- return valid, result_json, error_message
117
-
118
- # Return early if both json.loads and ast.literal_eval fail
119
- if result_json is None:
120
- error_message = "Failed to decode JSON data"
121
- logger.info(f"Validation failed for JSON data: {error_message}")
122
- return valid, result_json, error_message
123
-
124
- # Validate each item in the list against schema if it's a list
125
- if isinstance(result_json, list):
126
- for index, item in enumerate(result_json):
127
- try:
128
- validate(instance=item, schema=json_schema)
129
- logger.info(f"Item {index+1} is valid against the schema.")
130
- except ValidationError as e:
131
- error_message = f"Validation failed for item {index+1}: {e}"
132
- break
133
- else:
134
- # Default to validation without list
135
- try:
136
- validate(instance=result_json, schema=json_schema)
137
- except ValidationError as e:
138
- error_message = f"Validation failed: {e}"
139
-
140
- except Exception as e:
141
- error_message = f"Error occurred: {e}"
142
-
143
- if error_message is None:
144
- valid = True
145
- logger.info("JSON data is valid against the schema.")
146
- else:
147
- logger.info(f"Validation failed for JSON data: {error_message}")
148
-
149
- return valid, result_json, error_message