import os import datetime import asyncio from concurrent.futures import ThreadPoolExecutor from typing import Any, List, Dict, Union from pydantic import Extra import wandb from wandb.sdk.data_types.trace_tree import Trace import pinecone import google.generativeai as genai from llama_index import ( ServiceContext, PromptHelper, VectorStoreIndex ) from llama_index.vector_stores import PineconeVectorStore from llama_index.storage.storage_context import StorageContext from llama_index.node_parser import SimpleNodeParser from llama_index.text_splitter import TokenTextSplitter from llama_index.embeddings.base import BaseEmbedding from llama_index.llms import ( CustomLLM, CompletionResponse, CompletionResponseGen, LLMMetadata, ) from llama_index.llms.base import llm_completion_callback from llama_index.evaluation import SemanticSimilarityEvaluator from llama_index.embeddings import SimilarityMode import logging logging.basicConfig(format='%(asctime)s %(message)s', datefmt='%Y-%m-%d %I:%M:%S %p', level=logging.INFO) logger = logging.getLogger('llm') prompt_template = """ [System] You are in a role play of Gerard Lee. Gerard is a data enthusiast and humble about his success. Reply in no more than 5 complete sentences unless [User Query] requests to elaborate. Using content from [Context] only without prior knowledge except referring to [History] for seamless conversatation. [History] {context_history} [Context] {context_from_index} [User Query] {user_query} """ class LlamaIndexPaLMEmbeddings(BaseEmbedding, extra=Extra.allow): def __init__( self, model_name: str = 'models/embedding-gecko-001', **kwargs: Any, ) -> None: super().__init__(**kwargs) self._model_name = model_name @classmethod def class_name(cls) -> str: return 'PaLMEmbeddings' def gen_embeddings(self, text: str) -> List[float]: return genai.generate_embeddings(self._model_name, text) def _get_query_embedding(self, query: str) -> List[float]: embeddings = self.gen_embeddings(query) return embeddings['embedding'] def _get_text_embedding(self, text: str) -> List[float]: embeddings = self.gen_embeddings(text) return embeddings['embedding'] def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]: embeddings = [ self.gen_embeddings(text)['embedding'] for text in texts ] return embeddings async def _aget_query_embedding(self, query: str) -> List[float]: return self._get_query_embedding(query) async def _aget_text_embedding(self, text: str) -> List[float]: return self._get_text_embedding(text) class LlamaIndexPaLMText(CustomLLM, extra=Extra.allow): def __init__( self, model_name: str = 'models/text-bison-001', model_kwargs: dict = {}, context_window: int = 8196, num_output: int = 1024, **kwargs: Any, ) -> None: super().__init__(**kwargs) self._model_name = model_name self._model_kwargs = model_kwargs self._context_window = context_window self._num_output = num_output @property def metadata(self) -> LLMMetadata: """Get LLM metadata.""" return LLMMetadata( context_window=self._context_window, num_output=self._num_output, model_name=self._model_name ) def gen_texts(self, prompt): logging.debug(f"prompt: {prompt}") response = genai.generate_text( model=self._model_name, prompt=prompt, safety_settings=[ { 'category': genai.types.HarmCategory.HARM_CATEGORY_UNSPECIFIED, 'threshold': genai.types.HarmBlockThreshold.BLOCK_NONE, }, ], **self._model_kwargs ) logging.debug(f"response:\n{response}") return response.candidates[0]['output'] @llm_completion_callback() def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse: text = self.gen_texts(prompt) return CompletionResponse(text=text) @llm_completion_callback() def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen: raise NotImplementedError() class LlamaIndexPaLM(): def __init__( self, emb_model: LlamaIndexPaLMEmbeddings = LlamaIndexPaLMEmbeddings(), model: LlamaIndexPaLMText = LlamaIndexPaLMText(), # prompt_template: str = prompt_template ) -> None: self.emb_model = emb_model self.llm = model self.prompt_template = prompt_template # Google Generative AI genai.configure(api_key=os.environ['PALM_API_KEY']) # Pinecone pinecone.init( api_key=os.environ['PINECONE_API_KEY'], environment=os.getenv('PINECONE_ENV') ) # W&B wandb.init(project=os.getenv('WANDB_PROJECT')) # model metadata CONTEXT_WINDOW = os.getenv('CONTEXT_WINDOW', 8196) NUM_OUTPUT = os.getenv('NUM_OUTPUT', 1024) TEXT_CHUNK_SIZE = os.getenv('TEXT_CHUNK_SIZE', 512) TEXT_CHUNK_OVERLAP = os.getenv('TEXT_CHUNK_OVERLAP', 20) TEXT_CHUNK_OVERLAP_RATIO = os.getenv('TEXT_CHUNK_OVERLAP_RATIO', 0.1) TEXT_CHUNK_SIZE_LIMIT = os.getenv('TEXT_CHUNK_SIZE_LIMIT', None) self.node_parser = SimpleNodeParser.from_defaults( text_splitter=TokenTextSplitter( chunk_size=TEXT_CHUNK_SIZE, chunk_overlap=TEXT_CHUNK_OVERLAP ) ) self.prompt_helper = PromptHelper( context_window=CONTEXT_WINDOW, num_output=NUM_OUTPUT, chunk_overlap_ratio=TEXT_CHUNK_OVERLAP_RATIO, chunk_size_limit=TEXT_CHUNK_SIZE_LIMIT ) self.service_context = ServiceContext.from_defaults( llm=self.llm, embed_model=self.emb_model, node_parser=self.node_parser, prompt_helper=self.prompt_helper, ) self.emd_evaluator = SemanticSimilarityEvaluator( service_context=self.service_context, similarity_mode=SimilarityMode.DEFAULT, similarity_threshold=os.getenv('SIMILARITY_THRESHOLD', 0.7), ) def get_index_from_pinecone( self, index_name: str = os.getenv('PINECONE_INDEX'), index_namespace: str = os.getenv('PINECONE_NAMESPACE') ) -> None: # Pinecone VectorStore pinecone_index = pinecone.Index(index_name) self.vector_store = PineconeVectorStore(pinecone_index=pinecone_index, add_sparse_vector=True, namespace=index_namespace) self.pinecone_index = VectorStoreIndex.from_vector_store(self.vector_store, self.service_context) self._index_name = index_name self._index_namespace = index_namespace return None def retrieve_context( self, query: str ) -> Dict[str, Union[str, int]]: start_time = round(datetime.datetime.now().timestamp() * 1000) response = self.pinecone_index.as_query_engine(similarity_top_k=3).query(query) end_time = round(datetime.datetime.now().timestamp() * 1000) return {"result": response.response, "start": start_time, "end": end_time} async def aretrieve_context( self, query: str ) -> Dict[str, Union[str, int]]: start_time = round(datetime.datetime.now().timestamp() * 1000) response = await self.pinecone_index.as_query_engine(similarity_top_k=3, use_async=True).aquery(query) end_time = round(datetime.datetime.now().timestamp() * 1000) return {"result": response.response, "start": start_time, "end": end_time} async def aretrieve_context_multi( self, query_list: List[str] ) -> List[Dict]: result = await asyncio.gather(*(self.aretrieve_context(query) for query in query_list)) return result async def aevaluate_context( self, query: str, returned_context: str ) -> Dict[str, Any]: result = await self.emd_evaluator.aevaluate( response=returned_context, reference=query, ) return result async def aevaluate_context_multi( self, query_list: List[str], returned_context_list: List[str] ) -> List[Dict]: result = await asyncio.gather(*(self.aevaluate_context(query, returned_context) for query, returned_context in zip(query_list, returned_context_list))) return result def format_history_as_context( self, history: List[str], ) -> str: format_chat_history = "\n".join(list(filter(None, history))) return format_chat_history def generate_text( self, query: str, history: List[str], ) -> str: # get history context_history = self.format_history_as_context(history=history) # w&b trace start start_time_ms = round(datetime.datetime.now().timestamp() * 1000) root_span = Trace( name="MetaAgent", kind="agent", start_time_ms=start_time_ms, metadata={"user": "🤗 Space"}, ) # get retrieval context(s) from llama-index vectorstore index # w&b trace retrieval & select agent agent_span = Trace( name="LlamaIndexAgent", kind="agent", start_time_ms=start_time_ms, ) try: # No history, single context retrieval without evaluation if not history: # w&b trace retrieval context result_query_only = self.retrieve_context(query) # async version # result_query_only = asyncio.run(self.retrieve_context(query)) context_from_index_selected = result_query_only["result"] agent_end_time_ms = round(datetime.datetime.now().timestamp() * 1000) retrieval_span = Trace( name="QueryRetrieval", kind="chain", status_code="success", metadata={ "framework": "Llama-Index", "index_type": "VectorStoreIndex", "vector_store": "Pinecone", "vector_store_index": self._index_name, "vector_store_namespace": self._index_namespace, "model_name": self.llm._model_name, "custom_kwargs": self.llm._model_kwargs, }, start_time_ms=start_time_ms, end_time_ms=agent_end_time_ms, inputs={"query": query}, outputs={"response": context_from_index_selected}, ) agent_span.add_child(retrieval_span) # Has history, multiple context retrieval with async, then evaluation to determine which context to choose else: extended_query = f"[History]\n{history[-1]}\n[New Query]\n{query}" # thread version with ThreadPoolExecutor(2) as executor: results = executor.map(self.retrieve_context, [query, extended_query]) result_query_only, result_extended_query = [rec for rec in results] # async version - not working # result_query_only, result_extended_query = asyncio.run( # self.aretrieve_context_multi([query, extended_query]) # ) # w&b trace retrieval context query only retrieval_query_span = Trace( name="QueryRetrieval", kind="chain", status_code="success", metadata={ "framework": "Llama-Index", "index_type": "VectorStoreIndex", "vector_store": "Pinecone", "vector_store_index": self._index_name, "vector_store_namespace": self._index_namespace, "model_name": self.llm._model_name, "custom_kwargs": self.llm._model_kwargs, "start_time": result_query_only["start"], "end_time": result_query_only["end"], }, start_time_ms=result_query_only["start"], end_time_ms=result_query_only["end"], inputs={"query": query}, outputs={"response": result_query_only["result"]}, ) agent_span.add_child(retrieval_query_span) # w&b trace retrieval context extended query retrieval_extended_query_span = Trace( name="ExtendedQueryRetrieval", kind="chain", status_code="success", metadata={ "framework": "Llama-Index", "index_type": "VectorStoreIndex", "vector_store": "Pinecone", "vector_store_index": self._index_name, "vector_store_namespace": self._index_namespace, "model_name": self.llm._model_name, "custom_kwargs": self.llm._model_kwargs, "start_time": result_extended_query["start"], "end_time": result_extended_query["end"], }, start_time_ms=result_extended_query["start"], end_time_ms=result_extended_query["end"], inputs={"query": extended_query}, outputs={"response": result_extended_query["result"]}, ) agent_span.add_child(retrieval_extended_query_span) # w&b trace select context eval_start_time_ms = round(datetime.datetime.now().timestamp() * 1000) eval_context_query_only, eval_context_extended_query = asyncio.run( self.aevaluate_context_multi([query, extended_query], [result_query_only["result"], result_extended_query["result"]]) ) if eval_context_query_only.score > eval_context_extended_query.score: query_selected, context_from_index_selected = query, result_query_only["result"] else: query_selected, context_from_index_selected = extended_query, result_extended_query["result"] agent_end_time_ms = round(datetime.datetime.now().timestamp() * 1000) eval_span = Trace( name="EmbeddingsEvaluator", kind="tool", status_code="success", metadata={ "framework": "Llama-Index", "evaluator": "SemanticSimilarityEvaluator", "similarity_mode": "DEFAULT", "similarity_threshold": 0.7, "similarity_results": { "eval_context_query_only": eval_context_query_only.score, "eval_context_extended_query": eval_context_extended_query.score, }, "model_name": self.emb_model._model_name, }, start_time_ms=eval_start_time_ms, end_time_ms=agent_end_time_ms, inputs={"query": query_selected}, outputs={"response": context_from_index_selected}, ) agent_span.add_child(eval_span) except Exception as e: logger.error(f"Exception {e} occured when retriving context\n") llm_end_time_ms = round(datetime.datetime.now().timestamp() * 1000) result = "Something went wrong. Please try again later." root_span.add_inputs_and_outputs( inputs={"query": query}, outputs={"result": result, "exception": e} ) root_span._span.status_code="fail" root_span._span.end_time_ms = llm_end_time_ms root_span.log(name="llm_app_trace") return result logger.info(f"Context from Llama-Index:\n{context_from_index_selected}\n") agent_span.add_inputs_and_outputs( inputs={"query": query}, outputs={"result": context_from_index_selected} ) agent_span._span.status_code="success" agent_span._span.end_time_ms = agent_end_time_ms root_span.add_child(agent_span) # generate text with prompt template to roleplay myself prompt_with_context = self.prompt_template.format(context_history=context_history, context_from_index=context_from_index_selected, user_query=query) try: response = genai.generate_text( prompt=prompt_with_context, safety_settings=[ { 'category': genai.types.HarmCategory.HARM_CATEGORY_UNSPECIFIED, 'threshold': genai.types.HarmBlockThreshold.BLOCK_NONE, }, ], temperature=0.9, ) result = response.result success_flag = "success" if result is None: result = "Seems something went wrong. Please try again later." logger.error(f"Result with 'None' received\n") success_flag = "fail" except Exception as e: result = "Seems something went wrong. Please try again later." logger.error(f"Exception {e} occured\n") success_flag = "fail" # w&b trace llm llm_end_time_ms = round(datetime.datetime.now().timestamp() * 1000) llm_span = Trace( name="LLM", kind="llm", status_code=success_flag, start_time_ms=agent_end_time_ms, end_time_ms=llm_end_time_ms, inputs={"input": prompt_with_context}, outputs={"result": result}, ) root_span.add_child(llm_span) # w&b finalize trace root_span.add_inputs_and_outputs( inputs={"query": query}, outputs={"result": result} ) root_span._span.end_time_ms = llm_end_time_ms root_span.log(name="llm_app_trace") return result