QDrantRAG9 / app.py
dinhquangson's picture
Update app.py
22aa66a verified
raw
history blame
4.77 kB
from fastapi import FastAPI, UploadFile, File, FileResponse
from fastapi.middleware.cors import CORSMiddleware
# Loading
import os
import shutil
from os import makedirs,getcwd
from os.path import join,exists,dirname
from datasets import load_dataset
import torch
from tqdm import tqdm
from sentence_transformers import SentenceTransformer
import uuid
from qdrant_client import models, QdrantClient
from itertools import islice
app = FastAPI()
FILEPATH_PATTERN = "structured_data_doc.parquet"
NUM_PROC = os.cpu_count()
parent_path = dirname(getcwd())
temp_path = join(parent_path,'temp')
if not exists(temp_path ):
makedirs(temp_path )
# Determine device based on GPU availability
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# Load the desired model
model = SentenceTransformer(
'sentence-transformers/all-MiniLM-L6-v2',
device=device
)
# Create function to upsert embeddings in batches
def batched(iterable, n):
iterator = iter(iterable)
while batch := list(islice(iterator, n)):
yield batch
batch_size = 100
# Create an in-memory Qdrant instance
client2 = QdrantClient(path="database")
# Create a Qdrant collection for the embeddings
client2.create_collection(
collection_name="law",
vectors_config=models.VectorParams(
size=model.get_sentence_embedding_dimension(),
distance=models.Distance.COSINE,
),
)
# Create function to generate embeddings (in batches) for a given dataset split
def generate_embeddings(dataset, batch_size=32):
embeddings = []
with tqdm(total=len(dataset), desc=f"Generating embeddings for dataset") as pbar:
for i in range(0, len(dataset), batch_size):
batch_sentences = dataset['content'][i:i+batch_size]
batch_embeddings = model.encode(batch_sentences)
embeddings.extend(batch_embeddings)
pbar.update(len(batch_sentences))
return embeddings
@app.post("/uploadfile/")
async def create_upload_file(file: UploadFile = File(...)):
file_savePath = join(temp_path,file.filename)
with open(file_savePath,'wb') as f:
shutil.copyfileobj(file.file, f)
# Here you can save the file and do other operations as needed
if '.json' in file_savePath:
full_dataset = load_dataset('json',
data_files='my_file.json',
cache_dir=temp_path,
keep_in_memory=True,
num_proc=NUM_PROC*2)
elif '.parquet' in file_savePath:
full_dataset = load_dataset("parquet",
data_files=file_savePath,
split="train",
cache_dir=temp_path,
keep_in_memory=True,
num_proc=NUM_PROC*2)
else:
raise NotImplementedError("This feature is not supported yet")
# Generate and append embeddings to the train split
law_embeddings = generate_embeddings(full_dataset)
full_dataset= full_dataset.add_column("embeddings", law_embeddings)
if not 'uuid' in full_dataset.column_names:
full_dataset = full_dataset.add_column('uuid', [str(uuid.uuid4()) for _ in range(len(full_dataset))])
# Upsert the embeddings in batches
for batch in batched(full_dataset, batch_size):
ids = [point.pop("uuid") for point in batch]
vectors = [point.pop("embeddings") for point in batch]
client2.upsert(
collection_name="law",
points=models.Batch(
ids=ids,
vectors=vectors,
payloads=batch,
),
)
return {"filename": file.filename, "message": "Done"}
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/search")
def search(prompt: str):
# Let's see what senators are saying about immigration policy
hits = client2.search(
collection_name="law",
query_vector=model.encode(prompt).tolist(),
limit=5
)
for hit in hits:
print(hit.payload, "score:", hit.score)
return hits
@app.get("/download-database/")
async def download_database():
# Path to the database directory
database_dir = join(os.getcwd(), 'database')
# Path for the zip file
zip_path = join(os.getcwd(), 'database.zip')
# Create a zip file of the database directory
shutil.make_archive(zip_path.replace('.zip', ''), 'zip', database_dir)
# Return the zip file as a response for download
return FileResponse(zip_path, media_type='application/zip', filename='database.zip')
@app.get("/")
def api_home():
return {'detail': 'Welcome to FastAPI Qdrant importer!'}