f150 / src /recursive_inference.py
Adrian Cowham
removed top level query aggregator
c58b4cd
#!/usr/bin/env python
import json
import logging
import os
import sys
import s3fs
import torch
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from llama_index import (ServiceContext, StorageContext,
load_index_from_storage, set_global_service_context)
from llama_index.agent import ContextRetrieverOpenAIAgent, OpenAIAgent
from llama_index.indices.vector_store import VectorStoreIndex
from llama_index.llms import ChatMessage, MessageRole, OpenAI
from llama_index.prompts import ChatPromptTemplate, PromptTemplate
from llama_index.query_engine import RetrieverQueryEngine
from llama_index.response_synthesizers import get_response_synthesizer
from llama_index.retrievers import RecursiveRetriever
from llama_index.tools import QueryEngineTool, ToolMetadata
from llama_index.vector_stores import PGVectorStore
from sqlalchemy import make_url
# logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
# logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout))
def get_embed_model():
model_kwargs = {'device': 'cpu'}
if torch.cuda.is_available():
model_kwargs['device'] = 'cuda'
if torch.backends.mps.is_available():
model_kwargs['device'] = 'mps'
encode_kwargs = {'normalize_embeddings': True} # set True to compute cosine similarity
print("Loading model...")
try:
model_norm = HuggingFaceEmbeddings(
model_name="thenlper/gte-small",
model_kwargs=model_kwargs,
encode_kwargs=encode_kwargs,
)
except Exception as exception:
print(f"Model not found. Loading fake model...{exception}")
exit()
print("Model loaded.")
return model_norm
QA_TEMPLATE = """
You are a chatbot, able to have normal interactions as well as respond to question about my Ford F150.
Below are excerpts from my F150's user manual. You must only use the information in the context below to formulate your response.
If there is not enough information to formulate a response, you must respond with: "I'm sorry, I can't find the answer to your question."
{context_str}
{query_str}
"""
def main():
embed_model = get_embed_model()
llm = OpenAI("gpt-4")
service_context = ServiceContext.from_defaults(llm=llm, embed_model=embed_model)
set_global_service_context(service_context)
AWS_KEY = "AKIAWCUHDQXX3H7PPRXN"
AWS_SECRET = "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"
s3 = s3fs.S3FileSystem(
key=AWS_KEY,
secret=AWS_SECRET,
)
titles = s3.ls("f150-user-manual/recursive-agent/")
titles = list(map(lambda x: x.split("/")[-1], titles))
agents = {}
for title in titles[:5]:
if(title == "vector_index"):
continue
print(title)
# build vector index
storage_context = StorageContext.from_defaults(persist_dir=f"f150-user-manual/recursive-agent/{title}/vector_index", fs=s3)
vector_index = load_index_from_storage(storage_context)
# define query engines
vector_query_engine = vector_index.as_query_engine(
similarity_top_k=2,
verbose=True,
)
agents[title] = vector_query_engine
print(f"Agents: {len(agents)}")
storage_context = StorageContext.from_defaults(persist_dir=f"f150-user-manual/recursive-agent/vector_index", fs=s3)
top_level_vector_index = load_index_from_storage(storage_context)
vector_retriever = top_level_vector_index.as_retriever(similarity_top_k=1)
recursive_retriever = RecursiveRetriever(
"vector",
retriever_dict={"vector": vector_retriever},
query_engine_dict=agents,
verbose=True,
query_response_tmpl="{response}"
)
# response_synthesizer = get_response_synthesizer(
# response_mode="compact_accumulate",
# )
# query_engine = RetrieverQueryEngine.from_args(
# recursive_retriever,
# similarity_top_k=1,
# response_synthesizer=response_synthesizer,
# service_context=service_context,
# )
while True:
try:
# Read
user_input = input(">>> ")
# Evaluate and Print
if user_input == 'exit':
break
else:
response = recursive_retriever.retrieve(user_input)
print(response[0].get_text())
except Exception as e:
# Handle exceptions
print("Error:", e)
if __name__ == '__main__':
main()