File size: 4,663 Bytes
121a1b0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
import openai
openai.api_key_path = './openai_api_key.txt'
import streamlit as st
from streamlit_chat import message
from transformers import pipeline
summarizer = pipeline("summarization", model="philschmid/bart-large-cnn-samsum")
sentiment_task = pipeline("sentiment-analysis", model='cardiffnlp/twitter-roberta-base-sentiment-latest', tokenizer='cardiffnlp/twitter-roberta-base-sentiment-latest')
from math import log
completion = openai.Completion()
start_prompt = '[Instruction] Act as a friendly, compasionate, insightful, and empathetic AI therapist named Joy. Joy listens and offers advices. End the conversation when the patient wishes to.'
start_message = 'I am Joy, your AI therapist. How are you feeling today?'
start_sequence = "\nJoy:"
restart_sequence = "\n\nPatient:"
def ask(question: str, chat_log: str, model='text-davinci-003', temp=0.9) -> (str, str):
prompt = f'{chat_log}{restart_sequence} {question}{start_sequence}'
response = completion.create(
prompt = prompt,
model = model,
stop = ["Patient:",'Joy:'],
temperature = temp, #the higher the more creative
frequency_penalty = 0.9, #prevents word repetition, larger -> higher penalty
presence_penalty = 1, #prevents topic repetition, larger -> higher penalty
top_p =1,
best_of=1,
max_tokens=170
)
answer = response.choices[0].text.strip()
log = f'{restart_sequence}{question}{start_sequence}{answer}'
return str(answer), str(log)
def clean_chat_log(chat_log):
chat_log = ' '.join(chat_log)
# find the first /n
first_newline = chat_log.find('\n')
chat_log = chat_log[first_newline:]
# remove all \n
chat_log = chat_log.replace('\n', ' ')
return chat_log
def summarize(chat_log):
chat_log = clean_chat_log(chat_log)
summary = summarizer(chat_log, max_length=150, do_sample=False)[0]['summary_text']
return summary
def analyze_sentiment(chat_log):
# split chat_log into smaller chunks
# analyze each chunk
# return the average sentiment
chat_log = clean_chat_log(chat_log)
sentiment = sentiment_task(chat_log)
return sentiment
def main():
st.title("Chat with Joy - the AI therapist!")
col1, col2 = st.columns(2)
temp = col1.slider("Bot-Creativeness", 0.0, 1.0, 0.9, 0.1)
model = col2.selectbox("Model", ["text-davinci-003", "text-curie-001", "curie:ft-personal-2023-02-03-17-06-53"])
if 'generated' not in st.session_state:
st.session_state['generated'] = [start_message]
if 'past' not in st.session_state:
st.session_state['past'] = []
if 'summary' not in st.session_state:
st.session_state['summary'] = []
if 'chat_log' not in st.session_state:
st.session_state['chat_log'] = [start_prompt+start_sequence+start_message]
if len(st.session_state['generated']) > 2:
if st.button("Clear and summerize", key='clear'):
chat_log = clean_chat_log(st.session_state['chat_log'])
summary = summarizer(chat_log, max_length=100, min_length=30, do_sample=False)
st.write(summary)
user_sentiment = st.session_state['past']
user_sentiment = ' '.join(user_sentiment)
user_sentiment = clean_chat_log(user_sentiment)
st.write(sentiment_task(user_sentiment))
st.session_state['generated'] = [start_message]
st.session_state['past'] = []
st.session_state['chat_log'] = [start_prompt+start_sequence+start_message]
st.session_state['summary'] = []
user_input=st.text_input("You:",key='input')
if user_input:
output, chat_log = ask(user_input, st.session_state['chat_log'], model=model, temp=temp)
st.session_state['chat_log'].append(chat_log)
st.session_state['past'].append(user_input)
st.session_state['generated'].append(output)
print(model)
print(temp)
print(st.session_state['chat_log'])
if st.session_state['generated']:
for i in range(len(st.session_state['generated'])-1, -1, -1):
if i < len(st.session_state['past']):
message(st.session_state['past'][i], is_user=True, key=str(i) + '_user')
message(st.session_state["generated"][i], key=str(i))
if __name__ == "__main__":
main()
# save the user's input and the model's output to the database and analyze the user's input and the model's output
# if len(st.seesion_state['generated']) :
# save the user's input and the model's output to the database
# analyze the user's input and the model's output
|