vmoras commited on
Commit
5d345b1
1 Parent(s): c134e34

Initial commit

Browse files
Files changed (4) hide show
  1. .gitignore +5 -0
  2. main.py +35 -0
  3. requirements.txt +6 -0
  4. utils.py +238 -0
.gitignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ __pycache__/
2
+ .idea/
3
+ .env
4
+
5
+ prompts/
main.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dotenv import load_dotenv
2
+ load_dotenv()
3
+
4
+ import os
5
+ import utils
6
+ import gradio as gr
7
+
8
+
9
+ with gr.Blocks() as app:
10
+ with gr.Row() as selection:
11
+ model = gr.Dropdown(choices=[model for model in utils.MODELS], label='Select Model')
12
+ start_button = gr.Button(value='Start Test')
13
+ restart_button = gr.Button(value='Restart Test', visible=False)
14
+ with gr.Column(visible=False) as testing:
15
+ name_model = gr.Markdown()
16
+ chatbot = gr.Chatbot(label='Chatbot')
17
+ message = gr.Text(label='Enter your message')
18
+
19
+ # Init the chatbot
20
+ start_button.click(
21
+ utils.start_chat, model, [selection, restart_button, testing, name_model]
22
+ )
23
+
24
+ # Select again the model
25
+ restart_button.click(
26
+ utils.restart_chat, None, [selection, restart_button, testing, chatbot, message]
27
+ )
28
+
29
+ # Send the messages and get an answer
30
+ message.submit(
31
+ utils.get_answer, [chatbot, message, model], [chatbot, message]
32
+ )
33
+
34
+ app.queue()
35
+ app.launch(debug=True, auth=(os.environ.get('SPACE_USERNAME'), os.environ.get('SPACE_PASSWORD')))
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ gradio==4.19.0
2
+ python-dotenv==1.0.1
3
+ pinecone-client==2.2.4
4
+ openai==1.6.1
5
+ google-generativeai==0.3.2
6
+ huggingface_hub==0.20.2
utils.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pinecone
3
+ import gradio as gr
4
+ from openai import OpenAI
5
+ from typing import Callable
6
+ import google.generativeai as genai
7
+ from huggingface_hub import hf_hub_download
8
+
9
+
10
+ def download_prompt(name_prompt: str) -> str:
11
+ """
12
+ Downloads prompt from HuggingFace Hub
13
+ :param name_prompt: name of the file
14
+ :return: text of the file
15
+ """
16
+ hf_hub_download(
17
+ repo_id=os.environ.get('DATA'), repo_type='dataset', filename=f"{name_prompt}.txt",
18
+ token=os.environ.get('HUB_TOKEN'), local_dir="prompts"
19
+ )
20
+ with open(f'prompts/{name_prompt}.txt', mode='r', encoding='utf-8') as infile:
21
+ prompt = infile.read()
22
+ return prompt
23
+
24
+
25
+ def start_chat(model: str) -> tuple[gr.helpers, gr.helpers, gr.helpers, gr.helpers]:
26
+ """
27
+ Shows the chatbot interface and hides the selection of the model.
28
+ Returns gradio helpers (gr.update())
29
+ :param model: name of the model to use
30
+ :return: visible=False, visible=True, visible=True, value=selected_model
31
+ """
32
+ no_visible = gr.update(visible=False)
33
+ visible = gr.update(visible=True)
34
+ title = gr.update(value=f"# {model}")
35
+ return no_visible, visible, visible, title
36
+
37
+
38
+ def restart_chat() -> tuple[gr.helpers, gr.helpers, gr.helpers, list, str]:
39
+ """
40
+ Shows the selection of the model, hides the chatbot interface and restarts the chatbot.
41
+ Returns gradio helpers (gr.update())
42
+ :return: visible=True, visible=False, visible=False, empty list, empty string
43
+ """
44
+ no_visible = gr.update(visible=False)
45
+ visible = gr.update(visible=True)
46
+ return visible, no_visible, no_visible, [], ""
47
+
48
+
49
+ def get_answer(chatbot: list[tuple[str, str]], message: str, model: str) -> tuple[list[tuple[str, str]], str]:
50
+ """
51
+ Calls the model and returns the answer
52
+ :param chatbot: message history
53
+ :param message: user input
54
+ :param model: name of the model
55
+ :return: chatbot answer
56
+ """
57
+ # Setup which function will be called (depends on the model)
58
+ if COMPANIES[model]['real name'] == 'Gemini':
59
+ call_model = _call_google
60
+ else:
61
+ call_model = _call_openai
62
+
63
+ # Get standalone question
64
+ standalone_question = _get_standalone_question(chatbot, message, call_model)
65
+
66
+ # Get context
67
+ context = _get_context(standalone_question)
68
+
69
+ # Get answer from the Chatbot
70
+ prompt = PROMPT_GENERAL.replace('CONTEXT', context)
71
+ answer = call_model(prompt, chatbot, message)
72
+
73
+ # Add the new answer to the history
74
+ chatbot.append((message, answer))
75
+
76
+ return chatbot, ""
77
+
78
+
79
+ def _get_standalone_question(
80
+ chat_history: list[tuple[str, str]], message: str, call_model: Callable[[str, list, str], str]
81
+ ) -> str:
82
+ """
83
+ To get a better context a standalone question is obtained for each question
84
+ :param chat_history: message history
85
+ :param message: user input
86
+ :param call_model: name of the model
87
+ :return: standalone phrase
88
+ """
89
+ # Format the message history like: Human: blablablá \nAssistant: blablablá
90
+ history = ''
91
+ for i, (user, bot) in enumerate(chat_history):
92
+ if i == 0:
93
+ history += f'Assistant: {bot}\n'
94
+ else:
95
+ history += f'Human: {user}\n'
96
+ history += f'Assistant: {bot}\n'
97
+
98
+ # Add history and question to the prompt
99
+ prompt = PROMPT_STANDALONE.replace('HISTORY', history)
100
+ question = f'Follow-up message: {message}'
101
+
102
+ return call_model(prompt, [], question)
103
+
104
+
105
+ def _get_embedding(text: str) -> list[float]:
106
+ """
107
+ :param text: input text
108
+ :return: embedding
109
+ """
110
+ response = OPENAI_CLIENT.embeddings.create(
111
+ input=text,
112
+ model='text-embedding-ada-002'
113
+ )
114
+ return response.data[0].embedding
115
+
116
+
117
+ def _get_context(question: str) -> str:
118
+ """
119
+ Get the 10 nearest vectors to the given input
120
+ :param question: standalone question
121
+ :return: formatted context with the nearest vectors
122
+ """
123
+ result = INDEX.query(
124
+ vector=_get_embedding(question),
125
+ top_k=10,
126
+ include_metadata=True,
127
+ namespace=f'{CLIENT}-context'
128
+ )['matches']
129
+
130
+ context = ''
131
+ for r in result:
132
+ context += r['metadata']['Text'] + '\n\n'
133
+ return context
134
+
135
+
136
+ def _call_openai(prompt: str, chat_history: list[tuple[str, str]], question: str) -> str:
137
+ """
138
+ Calls ChatGPT 4
139
+ :param prompt: prompt with the context or the question (in the case of the standalone one)
140
+ :param chat_history: history of the conversation
141
+ :param question: user input
142
+ :return: chatbot answer
143
+ """
144
+ # Format the message history to the one used by OpenAI
145
+ msg_history = [{'role': 'system', 'content': prompt}]
146
+ for i, (user, bot) in enumerate(chat_history):
147
+ if i == 0:
148
+ msg_history.append({'role': 'assistant', 'content': bot})
149
+ else:
150
+ msg_history.append({'role': 'user', 'content': user})
151
+ msg_history.append({'role': 'assistant', 'content': bot})
152
+ msg_history.append({'role': 'user', 'content': question})
153
+
154
+ # Call ChatGPT 4
155
+ response = OPENAI_CLIENT.chat.completions.create(
156
+ model='gpt-4-turbo-preview',
157
+ temperature=0.5,
158
+ messages=msg_history
159
+ )
160
+ return response.choices[0].message.content
161
+
162
+
163
+ def _call_google(prompt: str, chat_history: list[tuple[str, str]], question: str) -> str:
164
+ """
165
+ Calls Gemini
166
+ :param prompt: prompt with the context or the question (in the case of the standalone one)
167
+ :param chat_history: history of the conversation
168
+ :param question: user input
169
+ :return: chatbot answer
170
+ """
171
+ # Format the message history to the one used by Google
172
+ history = [
173
+ {'role': 'user', 'parts': [prompt]},
174
+ {'role': 'model', 'parts': 'Excelente! Estoy super lista para ayudarte en lo que necesites'}
175
+ ]
176
+ for i, (user, bot) in enumerate(chat_history):
177
+ if i == 0:
178
+ history.append({'role': 'model', 'parts': bot})
179
+ else:
180
+ history.append({'role': 'user', 'parts': user})
181
+ history.append({'role': 'model', 'parts': bot})
182
+ convo = GEMINI.start_chat(history=history)
183
+
184
+ # Call Gemini
185
+ convo.send_message(question)
186
+ return convo.last.text
187
+
188
+
189
+ # ----------------------------------------- Setup constants and models ------------------------------------------------
190
+ OPENAI_CLIENT = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
191
+ genai.configure(api_key=os.getenv("GEMINI_API_KEY"))
192
+ pinecone.init(api_key=os.getenv('PINECONE_API_KEY'), environment=os.getenv("PINECONE_ENVIRONMENT"))
193
+ INDEX = pinecone.Index(os.getenv('PINECONE_INDEX'))
194
+ CLIENT = os.getenv('CLIENT')
195
+
196
+
197
+ # Setup Gemini
198
+ generation_config = {
199
+ "temperature": 0.9,
200
+ "top_p": 1,
201
+ "top_k": 1,
202
+ "max_output_tokens": 2048,
203
+ }
204
+ safety_settings = [
205
+ {
206
+ "category": "HARM_CATEGORY_HARASSMENT",
207
+ "threshold": "BLOCK_MEDIUM_AND_ABOVE"
208
+ },
209
+ {
210
+ "category": "HARM_CATEGORY_HATE_SPEECH",
211
+ "threshold": "BLOCK_MEDIUM_AND_ABOVE"
212
+ },
213
+ {
214
+ "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
215
+ "threshold": "BLOCK_ONLY_HIGH"
216
+ },
217
+ {
218
+ "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
219
+ "threshold": "BLOCK_MEDIUM_AND_ABOVE"
220
+ },
221
+ ]
222
+ GEMINI = genai.GenerativeModel(
223
+ model_name="gemini-1.0-pro", generation_config=generation_config, safety_settings=safety_settings
224
+ )
225
+
226
+
227
+ # Download and open prompts from HuggingFace Hub
228
+ os.makedirs('prompts', exist_ok=True)
229
+ PROMPT_STANDALONE = download_prompt('standalone')
230
+ PROMPT_GENERAL = download_prompt('general')
231
+
232
+
233
+ # Constants used in the app
234
+ COMPANIES = {
235
+ 'Model G': {'company': 'Google', 'real name': 'Gemini'},
236
+ 'Model C': {'company': 'OpenAI', 'real name': 'ChatGPT 4'},
237
+ }
238
+ MODELS = list(COMPANIES.keys())