File size: 5,300 Bytes
7d9d14f
 
c3618b0
838fecd
7d9d14f
 
 
 
 
 
 
 
 
c3618b0
7d9d14f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4b0190f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bb2e116
7d9d14f
 
 
 
 
 
 
4b0190f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
from time import time, sleep
import datetime
# import dotenv
import os
import openai
import json
import pinecone
from uuid import uuid4
from helper import open_file, save_file
import re
from langchain.memory import VectorStoreRetrieverMemory

## Read the environment variables
# dotenv.load_dotenv('.env')
openai.api_key = os.getenv('OPENAI_API_KEY')
embedding_model = os.getenv('EMBEDDING_ENGINE')
convo_length = int(os.getenv('CONVO_LENGTH_TO_FETCH'))
llm_model = os.getenv('LLM_MODEL')

pinecone_api_key = os.getenv('PINECONE_API_KEY')
pinecone_env = os.getenv('PINECONE_REGION')
pinecone_index = os.getenv('PINECONE_INDEX')
pinecone.init(
    api_key=pinecone_api_key,  
    environment=pinecone_env 
)
vector_db = pinecone.Index(pinecone_index)

def timestamp_to_datetime(unix_time):
    return datetime.datetime.fromtimestamp(unix_time).strftime("%A, %B %d, %Y at %I:%M%p %Z")


def perform_embedding(content):
    content = content.encode(encoding='ASCII',errors='ignore').decode()
    response = openai.Embedding.create(model=embedding_model, input=content)
    vector = response['data'][0]['embedding']
    return vector

def load_conversation(results):
    result = list()
    for m in results['matches']:
        result.append({'time1': m['metadata']['timestring'], 'text': m['metadata']['text']})
    ordered = sorted(result, key=lambda d: d['time1'], reverse = False)
    messages = [i['text'] for i in ordered]
    message_block = '\n'.join(messages).strip()
    return message_block
 

def call_gpt(prompt):
    max_retry = 5
    retry = 0
    prompt = prompt.encode(encoding='ASCII',errors='ignore').decode()
    while True:
        try:
            response = openai.ChatCompletion.create(
                    model=llm_model,
                    temperature=0.9,
               messages=[
                {"role": "user", "content": prompt}
              ]
            )
            
            text = response.choices[0].message.content
            text = re.sub('[\r\n]+', '\n', text)
            text = re.sub('[\t ]+', ' ', text)
            filename = '%s_gpt3.txt' % time()
            if not os.path.exists('gpt3_logs'):
                os.makedirs('gpt3_logs')
            save_file('gpt3_logs/%s' % filename, prompt + '\n\n==========\n\n' + text)
            response.choices[0].message.content = text
            return response
        except Exception as oops:
            retry += 1
            if retry >= max_retry:
                return "GPT3 error: %s" % oops
            print('Error communicating with OpenAI:', oops)
            sleep(1)


def start_game(game_id, user_id, user_input):
    payload = list()

    # Get user input, save it, vectorize it and save to pinecone
    timestamp = time()
    timestring = timestamp_to_datetime(timestamp)   
    unique_id = str(uuid4())
    vector = perform_embedding(user_input)
    metadata = {'speaker': 'USER', 'user_id': user_id, 'game_id': game_id, 'timestring': timestring, 'text': user_input}
    payload.append((unique_id, vector, metadata))
       
    
    # Search for relevant messages and return a response
    results=vector_db.query(vector=vector, top_k=convo_length, include_metadata=True, 
                            filter={
                                    "$and": [{ "user_id": { "$eq": user_id } }, { "game_id": { "$eq": game_id } }]
                             }
                          )
    conversation = load_conversation(results)
                                           

    # Populate prompt
    prompt_text = open_file(f"prompt_{game_id}_{user_id}.txt")
    prompt = open_file('prompt_response.txt').replace('<<PROMPT_VALUE>>', prompt_text).replace('<<CONVERSATION>>', conversation).replace('<<USER_MSG>>', user_input).replace('<<USER_VAL>>', user_id)

    # Generate response, vectorize
    llm_output_msg = call_gpt(prompt)
    llm_output = llm_output_msg.choices[0].message.content
    timestamp_op = time()
    timestring_op = timestamp_to_datetime(timestamp) 
    vector_op = perform_embedding(llm_output)
    unique_id_op = str(uuid4)
    metadata_op = {'speaker': 'BOT', 'user_id': user_id, 'game_id': game_id, 'timestring': timestring, 'text': llm_output}
    payload.append((unique_id_op, vector_op, metadata_op))

    # Upsert into the vector database
    vector_db.upsert(payload)
       
    return(llm_output)
  

def populate_prompt(game_id, splits):
    prompt_text = list()
    idlist = []
    for j in range(int(splits)):
       idlist.append(game_id + "-" + str(j)) 

    results=vector_db.fetch(ids=idlist)  
    for ids in idlist:
        prompt_text.append(results['vectors'][ids]["metadata"]["text"])

    whole_prompt = ' '.join(prompt_text).strip()
    return whole_prompt
    

def initialize_game(game_id, user_id, user_input):
    game_details = get_game_details(game_id)
    whole_prompt = populate_prompt(game_id, game_details["splits"])
    whole_prompt = whole_prompt.replace("<<USER_INPUT_MSG>>", user_input)
       
    llm_prompt_op = call_gpt(whole_prompt)
    #print(llm_prompt_op.choices[0]["message"]["content"])
    fname="prompt_" + game_id + "_" + user_id + ".txt"
    save_file(fname, llm_prompt_op.choices[0]["message"]["content"])
    return llm_prompt_op.choices[0]["message"]["content"]

if __name__ == '__main__':
    print("main")