import openai openai.api_key_path = './openai_api_key.txt' import streamlit as st from streamlit_chat import message completion = openai.Completion() start_prompt = '[Instruction] Act as a friendly, compasionate, insightful, and empathetic AI therapist named Joy. Joy listens, asks for details and offers detailed advices once a while. End the conversation when you wishes to.' start_message = 'I am Joy, your AI therapist. How are you feeling today?' start_sequence = "\nJoy:" restart_sequence = "\n\nYou:" # to do: # let the user choose between models (curie, davinci, curie-finetuned, davinci-finetuned) # let the user choose between different temperatures, frequency_penalty, presence_penalty # save the user's input and the model's output to the database # analyze the user's input and the model's output # sentiment/mood analysis / topic analysis of the user's input # embed the user's input and look for therapy catalogue that is similar to the user's input # push the therapy catalogue to the user def ask(question: str, chat_log: str) -> (str, str): prompt = f'{chat_log}{restart_sequence} {question}{start_sequence}' response = completion.create( prompt = prompt, model = model, stop = ["You:",'Joy:'], temperature = temp, #the higher the more creative frequency_penalty = 0.3, #prevents word repetition, larger -> higher penalty presence_penalty = 0.6, #prevents topic repetition, larger -> higher penalty top_p =1, best_of=1, max_tokens=170 ) answer = response.choices[0].text.strip() log = f'{restart_sequence}{question}{start_sequence}{answer}' return str(answer), str(log) # button for starting a new conversation st.title("Chat with Joy - the AI therapist!") temp = st.slider("Creativity", 0.0, 1.0, 0.7, 0.1) model = st.selectbox("Model", ["text-davinci-003", "text-curie-001", "curie:ft-personal-2023-02-03-17-06-53"]) if 'generated' not in st.session_state: st.session_state['generated'] = [start_message] if 'past' not in st.session_state: st.session_state['past'] = [] if 'chat_log' not in st.session_state: st.session_state['chat_log'] = [start_prompt+start_sequence+start_message] user_input=st.text_input("You:",key='input') if user_input: output, chat_log = ask(user_input, st.session_state['chat_log']) st.session_state['chat_log'].append(chat_log) st.session_state['past'].append(user_input) st.session_state['generated'].append(output) print(st.session_state['chat_log']) if st.session_state['generated']: for i in range(len(st.session_state['generated'])-1, -1, -1): if i < len(st.session_state['past']): message(st.session_state['past'][i], is_user=True, key=str(i) + '_user') message(st.session_state["generated"][i], key=str(i)) # save the user's input and the model's output to the database and analyze the user's input and the model's output