memory is now in place
Browse files- agent/_create.py +12 -5
- agent/memory.py +41 -3
- agent/prompt.py +8 -6
agent/_create.py
CHANGED
@@ -39,16 +39,23 @@ def agent(payload):
|
|
39 |
workflow.add_conditional_edges("agent", should_continue, {"continue": "action", "end": END} )
|
40 |
workflow.add_edge('action', 'agent')
|
41 |
|
42 |
-
from agent.memory import
|
43 |
-
|
|
|
|
|
44 |
|
45 |
|
46 |
from agent.prompt import prompt
|
47 |
-
prompt=prompt[
|
48 |
-
|
|
|
49 |
|
50 |
-
|
|
|
|
|
51 |
print(response[-1].content)
|
52 |
|
|
|
|
|
53 |
return response[-1].content
|
54 |
|
|
|
39 |
workflow.add_conditional_edges("agent", should_continue, {"continue": "action", "end": END} )
|
40 |
workflow.add_edge('action', 'agent')
|
41 |
|
42 |
+
from agent.memory import Memory
|
43 |
+
memory = Memory(payload)
|
44 |
+
|
45 |
+
app = workflow.compile(checkpointer=memory.checkpoints)
|
46 |
|
47 |
|
48 |
from agent.prompt import prompt
|
49 |
+
prompt=prompt[memory.isNew]
|
50 |
+
|
51 |
+
input = payload.get("input") or "hi! I'm bob"
|
52 |
|
53 |
+
prompt = prompt.format(input=input, thread_id=memory.thread_id)
|
54 |
+
|
55 |
+
response = app.invoke(prompt, {"configurable": {"thread_id": memory.thread_id}})
|
56 |
print(response[-1].content)
|
57 |
|
58 |
+
print(payload)
|
59 |
+
|
60 |
return response[-1].content
|
61 |
|
agent/memory.py
CHANGED
@@ -1,4 +1,42 @@
|
|
1 |
from langgraph.checkpoint.sqlite import SqliteSaver
|
2 |
-
|
3 |
-
|
4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from langgraph.checkpoint.sqlite import SqliteSaver
|
2 |
+
|
3 |
+
class extended_sqliteSaver(SqliteSaver):
|
4 |
+
def __init__(self, conn_string):
|
5 |
+
super().__init__(conn_string)
|
6 |
+
self.create_table("chat_history", {"thread_id": "TEXT", "message": "TEXT", "response": "TEXT"})
|
7 |
+
|
8 |
+
@classmethod
|
9 |
+
def from_conn_string(cls, conn_string: str) -> "SqliteSaver":
|
10 |
+
import sqlite3
|
11 |
+
return SqliteSaver(conn=sqlite3.connect(conn_string, check_same_thread=False))
|
12 |
+
|
13 |
+
#memory = extended_sqliteSaver.from_conn_string("_chat_history.sqlite")
|
14 |
+
|
15 |
+
class Memory:
|
16 |
+
|
17 |
+
def __init__(self, payload):
|
18 |
+
self.checkpoints = extended_sqliteSaver.from_conn_string("_chat_history.sqlite")
|
19 |
+
|
20 |
+
if(payload.get("thread_id")):
|
21 |
+
self.thread_id = payload.get("thread_id")
|
22 |
+
self.isNew = 0
|
23 |
+
else:
|
24 |
+
self.thread_id = "1"
|
25 |
+
self.isNew = 1
|
26 |
+
|
27 |
+
try:
|
28 |
+
|
29 |
+
cursor = self.checkpoints.conn.cursor()
|
30 |
+
|
31 |
+
## IF THERE ARE MORE THAN N THREADS, DELETE THE OLDEST ONES TO MAKE N THREADS
|
32 |
+
n=20
|
33 |
+
cursor.execute(f"DELETE FROM checkpoints WHERE thread_id NOT IN (SELECT thread_id FROM checkpoints ORDER BY thread_id DESC LIMIT {n})")
|
34 |
+
|
35 |
+
## GET THE HIGHEST THREAD VALUE AND ADD ONE TO CREATE THE NEW THREAD_ID
|
36 |
+
self.thread_id=cursor.execute("SELECT thread_id FROM checkpoints ORDER BY thread_id DESC LIMIT 1").fetchone()[0] or 0
|
37 |
+
self.thread_id=int(self.thread_id)+1
|
38 |
+
except:
|
39 |
+
self.thread_id=1
|
40 |
+
|
41 |
+
print("New thread created with id #", self.thread_id)
|
42 |
+
|
agent/prompt.py
CHANGED
@@ -2,7 +2,12 @@ from langchain_core.prompts import ChatPromptTemplate,SystemMessagePromptTemplat
|
|
2 |
from agent.datastructures import parser
|
3 |
|
4 |
prompt = {
|
5 |
-
0: # IF
|
|
|
|
|
|
|
|
|
|
|
6 |
ChatPromptTemplate.from_messages([
|
7 |
SystemMessagePromptTemplate.from_template("""
|
8 |
|
@@ -12,13 +17,10 @@ prompt = {
|
|
12 |
|
13 |
{response_format}
|
14 |
|
15 |
-
The thread_id of this conversation is 2.
|
16 |
-
|
17 |
"""
|
18 |
|
19 |
).format(response_format=parser.get_format_instructions()),
|
|
|
20 |
("human", "{input}")
|
21 |
-
])
|
22 |
-
1: # IF THREAD IS CONTINUING, WE CAN RELY ON THE ORIGINAL PROMPT
|
23 |
-
ChatPromptTemplate.from_messages([("human", "{input}")])
|
24 |
}
|
|
|
2 |
from agent.datastructures import parser
|
3 |
|
4 |
prompt = {
|
5 |
+
0: # IF THREAD IS CONTINUING, WE CAN RELY ON THE ORIGINAL PROMPT
|
6 |
+
ChatPromptTemplate.from_messages([
|
7 |
+
("system", "The thread_id of this conversation is {thread_id}."),
|
8 |
+
("human", "{input}")
|
9 |
+
]),
|
10 |
+
1: # IF THE THREAD IS NEW, THE CHATBOT NEEDS TO BE PUMP-PROMPTED
|
11 |
ChatPromptTemplate.from_messages([
|
12 |
SystemMessagePromptTemplate.from_template("""
|
13 |
|
|
|
17 |
|
18 |
{response_format}
|
19 |
|
|
|
|
|
20 |
"""
|
21 |
|
22 |
).format(response_format=parser.get_format_instructions()),
|
23 |
+
("system", "The thread_id of this conversation is {thread_id}."),
|
24 |
("human", "{input}")
|
25 |
+
])
|
|
|
|
|
26 |
}
|