g-palm-chat / src /llamaindex_palm.py
hoshingakag's picture
Update src/llamaindex_palm.py
36cb8de
raw
history blame
18.8 kB
import os
import datetime
import asyncio
from concurrent.futures import ThreadPoolExecutor
from typing import Any, List, Dict, Union
from pydantic import Extra
import wandb
from wandb.sdk.data_types.trace_tree import Trace
import pinecone
import google.generativeai as genai
from llama_index import (
ServiceContext,
PromptHelper,
VectorStoreIndex
)
from llama_index.vector_stores import PineconeVectorStore
from llama_index.storage.storage_context import StorageContext
from llama_index.node_parser import SimpleNodeParser
from llama_index.text_splitter import TokenTextSplitter
from llama_index.embeddings.base import BaseEmbedding
from llama_index.llms import (
CustomLLM,
CompletionResponse,
CompletionResponseGen,
LLMMetadata,
)
from llama_index.llms.base import llm_completion_callback
from llama_index.evaluation import SemanticSimilarityEvaluator
from llama_index.embeddings import SimilarityMode
import logging
logging.basicConfig(format='%(asctime)s %(message)s', datefmt='%Y-%m-%d %I:%M:%S %p', level=logging.INFO)
logger = logging.getLogger('llm')
prompt_template = """
[System]
You are in a role play of Gerard Lee. Gerard is a data enthusiast and humble about his success.
Reply in no more than 5 complete sentences unless [User Query] requests to elaborate. Using content from [Context] only without prior knowledge except referring to [History] for seamless conversatation.
[History]
{context_history}
[Context]
{context_from_index}
[User Query]
{user_query}
"""
class LlamaIndexPaLMEmbeddings(BaseEmbedding, extra=Extra.allow):
def __init__(
self,
model_name: str = 'models/embedding-gecko-001',
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
self._model_name = model_name
@classmethod
def class_name(cls) -> str:
return 'PaLMEmbeddings'
def gen_embeddings(self, text: str) -> List[float]:
return genai.generate_embeddings(self._model_name, text)
def _get_query_embedding(self, query: str) -> List[float]:
embeddings = self.gen_embeddings(query)
return embeddings['embedding']
def _get_text_embedding(self, text: str) -> List[float]:
embeddings = self.gen_embeddings(text)
return embeddings['embedding']
def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
embeddings = [
self.gen_embeddings(text)['embedding'] for text in texts
]
return embeddings
async def _aget_query_embedding(self, query: str) -> List[float]:
return self._get_query_embedding(query)
async def _aget_text_embedding(self, text: str) -> List[float]:
return self._get_text_embedding(text)
class LlamaIndexPaLMText(CustomLLM, extra=Extra.allow):
def __init__(
self,
model_name: str = 'models/text-bison-001',
model_kwargs: dict = {},
context_window: int = 8196,
num_output: int = 1024,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
self._model_name = model_name
self._model_kwargs = model_kwargs
self._context_window = context_window
self._num_output = num_output
@property
def metadata(self) -> LLMMetadata:
"""Get LLM metadata."""
return LLMMetadata(
context_window=self._context_window,
num_output=self._num_output,
model_name=self._model_name
)
def gen_texts(self, prompt):
logging.debug(f"prompt: {prompt}")
response = genai.generate_text(
model=self._model_name,
prompt=prompt,
safety_settings=[
{
'category': genai.types.HarmCategory.HARM_CATEGORY_UNSPECIFIED,
'threshold': genai.types.HarmBlockThreshold.BLOCK_NONE,
},
],
**self._model_kwargs
)
logging.debug(f"response:\n{response}")
return response.candidates[0]['output']
@llm_completion_callback()
def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
text = self.gen_texts(prompt)
return CompletionResponse(text=text)
@llm_completion_callback()
def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen:
raise NotImplementedError()
class LlamaIndexPaLM():
def __init__(
self,
emb_model: LlamaIndexPaLMEmbeddings = LlamaIndexPaLMEmbeddings(),
model: LlamaIndexPaLMText = LlamaIndexPaLMText(),
# prompt_template: str = prompt_template
) -> None:
self.emb_model = emb_model
self.llm = model
self.prompt_template = prompt_template
# Google Generative AI
genai.configure(api_key=os.environ['PALM_API_KEY'])
# Pinecone
pinecone.init(
api_key=os.environ['PINECONE_API_KEY'],
environment=os.getenv('PINECONE_ENV')
)
# W&B
wandb.init(project=os.getenv('WANDB_PROJECT'))
# model metadata
CONTEXT_WINDOW = os.getenv('CONTEXT_WINDOW', 8196)
NUM_OUTPUT = os.getenv('NUM_OUTPUT', 1024)
TEXT_CHUNK_SIZE = os.getenv('TEXT_CHUNK_SIZE', 512)
TEXT_CHUNK_OVERLAP = os.getenv('TEXT_CHUNK_OVERLAP', 20)
TEXT_CHUNK_OVERLAP_RATIO = os.getenv('TEXT_CHUNK_OVERLAP_RATIO', 0.1)
TEXT_CHUNK_SIZE_LIMIT = os.getenv('TEXT_CHUNK_SIZE_LIMIT', None)
self.node_parser = SimpleNodeParser.from_defaults(
text_splitter=TokenTextSplitter(
chunk_size=TEXT_CHUNK_SIZE, chunk_overlap=TEXT_CHUNK_OVERLAP
)
)
self.prompt_helper = PromptHelper(
context_window=CONTEXT_WINDOW,
num_output=NUM_OUTPUT,
chunk_overlap_ratio=TEXT_CHUNK_OVERLAP_RATIO,
chunk_size_limit=TEXT_CHUNK_SIZE_LIMIT
)
self.service_context = ServiceContext.from_defaults(
llm=self.llm,
embed_model=self.emb_model,
node_parser=self.node_parser,
prompt_helper=self.prompt_helper,
)
self.emd_evaluator = SemanticSimilarityEvaluator(
service_context=self.service_context,
similarity_mode=SimilarityMode.DEFAULT,
similarity_threshold=os.getenv('SIMILARITY_THRESHOLD', 0.7),
)
def get_index_from_pinecone(
self,
index_name: str = os.getenv('PINECONE_INDEX'),
index_namespace: str = os.getenv('PINECONE_NAMESPACE')
) -> None:
# Pinecone VectorStore
pinecone_index = pinecone.Index(index_name)
self.vector_store = PineconeVectorStore(pinecone_index=pinecone_index, add_sparse_vector=True, namespace=index_namespace)
self.pinecone_index = VectorStoreIndex.from_vector_store(self.vector_store, self.service_context)
self._index_name = index_name
self._index_namespace = index_namespace
return None
def retrieve_context(
self,
query: str
) -> Dict[str, Union[str, int]]:
start_time = round(datetime.datetime.now().timestamp() * 1000)
response = self.pinecone_index.as_query_engine(similarity_top_k=3).query(query)
end_time = round(datetime.datetime.now().timestamp() * 1000)
return {"result": response.response, "start": start_time, "end": end_time}
async def aretrieve_context(
self,
query: str
) -> Dict[str, Union[str, int]]:
start_time = round(datetime.datetime.now().timestamp() * 1000)
response = await self.pinecone_index.as_query_engine(similarity_top_k=3, use_async=True).aquery(query)
end_time = round(datetime.datetime.now().timestamp() * 1000)
return {"result": response.response, "start": start_time, "end": end_time}
async def aretrieve_context_multi(
self,
query_list: List[str]
) -> List[Dict]:
result = await asyncio.gather(*(self.aretrieve_context(query) for query in query_list))
return result
async def aevaluate_context(
self,
query: str,
returned_context: str
) -> Dict[str, Any]:
result = await self.emd_evaluator.aevaluate(
response=returned_context,
reference=query,
)
return result
async def aevaluate_context_multi(
self,
query_list: List[str],
returned_context_list: List[str]
) -> List[Dict]:
result = await asyncio.gather(*(self.aevaluate_context(query, returned_context) for query, returned_context in zip(query_list, returned_context_list)))
return result
def format_history_as_context(
self,
history: List[str],
) -> str:
format_chat_history = "\n".join(list(filter(None, history)))
return format_chat_history
def generate_text(
self,
query: str,
history: List[str],
) -> str:
# get history
context_history = self.format_history_as_context(history=history)
# w&b trace start
start_time_ms = round(datetime.datetime.now().timestamp() * 1000)
root_span = Trace(
name="MetaAgent",
kind="agent",
start_time_ms=start_time_ms,
metadata={"user": "🤗 Space"},
)
# get retrieval context(s) from llama-index vectorstore index
# w&b trace retrieval & select agent
agent_span = Trace(
name="LlamaIndexAgent",
kind="agent",
start_time_ms=start_time_ms,
)
try:
# No history, single context retrieval without evaluation
if not history:
# w&b trace retrieval context
result_query_only = self.retrieve_context(query)
# async version
# result_query_only = asyncio.run(self.retrieve_context(query))
context_from_index_selected = result_query_only["result"]
agent_end_time_ms = round(datetime.datetime.now().timestamp() * 1000)
retrieval_span = Trace(
name="QueryRetrieval",
kind="chain",
status_code="success",
metadata={
"framework": "Llama-Index",
"index_type": "VectorStoreIndex",
"vector_store": "Pinecone",
"vector_store_index": self._index_name,
"vector_store_namespace": self._index_namespace,
"model_name": self.llm._model_name,
"custom_kwargs": self.llm._model_kwargs,
},
start_time_ms=start_time_ms,
end_time_ms=agent_end_time_ms,
inputs={"query": query},
outputs={"response": context_from_index_selected},
)
agent_span.add_child(retrieval_span)
# Has history, multiple context retrieval with async, then evaluation to determine which context to choose
else:
extended_query = f"[History]\n{history[-1]}\n[New Query]\n{query}"
# thread version
with ThreadPoolExecutor(2) as executor:
results = executor.map(self.retrieve_context, [query, extended_query])
result_query_only, result_extended_query = [rec for rec in results]
# async version - not working
# result_query_only, result_extended_query = asyncio.run(
# self.aretrieve_context_multi([query, extended_query])
# )
# w&b trace retrieval context query only
retrieval_query_span = Trace(
name="QueryRetrieval",
kind="chain",
status_code="success",
metadata={
"framework": "Llama-Index",
"index_type": "VectorStoreIndex",
"vector_store": "Pinecone",
"vector_store_index": self._index_name,
"vector_store_namespace": self._index_namespace,
"model_name": self.llm._model_name,
"custom_kwargs": self.llm._model_kwargs,
"start_time": result_query_only["start"],
"end_time": result_query_only["end"],
},
start_time_ms=result_query_only["start"],
end_time_ms=result_query_only["end"],
inputs={"query": query},
outputs={"response": result_query_only["result"]},
)
agent_span.add_child(retrieval_query_span)
# w&b trace retrieval context extended query
retrieval_extended_query_span = Trace(
name="ExtendedQueryRetrieval",
kind="chain",
status_code="success",
metadata={
"framework": "Llama-Index",
"index_type": "VectorStoreIndex",
"vector_store": "Pinecone",
"vector_store_index": self._index_name,
"vector_store_namespace": self._index_namespace,
"model_name": self.llm._model_name,
"custom_kwargs": self.llm._model_kwargs,
"start_time": result_extended_query["start"],
"end_time": result_extended_query["end"],
},
start_time_ms=result_extended_query["start"],
end_time_ms=result_extended_query["end"],
inputs={"query": extended_query},
outputs={"response": result_extended_query["result"]},
)
agent_span.add_child(retrieval_extended_query_span)
# w&b trace select context
eval_start_time_ms = round(datetime.datetime.now().timestamp() * 1000)
eval_context_query_only, eval_context_extended_query = asyncio.run(
self.aevaluate_context_multi([query, extended_query], [result_query_only["result"], result_extended_query["result"]])
)
if eval_context_query_only.score > eval_context_extended_query.score:
query_selected, context_from_index_selected = query, result_query_only["result"]
else:
query_selected, context_from_index_selected = extended_query, result_extended_query["result"]
agent_end_time_ms = round(datetime.datetime.now().timestamp() * 1000)
eval_span = Trace(
name="EmbeddingsEvaluator",
kind="tool",
status_code="success",
metadata={
"framework": "Llama-Index",
"evaluator": "SemanticSimilarityEvaluator",
"similarity_mode": "DEFAULT",
"similarity_threshold": 0.7,
"similarity_results": {
"eval_context_query_only": eval_context_query_only["result"],
"eval_context_extended_query": eval_context_extended_query["result"],
},
"model_name": self.emb_model._model_name,
},
start_time_ms=eval_start_time_ms,
end_time_ms=agent_end_time_ms,
inputs={"query": query_selected},
outputs={"response": context_from_index_selected},
)
agent_span.add_child(eval_span)
except Exception as e:
logger.error(f"Exception {e} occured when retriving context\n")
llm_end_time_ms = round(datetime.datetime.now().timestamp() * 1000)
result = "Something went wrong. Please try again later."
root_span.add_inputs_and_outputs(
inputs={"query": query}, outputs={"result": result, "exception": e}
)
root_span._span.status_code="fail"
root_span._span.end_time_ms = llm_end_time_ms
root_span.log(name="llm_app_trace")
return result
logger.info(f"Context from Llama-Index:\n{context_from_index_selected}\n")
agent_span.add_inputs_and_outputs(
inputs={"query": query}, outputs={"result": context_from_index_selected}
)
agent_span._span.status_code="success"
agent_span._span.end_time_ms = agent_end_time_ms
root_span.add_child(agent_span)
# generate text with prompt template to roleplay myself
prompt_with_context = self.prompt_template.format(context_history=context_history, context_from_index=context_from_index_selected, user_query=query)
try:
response = genai.generate_text(
prompt=prompt_with_context,
safety_settings=[
{
'category': genai.types.HarmCategory.HARM_CATEGORY_UNSPECIFIED,
'threshold': genai.types.HarmBlockThreshold.BLOCK_NONE,
},
],
temperature=0.9,
)
result = response.result
success_flag = "success"
if result is None:
result = "Seems something went wrong. Please try again later."
logger.error(f"Result with 'None' received\n")
success_flag = "fail"
except Exception as e:
result = "Seems something went wrong. Please try again later."
logger.error(f"Exception {e} occured\n")
success_flag = "fail"
# w&b trace llm
llm_end_time_ms = round(datetime.datetime.now().timestamp() * 1000)
llm_span = Trace(
name="LLM",
kind="llm",
status_code=success_flag,
start_time_ms=agent_end_time_ms,
end_time_ms=llm_end_time_ms,
inputs={"input": prompt_with_context},
outputs={"result": result},
)
root_span.add_child(llm_span)
# w&b finalize trace
root_span.add_inputs_and_outputs(
inputs={"query": query}, outputs={"result": result}
)
root_span._span.end_time_ms = llm_end_time_ms
root_span.log(name="llm_app_trace")
return result