shahabkahn's picture
Update app.py
7d81152 verified
raw
history blame
4.59 kB
import os
import re
import logging
from fastapi import FastAPI, HTTPException
from fastapi.responses import RedirectResponse
from pydantic import BaseModel
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate
from langchain_community.llms import CTransformers
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings
import subprocess
from dotenv import load_dotenv
# Load environment variables
load_dotenv()
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# FastAPI app
app = FastAPI()
# Load embeddings and vector database
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2", model_kwargs={"device": "cpu"})
try:
db = FAISS.load_local("vectorstore/db_faiss", embeddings, allow_dangerous_deserialization=True)
logger.info("Vector database loaded successfully!")
except Exception as e:
logger.error(f"Failed to load vector database: {e}")
raise e
# Load LLM using ctransformers
try:
llm = CTransformers(
model="TheBloke/Llama-2-7B-Chat-GGML",
model_type="llama",
max_new_tokens=128,
temperature=0.5,
)
logger.info("LLM model loaded successfully!")
except Exception as e:
logger.error(f"Failed to load LLM model: {e}")
raise e
# Define custom prompt template
custom_prompt_template = """Use the following pieces of information to answer the user's question.
If you don't know the answer, just say that you don't know, don't try to make up an answer.
Context: {context}
Question: {question}
Only return the helpful answer below and nothing else.
Helpful answer:
"""
qa_prompt = PromptTemplate(template=custom_prompt_template, input_variables=["context", "question"])
# Set up RetrievalQA chain
qa_chain = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=db.as_retriever(search_kwargs={"k": 2}),
return_source_documents=True,
chain_type_kwargs={"prompt": qa_prompt},
)
class QuestionRequest(BaseModel):
question: str
class AnswerResponse(BaseModel):
answer: str
def clean_answer(answer):
# Remove unnecessary characters and symbols
cleaned_answer = re.sub(r'[^\w\s.,-]', '', answer)
# Remove repetitive phrases by identifying repeated words or sequences
cleaned_answer = re.sub(r'\b(\w+)( \1\b)+', r'\1', cleaned_answer)
# Remove any trailing or leading spaces
cleaned_answer = cleaned_answer.strip()
# Replace multiple spaces with a single space
cleaned_answer = re.sub(r'\s+', ' ', cleaned_answer)
# Replace \n with newline character in markdown
cleaned_answer = re.sub(r'\\n', '\n', cleaned_answer)
# Check for bullet points and replace with markdown syntax
cleaned_answer = re.sub(r'^\s*-\s+(.*)$', r'* \1', cleaned_answer, flags=re.MULTILINE)
# Check for numbered lists and replace with markdown syntax
cleaned_answer = re.sub(r'^\s*\d+\.\s+(.*)$', r'1. \1', cleaned_answer, flags=re.MULTILINE)
# Check for headings and replace with markdown syntax
cleaned_answer = re.sub(r'^\s*(#+)\s+(.*)$', r'\1 \2', cleaned_answer, flags=re.MULTILINE)
return cleaned_answer
def format_sources(sources):
formatted_sources = []
for source in sources:
metadata = source.metadata
page = metadata.get('page', 'Unknown page')
source_str = f"{metadata.get('source', 'Unknown source')}, page {page}"
formatted_sources.append(source_str)
return "\n".join(formatted_sources)
@app.post("/query", response_model=AnswerResponse)
async def query(question_request: QuestionRequest):
try:
question = question_request.question
if not question:
raise HTTPException(status_code=400, detail="Question is required")
result = qa_chain({"query": question})
answer = result.get("result")
sources = result.get("source_documents")
if sources:
formatted_sources = format_sources(sources)
answer += "\nSources:\n" + formatted_sources
else:
answer += "\nNo sources found"
# Clean up the answer
cleaned_answer = clean_answer(answer)
return {"answer": cleaned_answer}
except Exception as e:
logger.error(f"Error processing query: {e}")
raise HTTPException(status_code=500, detail="Internal Server Error")
@app.get("/")
async def root():
return RedirectResponse(url="/docs")
#if __name__ == '__main__':
#uvicorn.run(app, host='0.0.0.0', port=7860)