QDrantRAG9 / app.py
dinhquangson's picture
Update app.py
9fade90 verified
raw
history blame
7.89 kB
from fastapi import FastAPI, UploadFile, File
from fastapi.responses import FileResponse
from datasets import load_dataset
from fastapi.middleware.cors import CORSMiddleware
import pdfplumber
import pytesseract
# Loading
import os
import zipfile
import shutil
from os import makedirs,getcwd
from os.path import join,exists,dirname
import torch
import json
from haystack_integrations.document_stores.qdrant import QdrantDocumentStore
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
NUM_PROC = os.cpu_count()
parent_path = dirname(getcwd())
temp_path = join(parent_path,'temp')
if not exists(temp_path ):
makedirs(temp_path )
# Determine device based on GPU availability
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
import logging
logging.basicConfig(format="%(levelname)s - %(name)s - %(message)s", level=logging.WARNING)
logging.getLogger("haystack").setLevel(logging.INFO)
document_store = QdrantDocumentStore(
path="database",
recreate_index=True,
use_sparse_embeddings=True,
embedding_dim = 384
)
def extract_zip(zip_path, target_folder):
"""
Extracts all files from a ZIP archive and returns a list of their paths.
Args:
zip_path (str): Path to the ZIP file.
target_folder (str): Folder where the files will be extracted.
Returns:
List[str]: List of extracted file paths.
"""
extracted_files = []
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
zip_ref.extractall(target_folder)
for filename in zip_ref.namelist():
extracted_files.append(os.path.join(target_folder, filename))
return extracted_files
def extract_text_from_pdf(pdf_path):
with pdfplumber.open(pdf_path) as pdf:
text = ""
for page in pdf.pages:
text += page.extract_text()
return text
def extract_ocr_text_from_pdf(pdf_path):
from pdf2image import convert_from_path
images = convert_from_path(pdf_path)
text= ""
for image in images:
text += pytesseract.image_to_string(image,lang='vie')
return text
@app.post("/uploadfile/")
async def create_upload_file(text_field: str, file: UploadFile = File(...), ocr:bool=False):
# Imports
import time
from haystack import Document, Pipeline
from haystack.components.writers import DocumentWriter
from haystack_integrations.components.retrievers.qdrant import QdrantHybridRetriever
from haystack.document_stores.types import DuplicatePolicy
from haystack_integrations.components.embedders.fastembed import (
FastembedTextEmbedder,
FastembedDocumentEmbedder,
FastembedSparseTextEmbedder,
FastembedSparseDocumentEmbedder
)
start_time = time.time()
file_savePath = join(temp_path,file.filename)
with open(file_savePath,'wb') as f:
shutil.copyfileobj(file.file, f)
documents=[]
# Here you can save the file and do other operations as needed
if '.json' in file_savePath:
with open(file_savePath) as fd:
for line in fd:
obj = json.loads(line)
document = Document(content=obj[text_field], meta=obj)
documents.append(document)
elif '.zip' in file_savePath:
extracted_files_list = extract_zip(file_savePath, temp_path)
print("Extracted files:")
for file_path in extracted_files_list:
if '.pdf' in file_path:
if ocr:
text = extract_ocr_text_from_pdf(file_path)
else:
text = extract_text_from_pdf(file_path)
obj = {text_field:text,file_path:file_path}
document = Document(content=obj[text_field], meta=obj)
documents.append(document)
else:
raise NotImplementedError("This feature is not supported yet")
# Indexing
indexing = Pipeline()
indexing.add_component("sparse_doc_embedder", FastembedSparseDocumentEmbedder(model="prithvida/Splade_PP_en_v1"))
indexing.add_component("dense_doc_embedder", FastembedDocumentEmbedder(model="BAAI/bge-small-en-v1.5"))
indexing.add_component("writer", DocumentWriter(document_store=document_store, policy=DuplicatePolicy.OVERWRITE))
indexing.connect("sparse_doc_embedder", "dense_doc_embedder")
indexing.connect("dense_doc_embedder", "writer")
indexing.run({"sparse_doc_embedder": {"documents": documents}})
end_time = time.time()
elapsed_time = end_time - start_time
return {"filename": file.filename, "message": "Done", "execution_time": elapsed_time}
@app.get("/search")
def search(prompt: str):
import time
from haystack import Document, Pipeline
from haystack_integrations.components.retrievers.qdrant import QdrantHybridRetriever
from haystack_integrations.components.embedders.fastembed import (
FastembedTextEmbedder,
FastembedSparseTextEmbedder
)
from haystack.components.rankers import TransformersSimilarityRanker
from haystack.components.joiners import DocumentJoiner
start_time = time.time()
# Querying
querying = Pipeline()
querying.add_component("sparse_text_embedder", FastembedSparseTextEmbedder(model="prithvida/Splade_PP_en_v1"))
querying.add_component("dense_text_embedder", FastembedTextEmbedder(
model="BAAI/bge-small-en-v1.5", prefix="Represent this sentence for searching relevant passages: ")
)
querying.add_component("retriever", QdrantHybridRetriever(document_store=document_store))
querying.add_component("document_joiner", DocumentJoiner())
querying.add_component("ranker", TransformersSimilarityRanker(model="BAAI/bge-reranker-base"))
querying.connect("sparse_text_embedder.sparse_embedding", "retriever.query_sparse_embedding")
querying.connect("dense_text_embedder.embedding", "retriever.query_embedding")
querying.connect("retriever", "document_joiner")
querying.connect("document_joiner", "ranker")
question = "Cosa sono i marker tumorali?"
results = querying.run(
{"dense_text_embedder": {"text": question},
"sparse_text_embedder": {"text": question}}
)
end_time = time.time()
elapsed_time = end_time - start_time
print(f"Execution time: {elapsed_time:.6f} seconds")
return results["retriever"]["documents"]
@app.get("/download-database/")
async def download_database():
import time
start_time = time.time()
# Path to the database directory
database_dir = join(os.getcwd(), 'database')
# Path for the zip file
zip_path = join(os.getcwd(), 'database.zip')
# Create a zip file of the database directory
shutil.make_archive(zip_path.replace('.zip', ''), 'zip', database_dir)
end_time = time.time()
elapsed_time = end_time - start_time
print(f"Execution time: {elapsed_time:.6f} seconds")
# Return the zip file as a response for download
return FileResponse(zip_path, media_type='application/zip', filename='database.zip')
@app.post("/pdf2text/")
async def create_upload_file(file: UploadFile = File(...)):
import pytesseract
from pdf2image import convert_from_path
file_savePath = join(temp_path,file.filename)
with open(file_savePath,'wb') as f:
shutil.copyfileobj(file.file, f)
# convert PDF to image
images = convert_from_path(file_savePath)
text=""
# Extract text from images
for image in images:
ocr_text = pytesseract.image_to_string(image,lang='vie')
text=text+ocr_text+'\n'
return text
@app.get("/")
def api_home():
return {'detail': 'Welcome to FastAPI Qdrant importer!'}