QDrantRAG9 / app.py
dinhquangson's picture
Update app.py
5746b0c verified
raw
history blame
11.3 kB
from fastapi import FastAPI, UploadFile, File
from fastapi.responses import FileResponse
from datasets import load_dataset
from fastapi.middleware.cors import CORSMiddleware
import pdfplumber
import pytesseract
from models import Invoice
# Loading
import os
import zipfile
import shutil
from os import makedirs,getcwd
from os.path import join,exists,dirname
import torch
import json
from haystack_integrations.document_stores.qdrant import QdrantDocumentStore
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
NUM_PROC = os.cpu_count()
parent_path = dirname(getcwd())
temp_path = join(parent_path,'temp')
if not exists(temp_path ):
makedirs(temp_path )
# Determine device based on GPU availability
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
import logging
logging.basicConfig(format="%(levelname)s - %(name)s - %(message)s", level=logging.WARNING)
logging.getLogger("haystack").setLevel(logging.INFO)
document_store = QdrantDocumentStore(
path="database",
recreate_index=True,
use_sparse_embeddings=True,
embedding_dim = 384
)
def extract_zip(zip_path, target_folder):
"""
Extracts all files from a ZIP archive and returns a list of their paths.
Args:
zip_path (str): Path to the ZIP file.
target_folder (str): Folder where the files will be extracted.
Returns:
List[str]: List of extracted file paths.
"""
extracted_files = []
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
zip_ref.extractall(target_folder)
for filename in zip_ref.namelist():
extracted_files.append(os.path.join(target_folder, filename))
return extracted_files
def extract_text_from_pdf(pdf_path):
with pdfplumber.open(pdf_path) as pdf:
text = ""
for page in pdf.pages:
text += page.extract_text()
return text
def extract_ocr_text_from_pdf(pdf_path):
from pdf2image import convert_from_path
images = convert_from_path(pdf_path)
text= ""
for image in images:
text += pytesseract.image_to_string(image,lang='vie')
return text
@app.post("/uploadfile/")
async def create_upload_file(text_field: str, file: UploadFile = File(...), ocr:bool=False):
# Imports
import time
from haystack import Document, Pipeline
from haystack.components.writers import DocumentWriter
from haystack.components.preprocessors import DocumentSplitter, DocumentCleaner
from haystack.components.joiners import DocumentJoiner
from haystack_integrations.components.retrievers.qdrant import QdrantHybridRetriever
from haystack.document_stores.types import DuplicatePolicy
from haystack_integrations.components.embedders.fastembed import (
FastembedTextEmbedder,
FastembedDocumentEmbedder,
FastembedSparseTextEmbedder,
FastembedSparseDocumentEmbedder
)
start_time = time.time()
file_savePath = join(temp_path,file.filename)
with open(file_savePath,'wb') as f:
shutil.copyfileobj(file.file, f)
documents=[]
# Here you can save the file and do other operations as needed
if '.json' in file_savePath:
with open(file_savePath) as fd:
for line in fd:
obj = json.loads(line)
document = Document(content=obj[text_field], meta=obj)
documents.append(document)
elif '.zip' in file_savePath:
extracted_files_list = extract_zip(file_savePath, temp_path)
print("Extracted files:")
for file_path in extracted_files_list:
if '.pdf' in file_path:
if ocr:
text = extract_ocr_text_from_pdf(file_path)
else:
text = extract_text_from_pdf(file_path)
obj = {text_field:text,file_path:file_path}
document = Document(content=obj[text_field], meta=obj)
documents.append(document)
else:
raise NotImplementedError("This feature is not supported yet")
# Indexing
indexing = Pipeline()
document_joiner = DocumentJoiner()
document_cleaner = DocumentCleaner()
document_splitter = DocumentSplitter(split_by="passage", split_length=1000, split_overlap=0)
indexing.add_component("document_joiner", document_joiner)
indexing.add_component("document_cleaner", document_cleaner)
indexing.add_component("document_splitter", document_splitter)
indexing.add_component("sparse_doc_embedder", FastembedSparseDocumentEmbedder(model="Qdrant/bm42-all-minilm-l6-v2-attentions"))
indexing.add_component("dense_doc_embedder", FastembedDocumentEmbedder(model="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"))
indexing.add_component("writer", DocumentWriter(document_store=document_store, policy=DuplicatePolicy.OVERWRITE))
indexing.connect("document_joiner", "document_cleaner")
indexing.connect("document_cleaner", "document_splitter")
indexing.connect("document_splitter", "sparse_doc_embedder")
indexing.connect("sparse_doc_embedder", "dense_doc_embedder")
indexing.connect("dense_doc_embedder", "writer")
indexing.run({"document_joiner": {"documents": documents}})
end_time = time.time()
elapsed_time = end_time - start_time
return {"filename": file.filename, "message": "Done", "execution_time": elapsed_time}
@app.get("/search")
def search(prompt: str):
import time
from haystack import Document, Pipeline
from haystack_integrations.components.retrievers.qdrant import QdrantHybridRetriever
from haystack_integrations.components.embedders.fastembed import (
FastembedTextEmbedder,
FastembedSparseTextEmbedder
)
from haystack.components.rankers import TransformersSimilarityRanker
from haystack.components.joiners import DocumentJoiner
from haystack.components.generators import OpenAIGenerator
from haystack.utils import Secret
from haystack.components.builders import PromptBuilder
start_time = time.time()
# Querying
template = """
Với thông tin sau, hãy trả lời câu hỏi bằng tiếng ViệtViệt.
Bối cảnh: {% for document in documents %}
{{ document.content }}
{% endfor %}
Câu hỏi: {{ question }}
Trả lời:
"""
#######################################################
""""
Given the following information, answer the question.
Context:
{% for document in documents %}
{{ document.content }}
{% endfor %}
Question: {{question}}
Answer:
"""
prompt_builder = PromptBuilder(template=template)
generator = OpenAIGenerator(
api_key=Secret.from_env_var("OCTOAI_TOKEN"),
api_base_url="https://text.octoai.run/v1",
model="meta-llama-3-70b-instruct",
generation_kwargs = {"max_tokens": 512}
)
querying = Pipeline()
querying.add_component("sparse_text_embedder", FastembedSparseTextEmbedder(model="Qdrant/bm42-all-minilm-l6-v2-attentions"))
querying.add_component("dense_text_embedder", FastembedTextEmbedder(
model="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2", prefix="Đại diện cho câu này để tìm kiếm các đoạn văn có liên quan: ")
)
querying.add_component("retriever", QdrantHybridRetriever(document_store=document_store))
querying.add_component("document_joiner", DocumentJoiner())
querying.add_component("ranker", TransformersSimilarityRanker(model="BAAI/bge-m3"))
querying.add_component("prompt_builder", prompt_builder)
querying.add_component("llm", generator)
querying.connect("sparse_text_embedder.sparse_embedding", "retriever.query_sparse_embedding")
querying.connect("dense_text_embedder.embedding", "retriever.query_embedding")
querying.connect("retriever", "document_joiner")
querying.connect("document_joiner", "ranker")
querying.connect("ranker.documents", "prompt_builder.documents")
querying.connect("prompt_builder", "llm")
querying.debug=True
results = querying.run(
{
"dense_text_embedder": {"text": prompt},
"sparse_text_embedder": {"text": prompt},
"ranker": {"query": prompt},
"prompt_builder": {"question": prompt}
}
)
end_time = time.time()
elapsed_time = end_time - start_time
print(f"Execution time: {elapsed_time:.6f} seconds")
return results
@app.get("/download-database/")
async def download_database():
import time
start_time = time.time()
# Path to the database directory
database_dir = join(os.getcwd(), 'database')
# Path for the zip file
zip_path = join(os.getcwd(), 'database.zip')
# Create a zip file of the database directory
shutil.make_archive(zip_path.replace('.zip', ''), 'zip', database_dir)
end_time = time.time()
elapsed_time = end_time - start_time
print(f"Execution time: {elapsed_time:.6f} seconds")
# Return the zip file as a response for download
return FileResponse(zip_path, media_type='application/zip', filename='database.zip')
def truncate_text(text: str) -> str:
if len(text) <= 3000:
return text
else:
return text[:3000]
@app.post("/query2metadata/")
async def extract_metadata_from_query(query: str):
from QueryMetadataExtractor import QueryMetadataExtractor
extractor = QueryMetadataExtractor()
metadata_fields = {"publisher", "publish_date", "document_type"}
return extractor.run(query, metadata_fields)
@app.post("/pdf2text/")
async def convert_upload_file(file: UploadFile = File(...)):
import pytesseract
from pdf2image import convert_from_path
from octoai.client import OctoAI
from octoai.text_gen import ChatCompletionResponseFormat, ChatMessage
file_savePath = join(temp_path,file.filename)
with open(file_savePath,'wb') as f:
shutil.copyfileobj(file.file, f)
# convert PDF to image
images = convert_from_path(file_savePath)
text=""
first_page = ""
# Extract text from images
for image in images:
ocr_text = pytesseract.image_to_string(image,lang='vie')
if first_page=="":
first_page = truncate_text(ocr_text)
text=text+ocr_text+'\n'
client = OctoAI()
completion = client.text_gen.create_chat_completion(
model="meta-llama-3-70b-instruct",
messages=[
ChatMessage(role="system", content="You are a helpful assistant."),
ChatMessage(role="user", content=first_page),
],
presence_penalty=0,
temperature=0.1,
top_p=0.9,
response_format=ChatCompletionResponseFormat(
type="json_object",
schema=Invoice.model_json_schema(),
),
)
return {'content':text,'metadate':completion.choices[0].message.content}
@app.get("/")
def api_home():
return {'detail': 'Welcome to FastAPI Qdrant importer!'}