hoshingakag commited on
Commit
ecbd714
1 Parent(s): e97bd62
Files changed (2) hide show
  1. app.py +69 -170
  2. src/llamaindex_palm.py +302 -8
app.py CHANGED
@@ -1,28 +1,18 @@
1
- import os
2
- import time
3
- import datetime
4
 
5
  import gradio as gr
6
 
7
- import google.generativeai as genai
8
- from src.llamaindex_palm import LlamaIndexPaLM
9
-
10
- import wandb
11
- from wandb.sdk.data_types.trace_tree import Trace
12
-
13
  import logging
14
- logging.basicConfig(format='%(asctime)s %(message)s', datefmt='%Y-%m-%d %I:%M:%S %p', level=logging.INFO)
15
- logger = logging.getLogger('llm')
16
-
17
- # Llama-Index LLM
18
- llm = LlamaIndexPaLM()
19
- llm.set_index_from_pinecone()
20
 
21
- # Credentials
22
- genai.configure(api_key=os.getenv('PALM_API_KEY'))
23
 
24
- # W&B
25
- wandb.init(project=os.getenv('WANDB_PROJECT'))
 
 
26
 
27
  # Gradio
28
  chat_history = []
@@ -32,7 +22,7 @@ def clear_chat() -> None:
32
  chat_history = []
33
  return None
34
 
35
- def get_chat_history(chat_history) -> str:
36
  ind = 0
37
  formatted_chat_history = ""
38
  for message in chat_history:
@@ -40,162 +30,71 @@ def get_chat_history(chat_history) -> str:
40
  ind += 1
41
  return formatted_chat_history
42
 
43
- def generate_chat(prompt: str, llamaindex_llm: LlamaIndexPaLM):
44
  global chat_history
45
- # get chat history
46
- context_chat_history = "\n".join(list(filter(None, chat_history)))
47
 
48
  logger.info("Generating Message...")
49
  logger.info(f"User Message:\n{prompt}\n")
50
- chat_history.append(prompt)
51
 
52
- # w&b trace start
53
- start_time_ms = round(datetime.datetime.now().timestamp() * 1000)
54
-
55
- root_span = Trace(
56
- name="LLMChain",
57
- kind="chain",
58
- start_time_ms=start_time_ms,
59
- metadata={"user": "Gradio"},
60
- )
61
-
62
- # get context
63
- context_from_index = llamaindex_llm.generate_response(prompt)
64
- logger.info(f"Context from Llama-Index:\n{context_from_index}\n")
65
-
66
- # w&b trace agent
67
- agent_end_time_ms = round(datetime.datetime.now().timestamp() * 1000)
68
- agent_span = Trace(
69
- name="Agent",
70
- kind="agent",
71
- status_code="success",
72
- metadata={
73
- "framework": "Llama-Index",
74
- "index_type": "VectorStoreIndex",
75
- "vector_store": "Pinecone",
76
- "vector_store_index": llamaindex_llm._index_name,
77
- "vector_store_namespace": llamaindex_llm._index_namespace,
78
- "model_name": llamaindex_llm.llm._model_name,
79
- # "temperture": 0.7,
80
- # "top_k": 40,
81
- # "top_p": 0.95,
82
- "custom_kwargs": llamaindex_llm.llm._model_kwargs,
83
- },
84
- start_time_ms=start_time_ms,
85
- end_time_ms=agent_end_time_ms,
86
- inputs={"query": prompt},
87
- outputs={"response": context_from_index},
88
- )
89
- root_span.add_child(agent_span)
90
-
91
- prompt_with_context = f"""
92
- [System]
93
- You are in a role play of Gerard Lee and you need to pretend to be him to answer questions from people who interested in Gerard's background.
94
- Respond the User Query below in no more than 5 complete sentences, unless specifically asked by the user to elaborate on something. Use only the History and Context to inform your answers.
95
-
96
- [History]
97
- {context_chat_history}
98
-
99
- [Context]
100
- {context_from_index}
101
 
102
- [User Query]
103
- {prompt}
104
- """
105
 
106
- try:
107
- response = genai.generate_text(
108
- prompt=prompt_with_context,
109
- safety_settings=[
110
- {
111
- 'category': genai.types.HarmCategory.HARM_CATEGORY_UNSPECIFIED,
112
- 'threshold': genai.types.HarmBlockThreshold.BLOCK_NONE,
113
- },
114
- ],
115
- temperature=0.9,
116
- )
117
- result = response.result
118
- success_flag = "success"
119
- if result is None:
120
- result = "Seems something went wrong. Please try again later."
121
- logger.error(f"Result with 'None' received\n")
122
- success_flag = "fail"
123
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  except Exception as e:
125
- result = "Seems something went wrong. Please try again later."
126
- logger.error(f"Exception {e} occured\n")
127
- success_flag = "fail"
128
-
129
- chat_history.append(result)
130
- logger.info(f"Bot Message:\n{result}\n")
131
-
132
- # w&b trace llm
133
- llm_end_time_ms = round(datetime.datetime.now().timestamp() * 1000)
134
- llm_span = Trace(
135
- name="LLM",
136
- kind="llm",
137
- status_code=success_flag,
138
- start_time_ms=agent_end_time_ms,
139
- end_time_ms=llm_end_time_ms,
140
- inputs={"input": prompt_with_context},
141
- outputs={"result": result},
142
- )
143
- root_span.add_child(llm_span)
144
-
145
- # w&b finalize trace
146
- root_span.add_inputs_and_outputs(
147
- inputs={"query": prompt}, outputs={"result": result}
148
- )
149
- root_span._span.end_time_ms = llm_end_time_ms
150
- root_span.log(name="llm_app_trace")
151
-
152
- return result
153
-
154
- with gr.Blocks() as app:
155
- chatbot = gr.Chatbot(
156
- bubble_full_width=False,
157
- container=True,
158
- show_share_button=False,
159
- avatar_images=[None, './asset/akag-g-only.png']
160
- )
161
- msg = gr.Textbox(
162
- show_label=False,
163
- label="Type your message...",
164
- placeholder="Hi Gerard, can you introduce yourself?",
165
- container=False,
166
- )
167
- with gr.Row():
168
- clear = gr.Button("Clear", scale=1)
169
- send = gr.Button(
170
- value="",
171
- variant="primary",
172
- icon="./asset/send-message.png",
173
- scale=1
174
- )
175
-
176
- def user(user_message, history):
177
- return "", history + [[user_message, None]]
178
-
179
- def bot(history):
180
- bot_message = generate_chat(history[-1][0], llm)
181
- history[-1][1] = ""
182
- for character in bot_message:
183
- history[-1][1] += character
184
- time.sleep(0.01)
185
- yield history
186
-
187
- msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
188
- bot, chatbot, chatbot
189
- )
190
- send.click(user, [msg, chatbot], [msg, chatbot], queue=False).then(
191
- bot, chatbot, chatbot
192
- )
193
- clear.click(clear_chat, None, chatbot, queue=False)
194
-
195
- gr.HTML("""
196
- <p><center><i>Disclaimer: This is a RAG app for demostration purpose. LLM hallucination might occur.</i></center></p>
197
- <p><center>Hosted on 🤗 Spaces. Powered by Google PaLM 🌴</center></p>
198
- """)
199
-
200
- app.queue()
201
- app.launch()
 
1
+ from src.llamaindex_palm import LlamaIndexPaLM, LlamaIndexPaLMText
 
 
2
 
3
  import gradio as gr
4
 
5
+ from typing import List
6
+ import time
 
 
 
 
7
  import logging
 
 
 
 
 
 
8
 
9
+ # import dotenv
10
+ # dotenv.load_dotenv(".env")
11
 
12
+ # Llama-Index LLM
13
+ llm_backend = LlamaIndexPaLMText(model_kwargs={'temperature': 0.8})
14
+ llm = LlamaIndexPaLM(model=llm_backend)
15
+ llm.get_index_from_pinecone()
16
 
17
  # Gradio
18
  chat_history = []
 
22
  chat_history = []
23
  return None
24
 
25
+ def get_chat_history(chat_history: List[str]) -> str:
26
  ind = 0
27
  formatted_chat_history = ""
28
  for message in chat_history:
 
30
  ind += 1
31
  return formatted_chat_history
32
 
33
+ def generate_text(prompt: str, llamaindex_llm: LlamaIndexPaLM):
34
  global chat_history
 
 
35
 
36
  logger.info("Generating Message...")
37
  logger.info(f"User Message:\n{prompt}\n")
 
38
 
39
+ result = llamaindex_llm.generate_text(prompt, chat_history)
40
+ chat_history.append(prompt)
41
+ chat_history.append(result)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
+ logger.info(f"Replied Message:\n{result}\n")
44
+ return result
 
45
 
46
+ if __name__ == "__main__":
47
+ logging.basicConfig(format='%(asctime)s %(message)s', datefmt='%Y-%m-%d %I:%M:%S %p', level=logging.INFO)
48
+ logger = logging.getLogger('app')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
+ try:
51
+ with gr.Blocks() as app:
52
+ chatbot = gr.Chatbot(
53
+ bubble_full_width=False,
54
+ container=True,
55
+ show_share_button=False,
56
+ avatar_images=[None, './asset/akag-g-only.png']
57
+ )
58
+ msg = gr.Textbox(
59
+ show_label=False,
60
+ label="Type your message...",
61
+ placeholder="Hi Gerard, can you introduce yourself?",
62
+ container=False,
63
+ )
64
+ with gr.Row():
65
+ clear = gr.Button("Clear", scale=1)
66
+ send = gr.Button(
67
+ value="",
68
+ variant="primary",
69
+ icon="./asset/send-message.png",
70
+ scale=1
71
+ )
72
+
73
+ def user(user_message, history):
74
+ return "", history + [[user_message, None]]
75
+
76
+ def bot(history):
77
+ bot_message = generate_text(history[-1][0], llm)
78
+ history[-1][1] = ""
79
+ for character in bot_message:
80
+ history[-1][1] += character
81
+ time.sleep(0.01)
82
+ yield history
83
+
84
+ msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
85
+ bot, chatbot, chatbot
86
+ )
87
+ send.click(user, [msg, chatbot], [msg, chatbot], queue=False).then(
88
+ bot, chatbot, chatbot
89
+ )
90
+ clear.click(clear_chat, None, chatbot, queue=False)
91
+
92
+ gr.HTML("""
93
+ <p><center><i>Disclaimer: This is a RAG app for demostration purpose. LLM hallucination might occur.</i></center></p>
94
+ <p><center>Hosted on 🤗 Spaces. Powered by Google PaLM 🌴</center></p>
95
+ """)
96
+
97
+ app.queue()
98
+ app.launch()
99
  except Exception as e:
100
+ logger.exception(e)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/llamaindex_palm.py CHANGED
@@ -1,9 +1,14 @@
1
  import os
2
- import logging
 
 
3
 
4
- from typing import Any, List
5
  from pydantic import Extra
6
 
 
 
 
7
  import pinecone
8
  import google.generativeai as genai
9
 
@@ -25,6 +30,25 @@ from llama_index.llms import (
25
  )
26
  from llama_index.llms.base import llm_completion_callback
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  class LlamaIndexPaLMEmbeddings(BaseEmbedding, extra=Extra.allow):
29
  def __init__(
30
  self,
@@ -114,11 +138,13 @@ class LlamaIndexPaLM():
114
  def __init__(
115
  self,
116
  emb_model: LlamaIndexPaLMEmbeddings = LlamaIndexPaLMEmbeddings(),
117
- model: LlamaIndexPaLMText = LlamaIndexPaLMText()
 
118
  ) -> None:
119
  self.emb_model = emb_model
120
  self.llm = model
121
-
 
122
  # Google Generative AI
123
  genai.configure(api_key=os.environ['PALM_API_KEY'])
124
 
@@ -128,6 +154,9 @@ class LlamaIndexPaLM():
128
  environment=os.getenv('PINECONE_ENV')
129
  )
130
 
 
 
 
131
  # model metadata
132
  CONTEXT_WINDOW = os.getenv('CONTEXT_WINDOW', 8196)
133
  NUM_OUTPUT = os.getenv('NUM_OUTPUT', 1024)
@@ -156,7 +185,13 @@ class LlamaIndexPaLM():
156
  prompt_helper=self.prompt_helper,
157
  )
158
 
159
- def set_index_from_pinecone(
 
 
 
 
 
 
160
  self,
161
  index_name: str = os.getenv('PINECONE_INDEX'),
162
  index_namespace: str = os.getenv('PINECONE_NAMESPACE')
@@ -168,10 +203,269 @@ class LlamaIndexPaLM():
168
  self._index_name = index_name
169
  self._index_namespace = index_namespace
170
  return None
 
 
 
 
 
 
 
 
 
171
 
172
- def generate_response(
173
  self,
174
  query: str
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
  ) -> str:
176
- response = self.pinecone_index.as_query_engine(similarity_top_k=3,).query(query)
177
- return response.response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ import datetime
3
+ import asyncio
4
+ from concurrent.futures import ThreadPoolExecutor
5
 
6
+ from typing import Any, List, Dict, Union
7
  from pydantic import Extra
8
 
9
+ import wandb
10
+ from wandb.sdk.data_types.trace_tree import Trace
11
+
12
  import pinecone
13
  import google.generativeai as genai
14
 
 
30
  )
31
  from llama_index.llms.base import llm_completion_callback
32
 
33
+ from llama_index.evaluation import SemanticSimilarityEvaluator
34
+ from llama_index.embeddings import SimilarityMode
35
+
36
+ import logging
37
+ logging.basicConfig(format='%(asctime)s %(message)s', datefmt='%Y-%m-%d %I:%M:%S %p', level=logging.INFO)
38
+ logger = logging.getLogger('llm')
39
+
40
+ prompt_template = """
41
+ [System]
42
+ You are in a role play of Gerard Lee.
43
+ Reply in no more than 7 complete sentences using content from [Context] only. Refer to [History] for seamless conversatation.
44
+
45
+ [History]
46
+ {context_history}
47
+
48
+ [Context]
49
+ {context_from_index}
50
+ """
51
+
52
  class LlamaIndexPaLMEmbeddings(BaseEmbedding, extra=Extra.allow):
53
  def __init__(
54
  self,
 
138
  def __init__(
139
  self,
140
  emb_model: LlamaIndexPaLMEmbeddings = LlamaIndexPaLMEmbeddings(),
141
+ model: LlamaIndexPaLMText = LlamaIndexPaLMText(),
142
+ # prompt_template: str = prompt_template
143
  ) -> None:
144
  self.emb_model = emb_model
145
  self.llm = model
146
+ self.prompt_template = prompt_template
147
+
148
  # Google Generative AI
149
  genai.configure(api_key=os.environ['PALM_API_KEY'])
150
 
 
154
  environment=os.getenv('PINECONE_ENV')
155
  )
156
 
157
+ # W&B
158
+ wandb.init(project=os.getenv('WANDB_PROJECT'))
159
+
160
  # model metadata
161
  CONTEXT_WINDOW = os.getenv('CONTEXT_WINDOW', 8196)
162
  NUM_OUTPUT = os.getenv('NUM_OUTPUT', 1024)
 
185
  prompt_helper=self.prompt_helper,
186
  )
187
 
188
+ self.emd_evaluator = SemanticSimilarityEvaluator(
189
+ service_context=self.service_context,
190
+ similarity_mode=SimilarityMode.DEFAULT,
191
+ similarity_threshold=os.getenv('SIMILARITY_THRESHOLD', 0.7),
192
+ )
193
+
194
+ def get_index_from_pinecone(
195
  self,
196
  index_name: str = os.getenv('PINECONE_INDEX'),
197
  index_namespace: str = os.getenv('PINECONE_NAMESPACE')
 
203
  self._index_name = index_name
204
  self._index_namespace = index_namespace
205
  return None
206
+
207
+ def retrieve_context(
208
+ self,
209
+ query: str
210
+ ) -> Dict[str, Union[str, int]]:
211
+ start_time = round(datetime.datetime.now().timestamp() * 1000)
212
+ response = self.pinecone_index.as_query_engine(similarity_top_k=3).query(query)
213
+ end_time = round(datetime.datetime.now().timestamp() * 1000)
214
+ return {"result": response.response, "start": start_time, "end": end_time}
215
 
216
+ async def aretrieve_context(
217
  self,
218
  query: str
219
+ ) -> Dict[str, Union[str, int]]:
220
+ start_time = round(datetime.datetime.now().timestamp() * 1000)
221
+ response = await self.pinecone_index.as_query_engine(similarity_top_k=3, use_async=True).aquery(query)
222
+ end_time = round(datetime.datetime.now().timestamp() * 1000)
223
+ return {"result": response.response, "start": start_time, "end": end_time}
224
+
225
+ async def aretrieve_context_multi(
226
+ self,
227
+ query_list: List[str]
228
+ ) -> List[Dict]:
229
+ result = await asyncio.gather(*(self.aretrieve_context(query) for query in query_list))
230
+ return result
231
+
232
+ async def aevaluate_context(
233
+ self,
234
+ query: str,
235
+ returned_context: str
236
+ ) -> Dict[str, Any]:
237
+ result = await self.emd_evaluator.aevaluate(
238
+ response=returned_context,
239
+ reference=query,
240
+ )
241
+ return result
242
+
243
+ async def aevaluate_context_multi(
244
+ self,
245
+ query_list: List[str],
246
+ returned_context_list: List[str]
247
+ ) -> List[Dict]:
248
+ result = await asyncio.gather(*(self.aevaluate_context(query, returned_context) for query, returned_context in zip(query_list, returned_context_list)))
249
+ return result
250
+
251
+ def format_history_as_context(
252
+ self,
253
+ history: List[str],
254
+ ) -> str:
255
+ format_chat_history = "\n".join(list(filter(None, history)))
256
+ return format_chat_history
257
+
258
+ def generate_text(
259
+ self,
260
+ query: str,
261
+ history: List[str],
262
  ) -> str:
263
+ # get history
264
+ context_history = self.format_history_as_context(history=history)
265
+
266
+ # w&b trace start
267
+ start_time_ms = round(datetime.datetime.now().timestamp() * 1000)
268
+ root_span = Trace(
269
+ name="MetaAgent",
270
+ kind="agent",
271
+ start_time_ms=start_time_ms,
272
+ metadata={"user": "🤗 Space"},
273
+ )
274
+
275
+ # get retrieval context(s) from llama-index vectorstore index
276
+ # w&b trace retrieval & select agent
277
+ agent_span = Trace(
278
+ name="LlamaIndexAgent",
279
+ kind="agent",
280
+ start_time_ms=start_time_ms,
281
+ )
282
+ try:
283
+ # No history, single context retrieval without evaluation
284
+ if not history:
285
+ # w&b trace retrieval context
286
+ result_query_only = self.retrieve_context(query)
287
+ # async version
288
+ # result_query_only = asyncio.run(self.retrieve_context(query))
289
+ context_from_index_selected = result_query_only["result"]
290
+ agent_end_time_ms = round(datetime.datetime.now().timestamp() * 1000)
291
+ retrieval_span = Trace(
292
+ name="QueryRetrieval",
293
+ kind="chain",
294
+ status_code="success",
295
+ metadata={
296
+ "framework": "Llama-Index",
297
+ "index_type": "VectorStoreIndex",
298
+ "vector_store": "Pinecone",
299
+ "vector_store_index": self._index_name,
300
+ "vector_store_namespace": self._index_namespace,
301
+ "model_name": self.llm._model_name,
302
+ "custom_kwargs": self.llm._model_kwargs,
303
+ },
304
+ start_time_ms=start_time_ms,
305
+ end_time_ms=agent_end_time_ms,
306
+ inputs={"query": query},
307
+ outputs={"response": context_from_index_selected},
308
+ )
309
+ agent_span.add_child(retrieval_span)
310
+ # Has history, multiple context retrieval with async, then evaluation to determine which context to choose
311
+ else:
312
+ extended_query = f"[History]\n{history[-1]}\n[New Query]\n{query}"
313
+
314
+ # thread version
315
+ with ThreadPoolExecutor(2) as executor:
316
+ results = executor.map(self.retrieve_context, [query, extended_query])
317
+ result_query_only, result_extended_query = [rec for rec in results]
318
+
319
+ # async version - not working
320
+ # result_query_only, result_extended_query = asyncio.run(
321
+ # self.aretrieve_context_multi([query, extended_query])
322
+ # )
323
+
324
+ # w&b trace retrieval context query only
325
+ retrieval_query_span = Trace(
326
+ name="QueryRetrieval",
327
+ kind="chain",
328
+ status_code="success",
329
+ metadata={
330
+ "framework": "Llama-Index",
331
+ "index_type": "VectorStoreIndex",
332
+ "vector_store": "Pinecone",
333
+ "vector_store_index": self._index_name,
334
+ "vector_store_namespace": self._index_namespace,
335
+ "model_name": self.llm._model_name,
336
+ "custom_kwargs": self.llm._model_kwargs,
337
+ "start_time": result_query_only["start"],
338
+ "end_time": result_query_only["end"],
339
+ },
340
+ start_time_ms=result_query_only["start"],
341
+ end_time_ms=result_query_only["end"],
342
+ inputs={"query": query},
343
+ outputs={"response": result_query_only["result"]},
344
+ )
345
+ agent_span.add_child(retrieval_query_span)
346
+
347
+ # w&b trace retrieval context extended query
348
+ retrieval_extended_query_span = Trace(
349
+ name="ExtendedQueryRetrieval",
350
+ kind="chain",
351
+ status_code="success",
352
+ metadata={
353
+ "framework": "Llama-Index",
354
+ "index_type": "VectorStoreIndex",
355
+ "vector_store": "Pinecone",
356
+ "vector_store_index": self._index_name,
357
+ "vector_store_namespace": self._index_namespace,
358
+ "model_name": self.llm._model_name,
359
+ "custom_kwargs": self.llm._model_kwargs,
360
+ "start_time": result_extended_query["start"],
361
+ "end_time": result_extended_query["end"],
362
+ },
363
+ start_time_ms=result_extended_query["start"],
364
+ end_time_ms=result_extended_query["end"],
365
+ inputs={"query": extended_query},
366
+ outputs={"response": result_extended_query["result"]},
367
+ )
368
+ agent_span.add_child(retrieval_extended_query_span)
369
+
370
+ # w&b trace select context
371
+ eval_start_time_ms = round(datetime.datetime.now().timestamp() * 1000)
372
+ eval_context_query_only, eval_context_extended_query = asyncio.run(
373
+ self.aevaluate_context_multi([query, extended_query], [result_query_only["result"], result_extended_query["result"]])
374
+ )
375
+
376
+ if eval_context_query_only.score > eval_context_extended_query.score:
377
+ query_selected, context_from_index_selected = query, result_query_only["result"]
378
+ else:
379
+ query_selected, context_from_index_selected = extended_query, result_extended_query["result"]
380
+
381
+ agent_end_time_ms = round(datetime.datetime.now().timestamp() * 1000)
382
+ eval_span = Trace(
383
+ name="EmbeddingsEvaluator",
384
+ kind="tool",
385
+ status_code="success",
386
+ metadata={
387
+ "framework": "Llama-Index",
388
+ "evaluator": "SemanticSimilarityEvaluator",
389
+ "similarity_mode": "DEFAULT",
390
+ "similarity_threshold": 0.7,
391
+ "similarity_results": {
392
+ "eval_context_query_only": eval_context_query_only,
393
+ "eval_context_extended_query": eval_context_extended_query,
394
+ },
395
+ "model_name": self.emb_model._model_name,
396
+ },
397
+ start_time_ms=eval_start_time_ms,
398
+ end_time_ms=agent_end_time_ms,
399
+ inputs={"query": query_selected},
400
+ outputs={"response": context_from_index_selected},
401
+ )
402
+ agent_span.add_child(eval_span)
403
+
404
+ except Exception as e:
405
+ logger.error(f"Exception {e} occured when retriving context\n")
406
+
407
+ llm_end_time_ms = round(datetime.datetime.now().timestamp() * 1000)
408
+ result = "Something went wrong. Please try again later."
409
+ root_span.add_inputs_and_outputs(
410
+ inputs={"query": query}, outputs={"result": result, "exception": e}
411
+ )
412
+ root_span._span.status_code="fail"
413
+ root_span._span.end_time_ms = llm_end_time_ms
414
+ root_span.log(name="llm_app_trace")
415
+ return result
416
+
417
+ logger.info(f"Context from Llama-Index:\n{context_from_index_selected}\n")
418
+
419
+ agent_span.add_inputs_and_outputs(
420
+ inputs={"query": query}, outputs={"result": context_from_index_selected}
421
+ )
422
+ agent_span._span.status_code="success"
423
+ agent_span._span.end_time_ms = agent_end_time_ms
424
+ root_span.add_child(agent_span)
425
+
426
+ # generate text with prompt template to roleplay myself
427
+ prompt_with_context = self.prompt_template.format(context_history=context_history, context_from_index=context_from_index_selected, user_query=query)
428
+ try:
429
+ response = genai.generate_text(
430
+ prompt=prompt_with_context,
431
+ safety_settings=[
432
+ {
433
+ 'category': genai.types.HarmCategory.HARM_CATEGORY_UNSPECIFIED,
434
+ 'threshold': genai.types.HarmBlockThreshold.BLOCK_NONE,
435
+ },
436
+ ],
437
+ temperature=0.9,
438
+ )
439
+ result = response.result
440
+ success_flag = "success"
441
+ if result is None:
442
+ result = "Seems something went wrong. Please try again later."
443
+ logger.error(f"Result with 'None' received\n")
444
+ success_flag = "fail"
445
+
446
+ except Exception as e:
447
+ result = "Seems something went wrong. Please try again later."
448
+ logger.error(f"Exception {e} occured\n")
449
+ success_flag = "fail"
450
+
451
+ # w&b trace llm
452
+ llm_end_time_ms = round(datetime.datetime.now().timestamp() * 1000)
453
+ llm_span = Trace(
454
+ name="LLM",
455
+ kind="llm",
456
+ status_code=success_flag,
457
+ start_time_ms=agent_end_time_ms,
458
+ end_time_ms=llm_end_time_ms,
459
+ inputs={"input": prompt_with_context},
460
+ outputs={"result": result},
461
+ )
462
+ root_span.add_child(llm_span)
463
+
464
+ # w&b finalize trace
465
+ root_span.add_inputs_and_outputs(
466
+ inputs={"query": query}, outputs={"result": result}
467
+ )
468
+ root_span._span.end_time_ms = llm_end_time_ms
469
+ root_span.log(name="llm_app_trace")
470
+
471
+ return result