f150 / src /inference.py
Adrian Cowham
initial commit
cbdf795
#!/usr/bin/env python
#!/usr/bin/env python
import json
import logging
import os
import sys
import psycopg2
import s3fs
import torch
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from llama_index import ServiceContext, set_global_service_context
from llama_index.indices.vector_store import VectorStoreIndex
from llama_index.llms import OpenAI
from llama_index.prompts import PromptTemplate
from llama_index.vector_stores import PGVectorStore
from sqlalchemy import make_url
# logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
# logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout))
QA_TEMPLATE = """
You are an intelligent and helpful AI Assistant, able to have normal interactions as well as answer questions about my 2023 Ford F150.
Below are excerpts from my F150's User Manual. You must only use the information in the context below to formulate your response.
If there is not enough information to formulate a response, you must respond with: "I'm sorry, I can't find the answer to your question in the user manual."
{context_str}
{query_str}
"""
def get_embed_model():
model_kwargs = {'device': 'cpu'}
if torch.cuda.is_available():
model_kwargs['device'] = 'cuda'
if torch.backends.mps.is_available():
model_kwargs['device'] = 'mps'
encode_kwargs = {'normalize_embeddings': True} # set True to compute cosine similarity
print("Loading model...")
try:
model_norm = HuggingFaceEmbeddings(
model_name="thenlper/gte-small",
model_kwargs=model_kwargs,
encode_kwargs=encode_kwargs,
)
except Exception as exception:
print(f"Model not found. Loading fake model...{exception}")
exit()
print("Model loaded.")
return model_norm
def get_vector_store():
db_name = "helm"
connection_string = "postgresql://adrian@localhost:5432/postgres"
url = make_url(connection_string)
vector_store = PGVectorStore.from_params(
database=db_name,
host=url.host,
password=url.password,
port=url.port,
user=url.username,
table_name="f150_manual",
embed_dim=384,
hybrid_search=True,
text_search_config="english",
)
return vector_store
def main():
embed_model = get_embed_model()
llm = OpenAI("gpt-4")
service_context = ServiceContext.from_defaults(llm=llm, embed_model=embed_model)
set_global_service_context(service_context)
vector_store = get_vector_store()
vector_index = VectorStoreIndex.from_vector_store(vector_store=vector_store)
query_engine = vector_index.as_query_engine(
text_qa_template=PromptTemplate(QA_TEMPLATE),
similarity_top_k=2,
verbose=True)
# Recommended tire pressure
# Recommended oil
# Instructions on how to change a flat tire
# Fuel tank capacity and fuel grade
# How to change the keypad code.
while True:
try:
# Read
user_input = input(">>> ")
# Evaluate and Print
if user_input == 'exit':
break
else:
response = query_engine.query(user_input)
print(response)
except Exception as e:
# Handle exceptions
print("Error:", e)
if __name__ == '__main__':
main()