Spaces:
Runtime error
Runtime error
#!/usr/bin/env python | |
import json | |
import logging | |
import os | |
import sys | |
import s3fs | |
import torch | |
from langchain.embeddings.huggingface import HuggingFaceEmbeddings | |
from llama_index import (ServiceContext, StorageContext, | |
load_index_from_storage, set_global_service_context) | |
from llama_index.agent import ContextRetrieverOpenAIAgent, OpenAIAgent | |
from llama_index.indices.vector_store import VectorStoreIndex | |
from llama_index.llms import ChatMessage, MessageRole, OpenAI | |
from llama_index.prompts import ChatPromptTemplate, PromptTemplate | |
from llama_index.query_engine import RetrieverQueryEngine | |
from llama_index.response_synthesizers import get_response_synthesizer | |
from llama_index.retrievers import RecursiveRetriever | |
from llama_index.tools import QueryEngineTool, ToolMetadata | |
from llama_index.vector_stores import PGVectorStore | |
from sqlalchemy import make_url | |
# logging.basicConfig(stream=sys.stdout, level=logging.DEBUG) | |
# logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout)) | |
def get_embed_model(): | |
model_kwargs = {'device': 'cpu'} | |
if torch.cuda.is_available(): | |
model_kwargs['device'] = 'cuda' | |
if torch.backends.mps.is_available(): | |
model_kwargs['device'] = 'mps' | |
encode_kwargs = {'normalize_embeddings': True} # set True to compute cosine similarity | |
print("Loading model...") | |
try: | |
model_norm = HuggingFaceEmbeddings( | |
model_name="thenlper/gte-small", | |
model_kwargs=model_kwargs, | |
encode_kwargs=encode_kwargs, | |
) | |
except Exception as exception: | |
print(f"Model not found. Loading fake model...{exception}") | |
exit() | |
print("Model loaded.") | |
return model_norm | |
QA_TEMPLATE = """ | |
You are a chatbot, able to have normal interactions as well as respond to question about my Ford F150. | |
Below are excerpts from my F150's user manual. You must only use the information in the context below to formulate your response. | |
If there is not enough information to formulate a response, you must respond with: "I'm sorry, I can't find the answer to your question." | |
{context_str} | |
{query_str} | |
""" | |
def main(): | |
embed_model = get_embed_model() | |
llm = OpenAI("gpt-4") | |
service_context = ServiceContext.from_defaults(llm=llm, embed_model=embed_model) | |
set_global_service_context(service_context) | |
AWS_KEY = "AKIAWCUHDQXX3H7PPRXN" | |
AWS_SECRET = "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" | |
s3 = s3fs.S3FileSystem( | |
key=AWS_KEY, | |
secret=AWS_SECRET, | |
) | |
titles = s3.ls("f150-user-manual/recursive-agent/") | |
titles = list(map(lambda x: x.split("/")[-1], titles)) | |
agents = {} | |
for title in titles[:5]: | |
if(title == "vector_index"): | |
continue | |
print(title) | |
# build vector index | |
storage_context = StorageContext.from_defaults(persist_dir=f"f150-user-manual/recursive-agent/{title}/vector_index", fs=s3) | |
vector_index = load_index_from_storage(storage_context) | |
# define query engines | |
vector_query_engine = vector_index.as_query_engine( | |
similarity_top_k=2, | |
verbose=True, | |
) | |
agents[title] = vector_query_engine | |
print(f"Agents: {len(agents)}") | |
storage_context = StorageContext.from_defaults(persist_dir=f"f150-user-manual/recursive-agent/vector_index", fs=s3) | |
top_level_vector_index = load_index_from_storage(storage_context) | |
vector_retriever = top_level_vector_index.as_retriever(similarity_top_k=1) | |
recursive_retriever = RecursiveRetriever( | |
"vector", | |
retriever_dict={"vector": vector_retriever}, | |
query_engine_dict=agents, | |
verbose=True, | |
query_response_tmpl="{response}" | |
) | |
# response_synthesizer = get_response_synthesizer( | |
# response_mode="compact_accumulate", | |
# ) | |
# query_engine = RetrieverQueryEngine.from_args( | |
# recursive_retriever, | |
# similarity_top_k=1, | |
# response_synthesizer=response_synthesizer, | |
# service_context=service_context, | |
# ) | |
while True: | |
try: | |
# Read | |
user_input = input(">>> ") | |
# Evaluate and Print | |
if user_input == 'exit': | |
break | |
else: | |
response = recursive_retriever.retrieve(user_input) | |
print(response[0].get_text()) | |
except Exception as e: | |
# Handle exceptions | |
print("Error:", e) | |
if __name__ == '__main__': | |
main() |