SuperExpert / chat.py
JarvisChan630's picture
fix bugs
49bd427
import os
import asyncio
import re
import chainlit as cl
from typing import Dict, Any
from langgraph.graph import StateGraph
from langgraph.checkpoint.memory import MemorySaver
from typing import Union
from chainlit.input_widget import Select
from agents.jar3d_agent import (State,
Jar3d,
MetaExpert,
Router,
NoToolExpert,
ToolExpert,
set_chat_finished,
routing_function,
)
from agents.base_agent import BaseAgent
from utils.read_markdown import read_markdown_file
from config.load_configs import load_config
from dotenv import load_dotenv
# config_path = os.path.join(os.path.dirname(__file__), '..', 'config', 'config.yaml')
# load_config(config_path)
# Load environment variables from .env file
load_dotenv()
server = os.environ.get("LLM_SERVER")
recursion_limit = int(os.environ.get("RECURSION_LIMIT"))
def get_agent_kwargs(server: str = "claude", location: str = None, hybrid: bool = False) -> Dict[str, Any]:
if not location:
location = "us"
else:
location = location
if server == "openai":
agent_kwargs = {
"model": "gpt-4o-mini",
"server": "openai",
"temperature": 0,
}
agent_kwargs_meta_expert = agent_kwargs.copy()
agent_kwargs_meta_expert["model"] = "gpt-4o-mini"
# Mistral
elif server == "mistral":
agent_kwargs = {
"model": "mistral-large-latest",
"server": "mistral",
"temperature": 0,
}
agent_kwargs_meta_expert = agent_kwargs.copy()
elif server == "claude":
agent_kwargs = {
"model": "claude-3-haiku-20240307",
"server": "claude",
"temperature": 0,
}
agent_kwargs_meta_expert = agent_kwargs.copy()
elif server == "ollama":
agent_kwargs = {
"model": os.environ.get("OLLAMA_MODEL"),
"server": "ollama",
"temperature": 0.1,
}
agent_kwargs_meta_expert = agent_kwargs.copy()
elif server == "groq":
agent_kwargs = {
"model": "llama3-groq-70b-8192-tool-use-preview",
"server": "groq",
"temperature": 0,
}
agent_kwargs_meta_expert = agent_kwargs.copy()
# you must change the model and model_endpoint to the correct values
elif server == "vllm":
agent_kwargs = {
"model": "hugging-quants/Meta-Llama-3.1-70B-Instruct-AWQ-INT4",
"server": "vllm",
"temperature": 0.2,
"model_endpoint": "https://s1s4l1lhce486j-8000.proxy.runpod.net/",
}
agent_kwargs_meta_expert = agent_kwargs.copy()
agent_kwargs_tools = agent_kwargs.copy()
agent_kwargs_tools["location"] = location
agent_kwargs_tools["hybrid"] = hybrid
return agent_kwargs, agent_kwargs_tools, agent_kwargs_meta_expert
class Jar3dIntro(BaseAgent[State]):
def __init__(self, model: str = None, server: str = None, temperature: float = 0,
model_endpoint: str = None, stop: str = None):
super().__init__(model, server, temperature, model_endpoint, stop)
self.llm = self.get_llm(json_model=False)
def get_prompt(self, state) -> str:
system_prompt = read_markdown_file('prompt_engineering/jar3d_requirements_prompt.md')
return system_prompt
def process_response(self, response: Any, user_input: str = None, state: State = None) -> Dict[str, Union[str, dict]]:
user_input = "/start"
updates_conversation_history = {
"requirements_gathering": [
{"role": "user", "content": f"{user_input}"},
{"role": "assistant", "content": str(response)}
]
}
return updates_conversation_history
def get_conv_history(self, state: State) -> str:
pass
def get_user_input(self) -> str:
pass
def get_guided_json(self, state: State) -> Dict[str, Any]:
pass
def use_tool(self) -> Any:
pass
def run(self, state: State) -> State:
state = self.invoke(state=state, user_input="/start")
jar3d_intro = state["requirements_gathering"][-1]["content"]
jar3d_intro = re.sub(r'^```python[\s\S]*?```\s*', '', jar3d_intro, flags=re.MULTILINE)
jar3d_intro = jar3d_intro.lstrip()
return jar3d_intro
@cl.on_settings_update
async def update_settings(settings):
location = settings["location"]
location_dict = {
"The United States": "us",
"The United Kingdom": "gb",
"The Netherlands": "nl",
"Canada": "ca"
}
gl = location_dict.get(location, 'us')
cl.user_session.set("gl", gl)
retrieval_mode = settings["retrieval_mode"]
if retrieval_mode == "Hybrid (Graph + Dense)":
hybrid = True
else:
hybrid = False
cl.user_session.set("hybrid", hybrid)
agent_kwargs, agent_kwargs_tools, agent_kwargs_meta_expert = get_agent_kwargs(server, gl, hybrid)
cl.user_session.set("agent_kwargs", agent_kwargs)
cl.user_session.set("agent_kwargs_tools", agent_kwargs_tools)
cl.user_session.set("agent_kwargs_meta_expert", agent_kwargs_meta_expert)
workflow = build_workflow()
cl.user_session.set("workflow", workflow)
await cl.Message(content=f"I'll be conducting any Internet searches from {location} using {retrieval_mode}", author="Jar3dπŸ‘©β€πŸ’»").send()
@cl.on_chat_start
async def start():
agent_memory_dir = '/app/agent_memory'
file_path = os.path.join(agent_memory_dir, 'jar3d_final_response_previous_run.txt')
# Ensure the directory exists
os.makedirs(agent_memory_dir, exist_ok=True)
# Clear the file content
open(file_path, 'w').close()
task_list = cl.TaskList()
task_list.status = "Ready"
cl.user_session.set("task_list", task_list)
# Send the TaskList to the UI
await task_list.send()
state: State = {
"meta_prompt": [],
"conversation_history": [],
"requirements_gathering": [],
"expert_plan": [],
"expert_research": [],
"expert_research_shopping": [],
"expert_writing": [],
"user_input": [],
"previous_search_queries": [],
"router_decision": None,
"chat_limit": None,
"chat_finished": False,
"recursion_limit": None,
"final_answer": None,
"previous_type2_work": [],
"progress_tracking": None
}
cl.user_session.set("state", state)
await cl.ChatSettings(
[
Select(
id="location",
label="Select your location:",
values=[
"The United States",
"The United Kingdom",
"The Netherlands",
"Canada",
]
),
Select(
id="retrieval_mode",
label="Select retrieval mode:",
values=[
"Hybrid (Graph + Dense)",
"Dense Only",
],
initial_index=1,
description="The retrieval mode determines how Jar3d and searches and indexes information from the internet. Hybrid mode performs a deeper search but will cost more."
)
]
).send()
try:
gl = cl.user_session.get("gl")
hybrid = cl.user_session.get("hybrid")
except Exception as e:
gl = "us"
hybrid = False
agent_kwargs, agent_kwargs_tools, agent_kwargs_meta_expert = get_agent_kwargs(server, gl, hybrid)
cl.user_session.set("agent_kwargs", agent_kwargs)
cl.user_session.set("agent_kwargs_tools", agent_kwargs_tools)
cl.user_session.set("agent_kwargs_meta_expert", agent_kwargs_meta_expert)
workflow = build_workflow()
cl.user_session.set("workflow", workflow)
def initialise_jar3d():
jar3d_intro = Jar3dIntro(**agent_kwargs)
jar3d_intro_hi = jar3d_intro.run(state)
jar3d_agent = Jar3d(**agent_kwargs)
return jar3d_intro_hi, jar3d_agent
loop = asyncio.get_running_loop()
jar3d_intro_hi, jar3d_agent = await loop.run_in_executor(None, initialise_jar3d)
cl.user_session.set("jar3d_agent", jar3d_agent)
# Send an initial message to start the conversation
await cl.Message(content=f"{jar3d_intro_hi}.\n\n I'll be conducting any Internet searches from The United States with Dense Retrieval.", author="Jar3dπŸ‘©β€πŸ’»").send()
def build_workflow():
agent_kwargs = cl.user_session.get("agent_kwargs")
agent_kwargs_tools = cl.user_session.get("agent_kwargs_tools")
agent_kwargs_meta_expert = cl.user_session.get("agent_kwargs_meta_expert")
# Initialize agent instances
meta_expert_instance = MetaExpert(**agent_kwargs_meta_expert)
router_instance = Router(**agent_kwargs)
no_tool_expert_instance = NoToolExpert(**agent_kwargs)
tool_expert_instance = ToolExpert(**agent_kwargs_tools)
graph = StateGraph(State)
graph.add_node("meta_expert", lambda state: meta_expert_instance.run(state=state))
graph.add_node("router", lambda state: router_instance.run(state=state))
graph.add_node("no_tool_expert", lambda state: no_tool_expert_instance.run(state=state))
graph.add_node("tool_expert", lambda state: tool_expert_instance.run(state=state))
graph.add_node("end_chat", lambda state: set_chat_finished(state))
graph.set_entry_point("meta_expert")
graph.set_finish_point("end_chat")
graph.add_edge("meta_expert", "router")
graph.add_edge("tool_expert", "meta_expert")
graph.add_edge("no_tool_expert", "meta_expert")
graph.add_conditional_edges(
"router",
lambda state: routing_function(state),
)
checkpointer = MemorySaver()
workflow = graph.compile(checkpointer)
return workflow
def _run_workflow_sync(workflow, state, configs, progress_queue):
seen_progress_messages = set()
try:
for event in workflow.stream(state, configs):
# Access the node's output directly
node_output = next(iter(event.values()))
# Access 'progress_tracking' from the node's output
progress_message = node_output.get("progress_tracking", "")
if progress_message:
if progress_message not in seen_progress_messages:
print(f"Extracted progress_message: {progress_message}")
progress_queue.put_nowait(progress_message)
seen_progress_messages.add(progress_message)
else:
print(f"Duplicate progress_message ignored: {progress_message}")
progress_queue.put_nowait(None) # Signal that the workflow is complete
except Exception as e:
print(f"Exception in workflow execution: {e}")
progress_queue.put_nowait(None)
async def run_workflow(workflow, state):
state["recursion_limit"] = recursion_limit
state["user_input"] = "/start"
configs = {"recursion_limit": recursion_limit + 10, "configurable": {"thread_id": 42}}
progress_queue = asyncio.Queue()
task_list = cl.user_session.get("task_list")
# Set the TaskList status and send it to the UI
task_list.status = "Running..."
await task_list.send()
loop = asyncio.get_running_loop()
# Run the synchronous _run_workflow_sync in a background thread
loop.run_in_executor(
None, _run_workflow_sync, workflow, state, configs, progress_queue
)
# Process progress messages and update the TaskList
while True:
progress_message = await progress_queue.get()
if progress_message is None:
# Workflow is complete
break
# Create a new task with status RUNNING
task = cl.Task(title=progress_message, status=cl.TaskStatus.RUNNING)
await task_list.add_task(task)
await task_list.send()
# Simulate task completion (you can adjust this based on actual progress)
task.status = cl.TaskStatus.DONE
await task_list.send()
# Update TaskList status to Done and send the final update
task_list.status = "Done"
await task_list.send()
# Retrieve the final state
final_state = workflow.get_state(configs)
final_state = final_state.values
final_answer = final_state.get(
"final_answer",
"The agent failed to deliver a final response. Please check the logs for more information."
)
return final_answer
###
# Everytime we sent message
@cl.on_message
async def main(message: cl.Message):
state: State = cl.user_session.get("state")
agent: Jar3d = cl.user_session.get("jar3d_agent")
workflow = cl.user_session.get("workflow")
loop = asyncio.get_running_loop()
state, response = await loop.run_in_executor(None, agent.run_chainlit, state, message)
# Send back to ui user
await cl.Message(content=response, author="Jar3dπŸ‘©β€πŸ’»").send()
# When user want to end
if message.content == "/end":
await cl.Message(
content="This will take some time, probably a good time for a coffee break β˜•...",
author="System"
).send()
final_answer = await run_workflow(workflow, state)
if final_answer:
await cl.Message(content=final_answer, author="Jar3dπŸ‘©β€πŸ’»").send()
else:
await cl.Message(content="No final answer was produced.", author="Jar3dπŸ‘©β€πŸ’»").send()
else:
cl.user_session.set("state", state)
# if __name__ == "__main__":
# cl.run()