Spaces:
Sleeping
Sleeping
import os | |
import re | |
import logging | |
from fastapi import FastAPI, HTTPException | |
from fastapi.responses import RedirectResponse | |
from pydantic import BaseModel | |
from langchain.chains import RetrievalQA | |
from langchain.prompts import PromptTemplate | |
from langchain_community.llms import CTransformers | |
from langchain_community.vectorstores import FAISS | |
from langchain_community.embeddings import HuggingFaceEmbeddings | |
import subprocess | |
from dotenv import load_dotenv | |
# Load environment variables | |
load_dotenv() | |
# Set up logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# FastAPI app | |
app = FastAPI() | |
# Load embeddings and vector database | |
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2", model_kwargs={"device": "cpu"}) | |
try: | |
db = FAISS.load_local("vectorstore/db_faiss", embeddings, allow_dangerous_deserialization=True) | |
logger.info("Vector database loaded successfully!") | |
except Exception as e: | |
logger.error(f"Failed to load vector database: {e}") | |
raise e | |
# Load LLM using ctransformers | |
try: | |
llm = CTransformers( | |
model="TheBloke/Llama-2-7B-Chat-GGML", | |
model_type="llama", | |
max_new_tokens=128, | |
temperature=0.5, | |
) | |
logger.info("LLM model loaded successfully!") | |
except Exception as e: | |
logger.error(f"Failed to load LLM model: {e}") | |
raise e | |
# Define custom prompt template | |
custom_prompt_template = """Use the following pieces of information to answer the user's question. | |
If you don't know the answer, just say that you don't know, don't try to make up an answer. | |
Context: {context} | |
Question: {question} | |
Only return the helpful answer below and nothing else. | |
Helpful answer: | |
""" | |
qa_prompt = PromptTemplate(template=custom_prompt_template, input_variables=["context", "question"]) | |
# Set up RetrievalQA chain | |
qa_chain = RetrievalQA.from_chain_type( | |
llm=llm, | |
chain_type="stuff", | |
retriever=db.as_retriever(search_kwargs={"k": 2}), | |
return_source_documents=True, | |
chain_type_kwargs={"prompt": qa_prompt}, | |
) | |
class QuestionRequest(BaseModel): | |
question: str | |
class AnswerResponse(BaseModel): | |
answer: str | |
def clean_answer(answer): | |
# Remove unnecessary characters and symbols | |
cleaned_answer = re.sub(r'[^\w\s.,-]', '', answer) | |
# Remove repetitive phrases by identifying repeated words or sequences | |
cleaned_answer = re.sub(r'\b(\w+)( \1\b)+', r'\1', cleaned_answer) | |
# Remove any trailing or leading spaces | |
cleaned_answer = cleaned_answer.strip() | |
# Replace multiple spaces with a single space | |
cleaned_answer = re.sub(r'\s+', ' ', cleaned_answer) | |
# Replace \n with newline character in markdown | |
cleaned_answer = re.sub(r'\\n', '\n', cleaned_answer) | |
# Check for bullet points and replace with markdown syntax | |
cleaned_answer = re.sub(r'^\s*-\s+(.*)$', r'* \1', cleaned_answer, flags=re.MULTILINE) | |
# Check for numbered lists and replace with markdown syntax | |
cleaned_answer = re.sub(r'^\s*\d+\.\s+(.*)$', r'1. \1', cleaned_answer, flags=re.MULTILINE) | |
# Check for headings and replace with markdown syntax | |
cleaned_answer = re.sub(r'^\s*(#+)\s+(.*)$', r'\1 \2', cleaned_answer, flags=re.MULTILINE) | |
return cleaned_answer | |
def format_sources(sources): | |
formatted_sources = [] | |
for source in sources: | |
metadata = source.metadata | |
page = metadata.get('page', 'Unknown page') | |
source_str = f"{metadata.get('source', 'Unknown source')}, page {page}" | |
formatted_sources.append(source_str) | |
return "\n".join(formatted_sources) | |
async def query(question_request: QuestionRequest): | |
try: | |
question = question_request.question | |
if not question: | |
raise HTTPException(status_code=400, detail="Question is required") | |
result = qa_chain({"query": question}) | |
answer = result.get("result") | |
sources = result.get("source_documents") | |
if sources: | |
formatted_sources = format_sources(sources) | |
answer += "\nSources:\n" + formatted_sources | |
else: | |
answer += "\nNo sources found" | |
# Clean up the answer | |
cleaned_answer = clean_answer(answer) | |
return {"answer": cleaned_answer} | |
except Exception as e: | |
logger.error(f"Error processing query: {e}") | |
raise HTTPException(status_code=500, detail="Internal Server Error") | |
async def root(): | |
return RedirectResponse(url="/docs") | |
#if __name__ == '__main__': | |
#uvicorn.run(app, host='0.0.0.0', port=7860) | |