g-palm-chat / src /llamaindex_palm.py
hoshingakag's picture
Update src/llamaindex_palm.py
17140ef
import os
import datetime
import asyncio
from concurrent.futures import ThreadPoolExecutor
from typing import Any, List, Dict, Union
from pydantic import Extra
import wandb
from wandb.sdk.data_types.trace_tree import Trace
import pinecone
import google.generativeai as genai
from llama_index import (
ServiceContext,
PromptHelper,
VectorStoreIndex
)
from llama_index.vector_stores import PineconeVectorStore
from llama_index.storage.storage_context import StorageContext
from llama_index.node_parser import SimpleNodeParser
from llama_index.text_splitter import TokenTextSplitter
from llama_index.embeddings.base import BaseEmbedding
from llama_index.llms import (
CustomLLM,
CompletionResponse,
CompletionResponseGen,
LLMMetadata,
)
from llama_index.llms.base import llm_completion_callback
from llama_index.evaluation import SemanticSimilarityEvaluator
from llama_index.embeddings import SimilarityMode
import logging
logging.basicConfig(format='%(asctime)s %(message)s', datefmt='%Y-%m-%d %I:%M:%S %p', level=logging.INFO)
logger = logging.getLogger('llm')
prompt_template = """
[System]
You are in a role play of Gerard Lee. Gerard is a data enthusiast and humble about his success.
Reply in no more than 5 complete sentences unless [User Query] requests to elaborate. Using content from [Context] only without prior knowledge except referring to [History] for seamless conversatation.
[History]
{context_history}
[Context]
{context_from_index}
[User Query]
{user_query}
"""
class LlamaIndexPaLMEmbeddings(BaseEmbedding, extra=Extra.allow):
def __init__(
self,
model_name: str = 'models/embedding-gecko-001',
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
self._model_name = model_name
@classmethod
def class_name(cls) -> str:
return 'PaLMEmbeddings'
def gen_embeddings(self, text: str) -> List[float]:
return genai.generate_embeddings(self._model_name, text)
def _get_query_embedding(self, query: str) -> List[float]:
embeddings = self.gen_embeddings(query)
return embeddings['embedding']
def _get_text_embedding(self, text: str) -> List[float]:
embeddings = self.gen_embeddings(text)
return embeddings['embedding']
def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
embeddings = [
self.gen_embeddings(text)['embedding'] for text in texts
]
return embeddings
async def _aget_query_embedding(self, query: str) -> List[float]:
return self._get_query_embedding(query)
async def _aget_text_embedding(self, text: str) -> List[float]:
return self._get_text_embedding(text)
class LlamaIndexPaLMText(CustomLLM, extra=Extra.allow):
def __init__(
self,
model_name: str = 'models/text-bison-001',
model_kwargs: dict = {},
context_window: int = 8196,
num_output: int = 1024,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
self._model_name = model_name
self._model_kwargs = model_kwargs
self._context_window = context_window
self._num_output = num_output
@property
def metadata(self) -> LLMMetadata:
"""Get LLM metadata."""
return LLMMetadata(
context_window=self._context_window,
num_output=self._num_output,
model_name=self._model_name
)
def gen_texts(self, prompt):
logging.debug(f"prompt: {prompt}")
response = genai.generate_text(
model=self._model_name,
prompt=prompt,
safety_settings=[
{
'category': genai.types.HarmCategory.HARM_CATEGORY_UNSPECIFIED,
'threshold': genai.types.HarmBlockThreshold.BLOCK_NONE,
},
],
**self._model_kwargs
)
logging.debug(f"response:\n{response}")
return response.candidates[0]['output']
@llm_completion_callback()
def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
text = self.gen_texts(prompt)
return CompletionResponse(text=text)
@llm_completion_callback()
def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen:
raise NotImplementedError()
class LlamaIndexPaLM():
def __init__(
self,
emb_model: LlamaIndexPaLMEmbeddings = LlamaIndexPaLMEmbeddings(),
model: LlamaIndexPaLMText = LlamaIndexPaLMText(),
# prompt_template: str = prompt_template
) -> None:
self.emb_model = emb_model
self.llm = model
self.prompt_template = prompt_template
# Google Generative AI
genai.configure(api_key=os.environ['PALM_API_KEY'])
# Pinecone
pinecone.init(
api_key=os.environ['PINECONE_API_KEY'],
environment=os.getenv('PINECONE_ENV')
)
# W&B
wandb.init(project=os.getenv('WANDB_PROJECT'))
# model metadata
CONTEXT_WINDOW = os.getenv('CONTEXT_WINDOW', 8196)
NUM_OUTPUT = os.getenv('NUM_OUTPUT', 1024)
TEXT_CHUNK_SIZE = os.getenv('TEXT_CHUNK_SIZE', 512)
TEXT_CHUNK_OVERLAP = os.getenv('TEXT_CHUNK_OVERLAP', 20)
TEXT_CHUNK_OVERLAP_RATIO = os.getenv('TEXT_CHUNK_OVERLAP_RATIO', 0.1)
TEXT_CHUNK_SIZE_LIMIT = os.getenv('TEXT_CHUNK_SIZE_LIMIT', None)
self.node_parser = SimpleNodeParser.from_defaults(
text_splitter=TokenTextSplitter(
chunk_size=TEXT_CHUNK_SIZE, chunk_overlap=TEXT_CHUNK_OVERLAP
)
)
self.prompt_helper = PromptHelper(
context_window=CONTEXT_WINDOW,
num_output=NUM_OUTPUT,
chunk_overlap_ratio=TEXT_CHUNK_OVERLAP_RATIO,
chunk_size_limit=TEXT_CHUNK_SIZE_LIMIT
)
self.service_context = ServiceContext.from_defaults(
llm=self.llm,
embed_model=self.emb_model,
node_parser=self.node_parser,
prompt_helper=self.prompt_helper,
)
self.emd_evaluator = SemanticSimilarityEvaluator(
service_context=self.service_context,
similarity_mode=SimilarityMode.DEFAULT,
similarity_threshold=os.getenv('SIMILARITY_THRESHOLD', 0.7),
)
def get_index_from_pinecone(
self,
index_name: str = os.getenv('PINECONE_INDEX'),
index_namespace: str = os.getenv('PINECONE_NAMESPACE')
) -> None:
# Pinecone VectorStore
pinecone_index = pinecone.Index(index_name)
self.vector_store = PineconeVectorStore(pinecone_index=pinecone_index, add_sparse_vector=True, namespace=index_namespace)
self.pinecone_index = VectorStoreIndex.from_vector_store(self.vector_store, self.service_context)
self._index_name = index_name
self._index_namespace = index_namespace
return None
def retrieve_context(
self,
query: str
) -> Dict[str, Union[str, int]]:
start_time = round(datetime.datetime.now().timestamp() * 1000)
response = self.pinecone_index.as_query_engine(similarity_top_k=3).query(query)
end_time = round(datetime.datetime.now().timestamp() * 1000)
return {"result": response.response, "start": start_time, "end": end_time}
async def aretrieve_context(
self,
query: str
) -> Dict[str, Union[str, int]]:
start_time = round(datetime.datetime.now().timestamp() * 1000)
response = await self.pinecone_index.as_query_engine(similarity_top_k=3, use_async=True).aquery(query)
end_time = round(datetime.datetime.now().timestamp() * 1000)
return {"result": response.response, "start": start_time, "end": end_time}
async def aretrieve_context_multi(
self,
query_list: List[str]
) -> List[Dict]:
result = await asyncio.gather(*(self.aretrieve_context(query) for query in query_list))
return result
async def aevaluate_context(
self,
query: str,
returned_context: str
) -> Dict[str, Any]:
result = await self.emd_evaluator.aevaluate(
response=returned_context,
reference=query,
)
return result
async def aevaluate_context_multi(
self,
query_list: List[str],
returned_context_list: List[str]
) -> List[Dict]:
result = await asyncio.gather(*(self.aevaluate_context(query, returned_context) for query, returned_context in zip(query_list, returned_context_list)))
return result
def format_history_as_context(
self,
history: List[str],
) -> str:
format_chat_history = "\n".join(list(filter(None, history)))
return format_chat_history
def generate_text(
self,
query: str,
history: List[str],
) -> str:
# get history
context_history = self.format_history_as_context(history=history)
# w&b trace start
start_time_ms = round(datetime.datetime.now().timestamp() * 1000)
root_span = Trace(
name="MetaAgent",
kind="agent",
start_time_ms=start_time_ms,
metadata={"user": "🤗 Space"},
)
# get retrieval context(s) from llama-index vectorstore index
# w&b trace retrieval & select agent
agent_span = Trace(
name="LlamaIndexAgent",
kind="agent",
start_time_ms=start_time_ms,
)
try:
# No history, single context retrieval without evaluation
if not history:
# w&b trace retrieval context
result_query_only = self.retrieve_context(query)
# async version
# result_query_only = asyncio.run(self.retrieve_context(query))
context_from_index_selected = result_query_only["result"]
agent_end_time_ms = round(datetime.datetime.now().timestamp() * 1000)
retrieval_span = Trace(
name="QueryRetrieval",
kind="chain",
status_code="success",
metadata={
"framework": "Llama-Index",
"index_type": "VectorStoreIndex",
"vector_store": "Pinecone",
"vector_store_index": self._index_name,
"vector_store_namespace": self._index_namespace,
"model_name": self.llm._model_name,
"custom_kwargs": self.llm._model_kwargs,
},
start_time_ms=start_time_ms,
end_time_ms=agent_end_time_ms,
inputs={"query": query},
outputs={"response": context_from_index_selected},
)
agent_span.add_child(retrieval_span)
# Has history, multiple context retrieval with async, then evaluation to determine which context to choose
else:
extended_query = f"[History]\n{history[-1]}\n[New Query]\n{query}"
# thread version
with ThreadPoolExecutor(2) as executor:
results = executor.map(self.retrieve_context, [query, extended_query])
result_query_only, result_extended_query = [rec for rec in results]
# async version - not working
# result_query_only, result_extended_query = asyncio.run(
# self.aretrieve_context_multi([query, extended_query])
# )
# w&b trace retrieval context query only
retrieval_query_span = Trace(
name="QueryRetrieval",
kind="chain",
status_code="success",
metadata={
"framework": "Llama-Index",
"index_type": "VectorStoreIndex",
"vector_store": "Pinecone",
"vector_store_index": self._index_name,
"vector_store_namespace": self._index_namespace,
"model_name": self.llm._model_name,
"custom_kwargs": self.llm._model_kwargs,
"start_time": result_query_only["start"],
"end_time": result_query_only["end"],
},
start_time_ms=result_query_only["start"],
end_time_ms=result_query_only["end"],
inputs={"query": query},
outputs={"response": result_query_only["result"]},
)
agent_span.add_child(retrieval_query_span)
# w&b trace retrieval context extended query
retrieval_extended_query_span = Trace(
name="ExtendedQueryRetrieval",
kind="chain",
status_code="success",
metadata={
"framework": "Llama-Index",
"index_type": "VectorStoreIndex",
"vector_store": "Pinecone",
"vector_store_index": self._index_name,
"vector_store_namespace": self._index_namespace,
"model_name": self.llm._model_name,
"custom_kwargs": self.llm._model_kwargs,
"start_time": result_extended_query["start"],
"end_time": result_extended_query["end"],
},
start_time_ms=result_extended_query["start"],
end_time_ms=result_extended_query["end"],
inputs={"query": extended_query},
outputs={"response": result_extended_query["result"]},
)
agent_span.add_child(retrieval_extended_query_span)
# w&b trace select context
eval_start_time_ms = round(datetime.datetime.now().timestamp() * 1000)
eval_context_query_only, eval_context_extended_query = asyncio.run(
self.aevaluate_context_multi([query, extended_query], [result_query_only["result"], result_extended_query["result"]])
)
if eval_context_query_only.score > eval_context_extended_query.score:
query_selected, context_from_index_selected = query, result_query_only["result"]
else:
query_selected, context_from_index_selected = extended_query, result_extended_query["result"]
agent_end_time_ms = round(datetime.datetime.now().timestamp() * 1000)
eval_span = Trace(
name="EmbeddingsEvaluator",
kind="tool",
status_code="success",
metadata={
"framework": "Llama-Index",
"evaluator": "SemanticSimilarityEvaluator",
"similarity_mode": "DEFAULT",
"similarity_threshold": 0.7,
"similarity_results": {
"eval_context_query_only": eval_context_query_only.score,
"eval_context_extended_query": eval_context_extended_query.score,
},
"model_name": self.emb_model._model_name,
},
start_time_ms=eval_start_time_ms,
end_time_ms=agent_end_time_ms,
inputs={"query": query_selected},
outputs={"response": context_from_index_selected},
)
agent_span.add_child(eval_span)
except Exception as e:
logger.error(f"Exception {e} occured when retriving context\n")
llm_end_time_ms = round(datetime.datetime.now().timestamp() * 1000)
result = "Something went wrong. Please try again later."
root_span.add_inputs_and_outputs(
inputs={"query": query}, outputs={"result": result, "exception": e}
)
root_span._span.status_code="fail"
root_span._span.end_time_ms = llm_end_time_ms
root_span.log(name="llm_app_trace")
return result
logger.info(f"Context from Llama-Index:\n{context_from_index_selected}\n")
agent_span.add_inputs_and_outputs(
inputs={"query": query}, outputs={"result": context_from_index_selected}
)
agent_span._span.status_code="success"
agent_span._span.end_time_ms = agent_end_time_ms
root_span.add_child(agent_span)
# generate text with prompt template to roleplay myself
prompt_with_context = self.prompt_template.format(context_history=context_history, context_from_index=context_from_index_selected, user_query=query)
try:
response = genai.generate_text(
prompt=prompt_with_context,
safety_settings=[
{
'category': genai.types.HarmCategory.HARM_CATEGORY_UNSPECIFIED,
'threshold': genai.types.HarmBlockThreshold.BLOCK_NONE,
},
],
temperature=0.9,
)
result = response.result
success_flag = "success"
if result is None:
result = "Seems something went wrong. Please try again later."
logger.error(f"Result with 'None' received\n")
success_flag = "fail"
except Exception as e:
result = "Seems something went wrong. Please try again later."
logger.error(f"Exception {e} occured\n")
success_flag = "fail"
# w&b trace llm
llm_end_time_ms = round(datetime.datetime.now().timestamp() * 1000)
llm_span = Trace(
name="LLM",
kind="llm",
status_code=success_flag,
start_time_ms=agent_end_time_ms,
end_time_ms=llm_end_time_ms,
inputs={"input": prompt_with_context},
outputs={"result": result},
)
root_span.add_child(llm_span)
# w&b finalize trace
root_span.add_inputs_and_outputs(
inputs={"query": query}, outputs={"result": result}
)
root_span._span.end_time_ms = llm_end_time_ms
root_span.log(name="llm_app_trace")
return result