File size: 4,663 Bytes
121a1b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import openai
openai.api_key_path = './openai_api_key.txt'
import streamlit as st
from streamlit_chat import message
from transformers import pipeline
summarizer = pipeline("summarization", model="philschmid/bart-large-cnn-samsum")
sentiment_task = pipeline("sentiment-analysis", model='cardiffnlp/twitter-roberta-base-sentiment-latest', tokenizer='cardiffnlp/twitter-roberta-base-sentiment-latest')

from math import log

completion = openai.Completion()


start_prompt = '[Instruction] Act as a friendly, compasionate, insightful, and empathetic AI therapist named Joy. Joy listens and offers advices. End the conversation when the patient wishes to.'
start_message = 'I am Joy, your AI therapist. How are you feeling today?'
start_sequence = "\nJoy:"
restart_sequence = "\n\nPatient:"
  
def ask(question: str, chat_log: str, model='text-davinci-003', temp=0.9) -> (str, str):

  prompt = f'{chat_log}{restart_sequence} {question}{start_sequence}'

  response = completion.create(
      prompt = prompt,
      model = model,
      stop = ["Patient:",'Joy:'],
      temperature = temp, #the higher the more creative
      frequency_penalty = 0.9, #prevents word repetition, larger -> higher penalty
      presence_penalty = 1, #prevents topic repetition, larger -> higher penalty
      top_p =1, 
      best_of=1,
      max_tokens=170
  )
  
  answer = response.choices[0].text.strip()
  log = f'{restart_sequence}{question}{start_sequence}{answer}'
  return str(answer), str(log)

def clean_chat_log(chat_log):
    chat_log = ' '.join(chat_log)
    # find the first /n
    first_newline = chat_log.find('\n')
    chat_log = chat_log[first_newline:]
    # remove all \n
    chat_log = chat_log.replace('\n', ' ')
    return chat_log

def summarize(chat_log):
    chat_log = clean_chat_log(chat_log)
    summary = summarizer(chat_log, max_length=150, do_sample=False)[0]['summary_text']
    return summary

def analyze_sentiment(chat_log):
    # split chat_log into smaller chunks

    # analyze each chunk

    # return the average sentiment


    chat_log = clean_chat_log(chat_log)
    sentiment = sentiment_task(chat_log)
    return sentiment





def main():
    st.title("Chat with Joy - the AI therapist!")
    col1, col2 = st.columns(2)
    temp = col1.slider("Bot-Creativeness", 0.0, 1.0, 0.9, 0.1)
    model = col2.selectbox("Model", ["text-davinci-003", "text-curie-001", "curie:ft-personal-2023-02-03-17-06-53"])

    if 'generated' not in st.session_state:
        st.session_state['generated'] = [start_message]
        
    if 'past' not in st.session_state:
        st.session_state['past'] = []

    if 'summary' not in st.session_state:
        st.session_state['summary'] = []

    if 'chat_log' not in st.session_state:
        st.session_state['chat_log'] = [start_prompt+start_sequence+start_message]
        

    if len(st.session_state['generated']) > 2:
        if st.button("Clear and summerize", key='clear'):
            chat_log = clean_chat_log(st.session_state['chat_log'])
            summary = summarizer(chat_log, max_length=100, min_length=30, do_sample=False)
            st.write(summary)
            user_sentiment = st.session_state['past']
            user_sentiment = ' '.join(user_sentiment)
            user_sentiment = clean_chat_log(user_sentiment)
            st.write(sentiment_task(user_sentiment))
            st.session_state['generated'] = [start_message]
            st.session_state['past'] = []
            st.session_state['chat_log'] = [start_prompt+start_sequence+start_message]
            st.session_state['summary'] = []

    user_input=st.text_input("You:",key='input')

    if user_input:
        output, chat_log = ask(user_input, st.session_state['chat_log'], model=model, temp=temp)
        st.session_state['chat_log'].append(chat_log)
        st.session_state['past'].append(user_input)
        st.session_state['generated'].append(output)
        print(model)
        print(temp)
        print(st.session_state['chat_log'])
    if st.session_state['generated']:
        for i in range(len(st.session_state['generated'])-1, -1, -1):
            if i < len(st.session_state['past']):
                message(st.session_state['past'][i], is_user=True, key=str(i) + '_user')
            message(st.session_state["generated"][i], key=str(i))



if __name__ == "__main__":
    main()
    # save the user's input and the model's output to the database and analyze the user's input and the model's output

    # if len(st.seesion_state['generated'])  :
        # save the user's input and the model's output to the database
        # analyze the user's input and the model's output