markpeace commited on
Commit
dbf2f6d
1 Parent(s): 796ceef

memory is now in place

Browse files
Files changed (3) hide show
  1. agent/_create.py +12 -5
  2. agent/memory.py +41 -3
  3. agent/prompt.py +8 -6
agent/_create.py CHANGED
@@ -39,16 +39,23 @@ def agent(payload):
39
  workflow.add_conditional_edges("agent", should_continue, {"continue": "action", "end": END} )
40
  workflow.add_edge('action', 'agent')
41
 
42
- from agent.memory import memory,ThreadStatus,threadID
43
- app = workflow.compile(checkpointer=memory)
 
 
44
 
45
 
46
  from agent.prompt import prompt
47
- prompt=prompt[ThreadStatus]
48
- prompt = prompt.format(input="hi! I'm bob")
 
49
 
50
- response = app.invoke(prompt, {"configurable": {"thread_id": threadID}})
 
 
51
  print(response[-1].content)
52
 
 
 
53
  return response[-1].content
54
 
 
39
  workflow.add_conditional_edges("agent", should_continue, {"continue": "action", "end": END} )
40
  workflow.add_edge('action', 'agent')
41
 
42
+ from agent.memory import Memory
43
+ memory = Memory(payload)
44
+
45
+ app = workflow.compile(checkpointer=memory.checkpoints)
46
 
47
 
48
  from agent.prompt import prompt
49
+ prompt=prompt[memory.isNew]
50
+
51
+ input = payload.get("input") or "hi! I'm bob"
52
 
53
+ prompt = prompt.format(input=input, thread_id=memory.thread_id)
54
+
55
+ response = app.invoke(prompt, {"configurable": {"thread_id": memory.thread_id}})
56
  print(response[-1].content)
57
 
58
+ print(payload)
59
+
60
  return response[-1].content
61
 
agent/memory.py CHANGED
@@ -1,4 +1,42 @@
1
  from langgraph.checkpoint.sqlite import SqliteSaver
2
- memory = SqliteSaver.from_conn_string(":memory:")
3
- ThreadStatus=0
4
- threadID=2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from langgraph.checkpoint.sqlite import SqliteSaver
2
+
3
+ class extended_sqliteSaver(SqliteSaver):
4
+ def __init__(self, conn_string):
5
+ super().__init__(conn_string)
6
+ self.create_table("chat_history", {"thread_id": "TEXT", "message": "TEXT", "response": "TEXT"})
7
+
8
+ @classmethod
9
+ def from_conn_string(cls, conn_string: str) -> "SqliteSaver":
10
+ import sqlite3
11
+ return SqliteSaver(conn=sqlite3.connect(conn_string, check_same_thread=False))
12
+
13
+ #memory = extended_sqliteSaver.from_conn_string("_chat_history.sqlite")
14
+
15
+ class Memory:
16
+
17
+ def __init__(self, payload):
18
+ self.checkpoints = extended_sqliteSaver.from_conn_string("_chat_history.sqlite")
19
+
20
+ if(payload.get("thread_id")):
21
+ self.thread_id = payload.get("thread_id")
22
+ self.isNew = 0
23
+ else:
24
+ self.thread_id = "1"
25
+ self.isNew = 1
26
+
27
+ try:
28
+
29
+ cursor = self.checkpoints.conn.cursor()
30
+
31
+ ## IF THERE ARE MORE THAN N THREADS, DELETE THE OLDEST ONES TO MAKE N THREADS
32
+ n=20
33
+ cursor.execute(f"DELETE FROM checkpoints WHERE thread_id NOT IN (SELECT thread_id FROM checkpoints ORDER BY thread_id DESC LIMIT {n})")
34
+
35
+ ## GET THE HIGHEST THREAD VALUE AND ADD ONE TO CREATE THE NEW THREAD_ID
36
+ self.thread_id=cursor.execute("SELECT thread_id FROM checkpoints ORDER BY thread_id DESC LIMIT 1").fetchone()[0] or 0
37
+ self.thread_id=int(self.thread_id)+1
38
+ except:
39
+ self.thread_id=1
40
+
41
+ print("New thread created with id #", self.thread_id)
42
+
agent/prompt.py CHANGED
@@ -2,7 +2,12 @@ from langchain_core.prompts import ChatPromptTemplate,SystemMessagePromptTemplat
2
  from agent.datastructures import parser
3
 
4
  prompt = {
5
- 0: # IF THE THREAD IS NEW, THE CHATBOT NEEDS TO BE PUMP-PROMPTED
 
 
 
 
 
6
  ChatPromptTemplate.from_messages([
7
  SystemMessagePromptTemplate.from_template("""
8
 
@@ -12,13 +17,10 @@ prompt = {
12
 
13
  {response_format}
14
 
15
- The thread_id of this conversation is 2.
16
-
17
  """
18
 
19
  ).format(response_format=parser.get_format_instructions()),
 
20
  ("human", "{input}")
21
- ]),
22
- 1: # IF THREAD IS CONTINUING, WE CAN RELY ON THE ORIGINAL PROMPT
23
- ChatPromptTemplate.from_messages([("human", "{input}")])
24
  }
 
2
  from agent.datastructures import parser
3
 
4
  prompt = {
5
+ 0: # IF THREAD IS CONTINUING, WE CAN RELY ON THE ORIGINAL PROMPT
6
+ ChatPromptTemplate.from_messages([
7
+ ("system", "The thread_id of this conversation is {thread_id}."),
8
+ ("human", "{input}")
9
+ ]),
10
+ 1: # IF THE THREAD IS NEW, THE CHATBOT NEEDS TO BE PUMP-PROMPTED
11
  ChatPromptTemplate.from_messages([
12
  SystemMessagePromptTemplate.from_template("""
13
 
 
17
 
18
  {response_format}
19
 
 
 
20
  """
21
 
22
  ).format(response_format=parser.get_format_instructions()),
23
+ ("system", "The thread_id of this conversation is {thread_id}."),
24
  ("human", "{input}")
25
+ ])
 
 
26
  }