|
from langgraph.checkpoint.sqlite import SqliteSaver |
|
|
|
class extended_sqliteSaver(SqliteSaver): |
|
def __init__(self, conn_string): |
|
super().__init__(conn_string) |
|
self.create_table("chat_history", {"thread_id": "TEXT", "message": "TEXT", "response": "TEXT"}) |
|
|
|
@classmethod |
|
def from_conn_string(cls, conn_string: str) -> "SqliteSaver": |
|
import sqlite3 |
|
return SqliteSaver(conn=sqlite3.connect(conn_string, check_same_thread=False)) |
|
|
|
class Memory: |
|
|
|
def __init__(self, payload): |
|
self.checkpoints = extended_sqliteSaver.from_conn_string("_chat_history.sqlite") |
|
|
|
if(payload.get("thread_id")): |
|
self.thread_id = payload.get("thread_id") |
|
self.isNew = 0 |
|
else: |
|
self.thread_id = "1" |
|
self.isNew = 1 |
|
|
|
try: |
|
|
|
cursor = self.checkpoints.conn.cursor() |
|
|
|
|
|
n=20 |
|
cursor.execute(f"DELETE FROM checkpoints WHERE thread_id NOT IN (SELECT thread_id FROM checkpoints ORDER BY thread_id DESC LIMIT {n})") |
|
|
|
|
|
self.thread_id=cursor.execute("SELECT thread_id FROM checkpoints ORDER BY thread_id DESC LIMIT 1").fetchone()[0] or 0 |
|
self.thread_id=int(self.thread_id)+1 |
|
except: |
|
self.thread_id=1 |
|
|
|
print("New thread created with id #", self.thread_id) |
|
|
|
|