from fastapi import FastAPI, UploadFile, File from fastapi.responses import FileResponse from fastapi.middleware.cors import CORSMiddleware # Loading import os import shutil from os import makedirs,getcwd from os.path import join,exists,dirname from datasets import load_dataset import torch from tqdm import tqdm from sentence_transformers import SentenceTransformer import uuid from qdrant_client import models, QdrantClient from itertools import islice app = FastAPI() FILEPATH_PATTERN = "structured_data_doc.parquet" NUM_PROC = os.cpu_count() parent_path = dirname(getcwd()) temp_path = join(parent_path,'temp') if not exists(temp_path ): makedirs(temp_path ) # Determine device based on GPU availability device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {device}") # Load the desired model model = SentenceTransformer( 'sentence-transformers/all-MiniLM-L6-v2', device=device ) # Create function to upsert embeddings in batches def batched(iterable, n): iterator = iter(iterable) while batch := list(islice(iterator, n)): yield batch batch_size = 100 # Create an in-memory Qdrant instance client2 = QdrantClient(path="database") # Create a Qdrant collection for the embeddings client2.create_collection( collection_name="law", vectors_config=models.VectorParams( size=model.get_sentence_embedding_dimension(), distance=models.Distance.COSINE, ), ) # Create function to generate embeddings (in batches) for a given dataset split def generate_embeddings(dataset, batch_size=32): embeddings = [] with tqdm(total=len(dataset), desc=f"Generating embeddings for dataset") as pbar: for i in range(0, len(dataset), batch_size): batch_sentences = dataset['content'][i:i+batch_size] batch_embeddings = model.encode(batch_sentences) embeddings.extend(batch_embeddings) pbar.update(len(batch_sentences)) return embeddings @app.post("/uploadfile/") async def create_upload_file(file: UploadFile = File(...)): file_savePath = join(temp_path,file.filename) with open(file_savePath,'wb') as f: shutil.copyfileobj(file.file, f) # Here you can save the file and do other operations as needed if '.json' in file_savePath: full_dataset = load_dataset('json', data_files='my_file.json', cache_dir=temp_path, keep_in_memory=True, num_proc=NUM_PROC*2) elif '.parquet' in file_savePath: full_dataset = load_dataset("parquet", data_files=file_savePath, split="train", cache_dir=temp_path, keep_in_memory=True, num_proc=NUM_PROC*2) else: raise NotImplementedError("This feature is not supported yet") # Generate and append embeddings to the train split law_embeddings = generate_embeddings(full_dataset) full_dataset= full_dataset.add_column("embeddings", law_embeddings) if not 'uuid' in full_dataset.column_names: full_dataset = full_dataset.add_column('uuid', [str(uuid.uuid4()) for _ in range(len(full_dataset))]) # Upsert the embeddings in batches for batch in batched(full_dataset, batch_size): ids = [point.pop("uuid") for point in batch] vectors = [point.pop("embeddings") for point in batch] client2.upsert( collection_name="law", points=models.Batch( ids=ids, vectors=vectors, payloads=batch, ), ) return {"filename": file.filename, "message": "Done"} app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) @app.get("/search") def search(prompt: str): # Let's see what senators are saying about immigration policy hits = client2.search( collection_name="law", query_vector=model.encode(prompt).tolist(), limit=5 ) for hit in hits: print(hit.payload, "score:", hit.score) return hits @app.get("/download-database/") async def download_database(): # Path to the database directory database_dir = join(os.getcwd(), 'database') # Path for the zip file zip_path = join(os.getcwd(), 'database.zip') # Create a zip file of the database directory shutil.make_archive(zip_path.replace('.zip', ''), 'zip', database_dir) # Return the zip file as a response for download return FileResponse(zip_path, media_type='application/zip', filename='database.zip') @app.get("/") def api_home(): return {'detail': 'Welcome to FastAPI Qdrant importer!'}