QDrantRAG9 / app.py
dinhquangson's picture
Update app.py
504f0a9 verified
raw
history blame
9.06 kB
from fastapi import FastAPI, UploadFile, File
from fastapi.responses import FileResponse
from fastapi.middleware.cors import CORSMiddleware
# Loading
import os
import shutil
from os import makedirs,getcwd
from os.path import join,exists,dirname
from datasets import load_dataset
import torch
from tqdm import tqdm
from sentence_transformers import SentenceTransformer
import uuid
from qdrant_client import models, QdrantClient
from itertools import islice
from tqdm import tqdm
# The file where NeuralSearcher is stored
from neural_searcher import NeuralSearcher
# The file where HybridSearcher is stored
from hybrid_searcher import HybridSearcher
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
FILEPATH_PATTERN = "structured_data_doc.parquet"
NUM_PROC = os.cpu_count()
parent_path = dirname(getcwd())
temp_path = join(parent_path,'temp')
if not exists(temp_path ):
makedirs(temp_path )
# Determine device based on GPU availability
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# Load the desired model
model = SentenceTransformer(
'sentence-transformers/all-MiniLM-L6-v2',
device=device
)
# Create function to upsert embeddings in batches
def batched(iterable, n):
iterator = iter(iterable)
while batch := list(islice(iterator, n)):
yield batch
batch_size = 100
# Create an in-memory Qdrant instance
client2 = QdrantClient(path="database")
# Create a Qdrant collection for the embeddings
client2.create_collection(
collection_name="law",
vectors_config=models.VectorParams(
size=model.get_sentence_embedding_dimension(),
distance=models.Distance.COSINE,
),
)
# Create function to generate embeddings (in batches) for a given dataset split
def generate_embeddings(dataset, text_field, batch_size=32):
embeddings = []
with tqdm(total=len(dataset), desc=f"Generating embeddings for dataset") as pbar:
for i in range(0, len(dataset), batch_size):
print(dataset)
batch_sentences = dataset[text_field][i:i+batch_size]
batch_embeddings = model.encode(batch_sentences)
embeddings.extend(batch_embeddings)
pbar.update(len(batch_sentences))
return embeddings
@app.post("/uploadfile/")
async def create_upload_file(text_field: str, file: UploadFile = File(...)):
import time
start_time = time.time()
file_savePath = join(temp_path,file.filename)
with open(file_savePath,'wb') as f:
shutil.copyfileobj(file.file, f)
# Here you can save the file and do other operations as needed
if '.json' in file_savePath:
full_dataset = load_dataset('json',
data_files=file_savePath,
split="train",
cache_dir=temp_path,
keep_in_memory=True,
num_proc=NUM_PROC*2)
elif '.parquet' in file_savePath:
full_dataset = load_dataset("parquet",
data_files=file_savePath,
split="train",
cache_dir=temp_path,
keep_in_memory=True,
num_proc=NUM_PROC*2)
else:
raise NotImplementedError("This feature is not supported yet")
# Generate and append embeddings to the train split
law_embeddings = generate_embeddings(full_dataset, text_field)
full_dataset= full_dataset.add_column("embeddings", law_embeddings)
if not 'uuid' in full_dataset.column_names:
full_dataset = full_dataset.add_column('uuid', [str(uuid.uuid4()) for _ in range(len(full_dataset))])
# Upsert the embeddings in batches
for batch in batched(full_dataset, batch_size):
ids = [point.pop("uuid") for point in batch]
vectors = [point.pop("embeddings") for point in batch]
client2.upsert(
collection_name=collection_name,
points=models.Batch(
ids=ids,
vectors=vectors,
payloads=batch,
),
)
end_time = time.time()
elapsed_time = end_time - start_time
return {"filename": file.filename, "message": "Done", "execution_time": elapsed_time}
@app.post("/uploadfile4hypersearch/")
async def upload_file_4_hyper_search(collection_name: str, text_field: str, file: UploadFile = File(...)):
import time
start_time = time.time()
file_savePath = join(temp_path,file.filename)
client2.set_model("sentence-transformers/all-MiniLM-L6-v2")
# comment this line to use dense vectors only
client2.set_sparse_model("prithivida/Splade_PP_en_v1")
with open(file_savePath,'wb') as f:
shutil.copyfileobj(file.file, f)
print(f"Uploaded complete!")
client2.recreate_collection(
collection_name=collection_name,
vectors_config=client2.get_fastembed_vector_params(),
# comment this line to use dense vectors only
sparse_vectors_config=client2.get_fastembed_sparse_vector_params(),
)
print(f"The collection is created complete!")
# Here you can save the file and do other operations as needed
if '.json' in file_savePath:
import json
import uuid
# Define your batch size
batch_size = 100
metadata = []
documents = []
with open(file_savePath) as fd:
for line in fd:
obj = json.loads(line)
documents.append(obj.pop(text_field))
metadata.append(obj)
# Generate UUIDs for each document
document_ids = [str(uuid.uuid4()) for _ in range(len(documents))]
# Split documents and metadata into batches
for i in range(0, len(documents), batch_size):
batch_documents = documents[i:i + batch_size]
batch_metadata = metadata[i:i + batch_size]
batch_ids = document_ids[i:i + batch_size]
# Upsert the embeddings in batches
client2.add(
collection_name=collection_name,
documents=batch_documents,
metadata=batch_metadata,
ids=batch_ids,
)
print(f"The documents and metadata are parsed and upserted in batches with unique UUIDs: {batch_ids}!")
print(f"The documents and metadata are parsed and upserted in batches of {batch_size} with unique UUIDs!")
print(f"The documents and metadata is upserted complete!")
else:
raise NotImplementedError("This feature is not supported yet")
end_time = time.time()
elapsed_time = end_time - start_time
return {"filename": file.filename, "message": "Done", "execution_time": elapsed_time}
@app.get("/search")
def search(prompt: str):
import time
start_time = time.time()
# Let's see what senators are saying about immigration policy
hits = client2.search(
collection_name="law",
query_vector=model.encode(prompt).tolist(),
limit=5
)
for hit in hits:
print(hit.payload, "score:", hit.score)
end_time = time.time()
elapsed_time = end_time - start_time
print(f"Execution time: {elapsed_time:.6f} seconds")
return hits
@app.get("/download-database/")
async def download_database():
import time
start_time = time.time()
# Path to the database directory
database_dir = join(os.getcwd(), 'database')
# Path for the zip file
zip_path = join(os.getcwd(), 'database.zip')
# Create a zip file of the database directory
shutil.make_archive(zip_path.replace('.zip', ''), 'zip', database_dir)
end_time = time.time()
elapsed_time = end_time - start_time
print(f"Execution time: {elapsed_time:.6f} seconds")
# Return the zip file as a response for download
return FileResponse(zip_path, media_type='application/zip', filename='database.zip')
@app.get("/neural_search")
def neural_search(q: str, city: str, collection_name: str):
import time
start_time = time.time()
# Create a neural searcher instance
neural_searcher = NeuralSearcher(collection_name=collection_name)
end_time = time.time()
elapsed_time = end_time - start_time
return {"result": neural_searcher.search(text=q, city=city), "execution_time": elapsed_time}
@app.get("/hybrid_search")
def hybrid_search(q: str, city: str, collection_name: str):
import time
start_time = time.time()
# Create a hybrid searcher instance
hybrid_searcher = HybridSearcher(collection_name=collection_name)
end_time = time.time()
elapsed_time = end_time - start_time
return {"result": hybrid_searcher.search(text=q, city=city), "execution_time": elapsed_time}
@app.get("/")
def api_home():
return {'detail': 'Welcome to FastAPI Qdrant importer!'}