Spaces:
Running
Running
import os | |
import datetime | |
import asyncio | |
from concurrent.futures import ThreadPoolExecutor | |
from typing import Any, List, Dict, Union | |
from pydantic import Extra | |
import wandb | |
from wandb.sdk.data_types.trace_tree import Trace | |
import pinecone | |
import google.generativeai as genai | |
from llama_index import ( | |
ServiceContext, | |
PromptHelper, | |
VectorStoreIndex | |
) | |
from llama_index.vector_stores import PineconeVectorStore | |
from llama_index.storage.storage_context import StorageContext | |
from llama_index.node_parser import SimpleNodeParser | |
from llama_index.text_splitter import TokenTextSplitter | |
from llama_index.embeddings.base import BaseEmbedding | |
from llama_index.llms import ( | |
CustomLLM, | |
CompletionResponse, | |
CompletionResponseGen, | |
LLMMetadata, | |
) | |
from llama_index.llms.base import llm_completion_callback | |
from llama_index.evaluation import SemanticSimilarityEvaluator | |
from llama_index.embeddings import SimilarityMode | |
import logging | |
logging.basicConfig(format='%(asctime)s %(message)s', datefmt='%Y-%m-%d %I:%M:%S %p', level=logging.INFO) | |
logger = logging.getLogger('llm') | |
prompt_template = """ | |
[System] | |
You are in a role play of Gerard Lee. Gerard is a data enthusiast and humble about his success. | |
Reply in no more than 5 complete sentences unless [User Query] requests to elaborate. Using content from [Context] only without prior knowledge except referring to [History] for seamless conversatation. | |
[History] | |
{context_history} | |
[Context] | |
{context_from_index} | |
[User Query] | |
{user_query} | |
""" | |
class LlamaIndexPaLMEmbeddings(BaseEmbedding, extra=Extra.allow): | |
def __init__( | |
self, | |
model_name: str = 'models/embedding-gecko-001', | |
**kwargs: Any, | |
) -> None: | |
super().__init__(**kwargs) | |
self._model_name = model_name | |
def class_name(cls) -> str: | |
return 'PaLMEmbeddings' | |
def gen_embeddings(self, text: str) -> List[float]: | |
return genai.generate_embeddings(self._model_name, text) | |
def _get_query_embedding(self, query: str) -> List[float]: | |
embeddings = self.gen_embeddings(query) | |
return embeddings['embedding'] | |
def _get_text_embedding(self, text: str) -> List[float]: | |
embeddings = self.gen_embeddings(text) | |
return embeddings['embedding'] | |
def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]: | |
embeddings = [ | |
self.gen_embeddings(text)['embedding'] for text in texts | |
] | |
return embeddings | |
async def _aget_query_embedding(self, query: str) -> List[float]: | |
return self._get_query_embedding(query) | |
async def _aget_text_embedding(self, text: str) -> List[float]: | |
return self._get_text_embedding(text) | |
class LlamaIndexPaLMText(CustomLLM, extra=Extra.allow): | |
def __init__( | |
self, | |
model_name: str = 'models/text-bison-001', | |
model_kwargs: dict = {}, | |
context_window: int = 8196, | |
num_output: int = 1024, | |
**kwargs: Any, | |
) -> None: | |
super().__init__(**kwargs) | |
self._model_name = model_name | |
self._model_kwargs = model_kwargs | |
self._context_window = context_window | |
self._num_output = num_output | |
def metadata(self) -> LLMMetadata: | |
"""Get LLM metadata.""" | |
return LLMMetadata( | |
context_window=self._context_window, | |
num_output=self._num_output, | |
model_name=self._model_name | |
) | |
def gen_texts(self, prompt): | |
logging.debug(f"prompt: {prompt}") | |
response = genai.generate_text( | |
model=self._model_name, | |
prompt=prompt, | |
safety_settings=[ | |
{ | |
'category': genai.types.HarmCategory.HARM_CATEGORY_UNSPECIFIED, | |
'threshold': genai.types.HarmBlockThreshold.BLOCK_NONE, | |
}, | |
], | |
**self._model_kwargs | |
) | |
logging.debug(f"response:\n{response}") | |
return response.candidates[0]['output'] | |
def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse: | |
text = self.gen_texts(prompt) | |
return CompletionResponse(text=text) | |
def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen: | |
raise NotImplementedError() | |
class LlamaIndexPaLM(): | |
def __init__( | |
self, | |
emb_model: LlamaIndexPaLMEmbeddings = LlamaIndexPaLMEmbeddings(), | |
model: LlamaIndexPaLMText = LlamaIndexPaLMText(), | |
# prompt_template: str = prompt_template | |
) -> None: | |
self.emb_model = emb_model | |
self.llm = model | |
self.prompt_template = prompt_template | |
# Google Generative AI | |
genai.configure(api_key=os.environ['PALM_API_KEY']) | |
# Pinecone | |
pinecone.init( | |
api_key=os.environ['PINECONE_API_KEY'], | |
environment=os.getenv('PINECONE_ENV') | |
) | |
# W&B | |
wandb.init(project=os.getenv('WANDB_PROJECT')) | |
# model metadata | |
CONTEXT_WINDOW = os.getenv('CONTEXT_WINDOW', 8196) | |
NUM_OUTPUT = os.getenv('NUM_OUTPUT', 1024) | |
TEXT_CHUNK_SIZE = os.getenv('TEXT_CHUNK_SIZE', 512) | |
TEXT_CHUNK_OVERLAP = os.getenv('TEXT_CHUNK_OVERLAP', 20) | |
TEXT_CHUNK_OVERLAP_RATIO = os.getenv('TEXT_CHUNK_OVERLAP_RATIO', 0.1) | |
TEXT_CHUNK_SIZE_LIMIT = os.getenv('TEXT_CHUNK_SIZE_LIMIT', None) | |
self.node_parser = SimpleNodeParser.from_defaults( | |
text_splitter=TokenTextSplitter( | |
chunk_size=TEXT_CHUNK_SIZE, chunk_overlap=TEXT_CHUNK_OVERLAP | |
) | |
) | |
self.prompt_helper = PromptHelper( | |
context_window=CONTEXT_WINDOW, | |
num_output=NUM_OUTPUT, | |
chunk_overlap_ratio=TEXT_CHUNK_OVERLAP_RATIO, | |
chunk_size_limit=TEXT_CHUNK_SIZE_LIMIT | |
) | |
self.service_context = ServiceContext.from_defaults( | |
llm=self.llm, | |
embed_model=self.emb_model, | |
node_parser=self.node_parser, | |
prompt_helper=self.prompt_helper, | |
) | |
self.emd_evaluator = SemanticSimilarityEvaluator( | |
service_context=self.service_context, | |
similarity_mode=SimilarityMode.DEFAULT, | |
similarity_threshold=os.getenv('SIMILARITY_THRESHOLD', 0.7), | |
) | |
def get_index_from_pinecone( | |
self, | |
index_name: str = os.getenv('PINECONE_INDEX'), | |
index_namespace: str = os.getenv('PINECONE_NAMESPACE') | |
) -> None: | |
# Pinecone VectorStore | |
pinecone_index = pinecone.Index(index_name) | |
self.vector_store = PineconeVectorStore(pinecone_index=pinecone_index, add_sparse_vector=True, namespace=index_namespace) | |
self.pinecone_index = VectorStoreIndex.from_vector_store(self.vector_store, self.service_context) | |
self._index_name = index_name | |
self._index_namespace = index_namespace | |
return None | |
def retrieve_context( | |
self, | |
query: str | |
) -> Dict[str, Union[str, int]]: | |
start_time = round(datetime.datetime.now().timestamp() * 1000) | |
response = self.pinecone_index.as_query_engine(similarity_top_k=3).query(query) | |
end_time = round(datetime.datetime.now().timestamp() * 1000) | |
return {"result": response.response, "start": start_time, "end": end_time} | |
async def aretrieve_context( | |
self, | |
query: str | |
) -> Dict[str, Union[str, int]]: | |
start_time = round(datetime.datetime.now().timestamp() * 1000) | |
response = await self.pinecone_index.as_query_engine(similarity_top_k=3, use_async=True).aquery(query) | |
end_time = round(datetime.datetime.now().timestamp() * 1000) | |
return {"result": response.response, "start": start_time, "end": end_time} | |
async def aretrieve_context_multi( | |
self, | |
query_list: List[str] | |
) -> List[Dict]: | |
result = await asyncio.gather(*(self.aretrieve_context(query) for query in query_list)) | |
return result | |
async def aevaluate_context( | |
self, | |
query: str, | |
returned_context: str | |
) -> Dict[str, Any]: | |
result = await self.emd_evaluator.aevaluate( | |
response=returned_context, | |
reference=query, | |
) | |
return result | |
async def aevaluate_context_multi( | |
self, | |
query_list: List[str], | |
returned_context_list: List[str] | |
) -> List[Dict]: | |
result = await asyncio.gather(*(self.aevaluate_context(query, returned_context) for query, returned_context in zip(query_list, returned_context_list))) | |
return result | |
def format_history_as_context( | |
self, | |
history: List[str], | |
) -> str: | |
format_chat_history = "\n".join(list(filter(None, history))) | |
return format_chat_history | |
def generate_text( | |
self, | |
query: str, | |
history: List[str], | |
) -> str: | |
# get history | |
context_history = self.format_history_as_context(history=history) | |
# w&b trace start | |
start_time_ms = round(datetime.datetime.now().timestamp() * 1000) | |
root_span = Trace( | |
name="MetaAgent", | |
kind="agent", | |
start_time_ms=start_time_ms, | |
metadata={"user": "🤗 Space"}, | |
) | |
# get retrieval context(s) from llama-index vectorstore index | |
# w&b trace retrieval & select agent | |
agent_span = Trace( | |
name="LlamaIndexAgent", | |
kind="agent", | |
start_time_ms=start_time_ms, | |
) | |
try: | |
# No history, single context retrieval without evaluation | |
if not history: | |
# w&b trace retrieval context | |
result_query_only = self.retrieve_context(query) | |
# async version | |
# result_query_only = asyncio.run(self.retrieve_context(query)) | |
context_from_index_selected = result_query_only["result"] | |
agent_end_time_ms = round(datetime.datetime.now().timestamp() * 1000) | |
retrieval_span = Trace( | |
name="QueryRetrieval", | |
kind="chain", | |
status_code="success", | |
metadata={ | |
"framework": "Llama-Index", | |
"index_type": "VectorStoreIndex", | |
"vector_store": "Pinecone", | |
"vector_store_index": self._index_name, | |
"vector_store_namespace": self._index_namespace, | |
"model_name": self.llm._model_name, | |
"custom_kwargs": self.llm._model_kwargs, | |
}, | |
start_time_ms=start_time_ms, | |
end_time_ms=agent_end_time_ms, | |
inputs={"query": query}, | |
outputs={"response": context_from_index_selected}, | |
) | |
agent_span.add_child(retrieval_span) | |
# Has history, multiple context retrieval with async, then evaluation to determine which context to choose | |
else: | |
extended_query = f"[History]\n{history[-1]}\n[New Query]\n{query}" | |
# thread version | |
with ThreadPoolExecutor(2) as executor: | |
results = executor.map(self.retrieve_context, [query, extended_query]) | |
result_query_only, result_extended_query = [rec for rec in results] | |
# async version - not working | |
# result_query_only, result_extended_query = asyncio.run( | |
# self.aretrieve_context_multi([query, extended_query]) | |
# ) | |
# w&b trace retrieval context query only | |
retrieval_query_span = Trace( | |
name="QueryRetrieval", | |
kind="chain", | |
status_code="success", | |
metadata={ | |
"framework": "Llama-Index", | |
"index_type": "VectorStoreIndex", | |
"vector_store": "Pinecone", | |
"vector_store_index": self._index_name, | |
"vector_store_namespace": self._index_namespace, | |
"model_name": self.llm._model_name, | |
"custom_kwargs": self.llm._model_kwargs, | |
"start_time": result_query_only["start"], | |
"end_time": result_query_only["end"], | |
}, | |
start_time_ms=result_query_only["start"], | |
end_time_ms=result_query_only["end"], | |
inputs={"query": query}, | |
outputs={"response": result_query_only["result"]}, | |
) | |
agent_span.add_child(retrieval_query_span) | |
# w&b trace retrieval context extended query | |
retrieval_extended_query_span = Trace( | |
name="ExtendedQueryRetrieval", | |
kind="chain", | |
status_code="success", | |
metadata={ | |
"framework": "Llama-Index", | |
"index_type": "VectorStoreIndex", | |
"vector_store": "Pinecone", | |
"vector_store_index": self._index_name, | |
"vector_store_namespace": self._index_namespace, | |
"model_name": self.llm._model_name, | |
"custom_kwargs": self.llm._model_kwargs, | |
"start_time": result_extended_query["start"], | |
"end_time": result_extended_query["end"], | |
}, | |
start_time_ms=result_extended_query["start"], | |
end_time_ms=result_extended_query["end"], | |
inputs={"query": extended_query}, | |
outputs={"response": result_extended_query["result"]}, | |
) | |
agent_span.add_child(retrieval_extended_query_span) | |
# w&b trace select context | |
eval_start_time_ms = round(datetime.datetime.now().timestamp() * 1000) | |
eval_context_query_only, eval_context_extended_query = asyncio.run( | |
self.aevaluate_context_multi([query, extended_query], [result_query_only["result"], result_extended_query["result"]]) | |
) | |
if eval_context_query_only.score > eval_context_extended_query.score: | |
query_selected, context_from_index_selected = query, result_query_only["result"] | |
else: | |
query_selected, context_from_index_selected = extended_query, result_extended_query["result"] | |
agent_end_time_ms = round(datetime.datetime.now().timestamp() * 1000) | |
eval_span = Trace( | |
name="EmbeddingsEvaluator", | |
kind="tool", | |
status_code="success", | |
metadata={ | |
"framework": "Llama-Index", | |
"evaluator": "SemanticSimilarityEvaluator", | |
"similarity_mode": "DEFAULT", | |
"similarity_threshold": 0.7, | |
"similarity_results": { | |
"eval_context_query_only": eval_context_query_only["result"], | |
"eval_context_extended_query": eval_context_extended_query["result"], | |
}, | |
"model_name": self.emb_model._model_name, | |
}, | |
start_time_ms=eval_start_time_ms, | |
end_time_ms=agent_end_time_ms, | |
inputs={"query": query_selected}, | |
outputs={"response": context_from_index_selected}, | |
) | |
agent_span.add_child(eval_span) | |
except Exception as e: | |
logger.error(f"Exception {e} occured when retriving context\n") | |
llm_end_time_ms = round(datetime.datetime.now().timestamp() * 1000) | |
result = "Something went wrong. Please try again later." | |
root_span.add_inputs_and_outputs( | |
inputs={"query": query}, outputs={"result": result, "exception": e} | |
) | |
root_span._span.status_code="fail" | |
root_span._span.end_time_ms = llm_end_time_ms | |
root_span.log(name="llm_app_trace") | |
return result | |
logger.info(f"Context from Llama-Index:\n{context_from_index_selected}\n") | |
agent_span.add_inputs_and_outputs( | |
inputs={"query": query}, outputs={"result": context_from_index_selected} | |
) | |
agent_span._span.status_code="success" | |
agent_span._span.end_time_ms = agent_end_time_ms | |
root_span.add_child(agent_span) | |
# generate text with prompt template to roleplay myself | |
prompt_with_context = self.prompt_template.format(context_history=context_history, context_from_index=context_from_index_selected, user_query=query) | |
try: | |
response = genai.generate_text( | |
prompt=prompt_with_context, | |
safety_settings=[ | |
{ | |
'category': genai.types.HarmCategory.HARM_CATEGORY_UNSPECIFIED, | |
'threshold': genai.types.HarmBlockThreshold.BLOCK_NONE, | |
}, | |
], | |
temperature=0.9, | |
) | |
result = response.result | |
success_flag = "success" | |
if result is None: | |
result = "Seems something went wrong. Please try again later." | |
logger.error(f"Result with 'None' received\n") | |
success_flag = "fail" | |
except Exception as e: | |
result = "Seems something went wrong. Please try again later." | |
logger.error(f"Exception {e} occured\n") | |
success_flag = "fail" | |
# w&b trace llm | |
llm_end_time_ms = round(datetime.datetime.now().timestamp() * 1000) | |
llm_span = Trace( | |
name="LLM", | |
kind="llm", | |
status_code=success_flag, | |
start_time_ms=agent_end_time_ms, | |
end_time_ms=llm_end_time_ms, | |
inputs={"input": prompt_with_context}, | |
outputs={"result": result}, | |
) | |
root_span.add_child(llm_span) | |
# w&b finalize trace | |
root_span.add_inputs_and_outputs( | |
inputs={"query": query}, outputs={"result": result} | |
) | |
root_span._span.end_time_ms = llm_end_time_ms | |
root_span.log(name="llm_app_trace") | |
return result |