|
from operator import itemgetter |
|
from langchain_openai import ChatOpenAI |
|
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder |
|
from langchain.schema import StrOutputParser |
|
from langchain.schema.runnable import Runnable, RunnablePassthrough, RunnableLambda |
|
from langchain.schema.runnable.config import RunnableConfig |
|
from langchain.memory import ConversationBufferMemory |
|
from resolution_logic import ResolutionLogic |
|
from literal_thread_manager import LiteralThreadManager |
|
from prompt_engineering.prompt_desing import system_prompt, system_prompt_b, system_prompt_questioning |
|
import chainlit as cl |
|
from chainlit.types import ThreadDict |
|
import os |
|
from dotenv import load_dotenv |
|
|
|
|
|
load_dotenv() |
|
|
|
jwt_secret_key = os.getenv('CHAINLIT_AUTH_SECRET') |
|
if not jwt_secret_key: |
|
raise ValueError( |
|
"You must provide a JWT secret in the environment to use authentication.") |
|
|
|
|
|
openai_api_key = os.getenv("OPENAI_API_KEY") |
|
|
|
|
|
os.environ["OPENAI_API_KEY"] = openai_api_key |
|
manager = LiteralThreadManager(api_key=os.getenv("LITERAL_API_KEY")) |
|
|
|
def setup_runnable(): |
|
""" |
|
Sets up the runnable pipeline for the chatbot. This pipeline includes a model for generating responses |
|
and memory management for maintaining conversation context. |
|
""" |
|
memory = cl.user_session.get("memory") |
|
model = ChatOpenAI(streaming=True, model="gpt-3.5-turbo") |
|
prompt = ChatPromptTemplate.from_messages( |
|
[ |
|
("system", system_prompt_questioning), |
|
MessagesPlaceholder(variable_name="history"), |
|
("human", "{question}"), |
|
] |
|
) |
|
|
|
runnable = ( |
|
RunnablePassthrough.assign( |
|
history=RunnableLambda( |
|
memory.load_memory_variables) | itemgetter("history") |
|
) |
|
| prompt |
|
| model |
|
| StrOutputParser() |
|
) |
|
cl.user_session.set("runnable", runnable) |
|
|
|
@cl.password_auth_callback |
|
def auth_callback(username: str, password: str): |
|
""" |
|
Authenticates a user using the provided username and password. If the user does not exist in the |
|
LiteralAI database, a new user is created. |
|
|
|
Args: |
|
username (str): The username provided by the user. |
|
password (str): The password provided by the user. |
|
|
|
Returns: |
|
cl.User | None: A User object if authentication is successful, create a User otherwise. |
|
""" |
|
auth_user = manager.literal_client.api.get_or_create_user(identifier=username) |
|
if auth_user: |
|
if username != "admin": |
|
return cl.User( |
|
identifier=username, metadata={ |
|
"role": "user", "provider": "credentials"} |
|
) |
|
else: |
|
return cl.User( |
|
identifier=username, metadata={ |
|
"role": "admin", "provider": "credentials"} |
|
) |
|
else: |
|
return None |
|
|
|
def create_and_update_threads(first_res, current_user, partner_user): |
|
""" |
|
Creates and updates threads for the conversation between the current user and their partner. |
|
|
|
Args: |
|
first_res (str): The initial response from the user. |
|
current_user (cl.User): The current user initiating the conversation. |
|
partner_user (cl.User): The partner user to connect with. |
|
""" |
|
latest_thread = manager.literal_client.api.get_threads(first=1) |
|
partner_thread = manager.literal_client.api.create_thread(name=first_res['output'], participant_id=partner_user.id, metadata={ |
|
"partner_id": current_user.id, "partner_thread_id": latest_thread.data[0].id, "user_id": partner_user.id}) |
|
resolver = ResolutionLogic() |
|
message_to_other_partner = resolver.summarize_conflict_topic(partner_user.identifier, current_user.identifier, first_res['output']) |
|
manager.literal_client.api.create_step(thread_id=partner_thread.id, type="assistant_message", |
|
output={'content': message_to_other_partner}) |
|
current_thread = manager.literal_client.api.upsert_thread(id=latest_thread.data[0].id, |
|
participant_id=current_user.id, metadata={"partner_id": partner_user.id, "partner_thread_id": partner_thread.id}) |
|
cl.user_session.set("thread_id", current_thread.id) |
|
manager.get_other_partner_thread_id(current_thread.id) |
|
|
|
@cl.action_callback("2-1 Chat") |
|
async def on_action(action): |
|
""" |
|
Handles the action callback for initiating a 2-1 chat. |
|
|
|
Args: |
|
action (cl.Action): The action object containing the user's input. |
|
""" |
|
await cl.Message(content="Write the email and the chat id:").send() |
|
action.get("value") |
|
await action.remove() |
|
|
|
@cl.on_chat_start |
|
async def on_chat_start(): |
|
""" |
|
Handles the start of a chat session. Initializes the memory, sets up the runnable pipeline, and prompts the user |
|
to summarize the type of conflict. |
|
""" |
|
cl.user_session.set("memory", ConversationBufferMemory(return_messages=True)) |
|
setup_runnable() |
|
first_res = await cl.AskUserMessage(content="Welcome to the Relationship Coach chatbot. I can help you with your relationship questions. Please first summarize the type of conflict.").send() |
|
add_person = await cl.AskActionMessage( |
|
content="Select the conversation type.", |
|
actions=[ |
|
cl.Action(name="1-1 Chat", value="1-1 Chat", label="π€ 1-1"), |
|
cl.Action(name="2-1 Chat", value="2-1 Chat", label="π₯ 2-1"), |
|
], |
|
).send() |
|
|
|
if add_person and add_person.get("value") == "2-1 Chat": |
|
res = await cl.AskUserMessage(content="Please write the username of the person to connect with.").send() |
|
if res: |
|
|
|
while manager.literal_client.api.get_user(identifier=res["output"]) == None: |
|
await cl.Message(content=f"Partner {res['output']} does not exist in db.").send() |
|
res = await cl.AskUserMessage(content="Please write the username of the person to connect with.").send() |
|
partner_username = res['output'] |
|
partner_user = manager.literal_client.api.get_user(identifier=partner_username) |
|
current_user = cl.user_session.get("user") |
|
current_username = current_user.identifier |
|
manager.literal_client.api.update_user(id=current_user.id, identifier=current_username, metadata={ |
|
"role": "user", "provider": "credentials", "relationships": {"partner_username": partner_username}}) |
|
await cl.Message(content=f"Connected with {partner_username}!").send() |
|
await on_message(cl.Message(content=first_res['output'])) |
|
create_and_update_threads(first_res, current_user, partner_user) |
|
else: |
|
await cl.Message( |
|
content=f"Action timed out!", |
|
).send() |
|
|
|
@cl.on_chat_resume |
|
async def on_chat_resume(thread: ThreadDict): |
|
""" |
|
Handles the resumption of a chat session. Restores the chat memory and sets up the runnable pipeline. |
|
|
|
Args: |
|
thread (ThreadDict): The thread dictionary containing the chat history. |
|
""" |
|
memory = ConversationBufferMemory(return_messages=True) |
|
root_messages = [m for m in thread["steps"] if m["parentId"] == None] |
|
for message in root_messages: |
|
if message["type"] == "user_message": |
|
memory.chat_memory.add_user_message(message["output"]) |
|
else: |
|
memory.chat_memory.add_ai_message(message["output"]) |
|
|
|
cl.user_session.set("memory", memory) |
|
cl.user_session.set("thread_id", thread["id"]) |
|
|
|
setup_runnable() |
|
|
|
conflict_resolution = ResolutionLogic() |
|
resolution = conflict_resolution.intervention(thread["id"]) |
|
|
|
if resolution: |
|
await cl.Message(content=resolution).send() |
|
|
|
@cl.on_message |
|
async def on_message(message: cl.Message): |
|
""" |
|
Handles incoming messages during a chat session. Updates the memory and generates a response. |
|
|
|
Args: |
|
message (cl.Message): The incoming message from the user. |
|
""" |
|
memory = cl.user_session.get("memory") |
|
runnable = cl.user_session.get("runnable") |
|
|
|
response = cl.Message(content="") |
|
|
|
conflict_resolution = ResolutionLogic() |
|
if cl.user_session.get("thread_id"): |
|
resolution = conflict_resolution.intervention(cl.user_session.get("thread_id")) |
|
|
|
if cl.user_session.get("thread_id") and resolution: |
|
response = cl.Message(content=resolution) |
|
else: |
|
async for chunk in runnable.astream( |
|
{"question": message.content}, |
|
config=RunnableConfig(callbacks=[cl.LangchainCallbackHandler()]), |
|
): |
|
await response.stream_token(chunk) |
|
|
|
await response.send() |
|
|
|
memory.chat_memory.add_user_message(message.content) |
|
memory.chat_memory.add_ai_message(response.content) |
|
|
|
def main(): |
|
""" |
|
The main function to demonstrate the usage of the chatbot. Initializes the chat session and starts the event loop. |
|
""" |
|
on_chat_start() |
|
cl.run() |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|