Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, UploadFile, File | |
from fastapi.responses import FileResponse | |
from fastapi.middleware.cors import CORSMiddleware | |
# Loading | |
import os | |
import shutil | |
from os import makedirs,getcwd | |
from os.path import join,exists,dirname | |
from datasets import load_dataset | |
import torch | |
from tqdm import tqdm | |
from sentence_transformers import SentenceTransformer | |
import uuid | |
from qdrant_client import models, QdrantClient | |
from itertools import islice | |
from tqdm import tqdm | |
# The file where NeuralSearcher is stored | |
from neural_searcher import NeuralSearcher | |
# The file where HybridSearcher is stored | |
from hybrid_searcher import HybridSearcher | |
app = FastAPI() | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
FILEPATH_PATTERN = "structured_data_doc.parquet" | |
NUM_PROC = os.cpu_count() | |
parent_path = dirname(getcwd()) | |
temp_path = join(parent_path,'temp') | |
if not exists(temp_path ): | |
makedirs(temp_path ) | |
# Determine device based on GPU availability | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
print(f"Using device: {device}") | |
# Load the desired model | |
model = SentenceTransformer( | |
'sentence-transformers/all-MiniLM-L6-v2', | |
device=device | |
) | |
# Create function to upsert embeddings in batches | |
def batched(iterable, n): | |
iterator = iter(iterable) | |
while batch := list(islice(iterator, n)): | |
yield batch | |
batch_size = 100 | |
# Create an in-memory Qdrant instance | |
client2 = QdrantClient(path="database") | |
# Create a Qdrant collection for the embeddings | |
client2.create_collection( | |
collection_name="law", | |
vectors_config=models.VectorParams( | |
size=model.get_sentence_embedding_dimension(), | |
distance=models.Distance.COSINE, | |
), | |
) | |
# Create function to generate embeddings (in batches) for a given dataset split | |
def generate_embeddings(dataset, text_field, batch_size=32): | |
embeddings = [] | |
with tqdm(total=len(dataset), desc=f"Generating embeddings for dataset") as pbar: | |
for i in range(0, len(dataset), batch_size): | |
print(dataset) | |
batch_sentences = dataset[text_field][i:i+batch_size] | |
batch_embeddings = model.encode(batch_sentences) | |
embeddings.extend(batch_embeddings) | |
pbar.update(len(batch_sentences)) | |
return embeddings | |
async def create_upload_file(text_field: str, file: UploadFile = File(...)): | |
import time | |
start_time = time.time() | |
file_savePath = join(temp_path,file.filename) | |
with open(file_savePath,'wb') as f: | |
shutil.copyfileobj(file.file, f) | |
# Here you can save the file and do other operations as needed | |
if '.json' in file_savePath: | |
full_dataset = load_dataset('json', | |
data_files=file_savePath, | |
split="train", | |
cache_dir=temp_path, | |
keep_in_memory=True, | |
num_proc=NUM_PROC*2) | |
elif '.parquet' in file_savePath: | |
full_dataset = load_dataset("parquet", | |
data_files=file_savePath, | |
split="train", | |
cache_dir=temp_path, | |
keep_in_memory=True, | |
num_proc=NUM_PROC*2) | |
else: | |
raise NotImplementedError("This feature is not supported yet") | |
# Generate and append embeddings to the train split | |
law_embeddings = generate_embeddings(full_dataset, text_field) | |
full_dataset= full_dataset.add_column("embeddings", law_embeddings) | |
if not 'uuid' in full_dataset.column_names: | |
full_dataset = full_dataset.add_column('uuid', [str(uuid.uuid4()) for _ in range(len(full_dataset))]) | |
# Upsert the embeddings in batches | |
for batch in batched(full_dataset, batch_size): | |
ids = [point.pop("uuid") for point in batch] | |
vectors = [point.pop("embeddings") for point in batch] | |
client2.upsert( | |
collection_name=collection_name, | |
points=models.Batch( | |
ids=ids, | |
vectors=vectors, | |
payloads=batch, | |
), | |
) | |
end_time = time.time() | |
elapsed_time = end_time - start_time | |
return {"filename": file.filename, "message": "Done", "execution_time": elapsed_time} | |
async def upload_file_4_hyper_search(collection_name: str, text_field: str, file: UploadFile = File(...)): | |
import time | |
start_time = time.time() | |
file_savePath = join(temp_path,file.filename) | |
client2.set_model("sentence-transformers/all-MiniLM-L6-v2") | |
# comment this line to use dense vectors only | |
client2.set_sparse_model("prithivida/Splade_PP_en_v1") | |
with open(file_savePath,'wb') as f: | |
shutil.copyfileobj(file.file, f) | |
print(f"Uploaded complete!") | |
client2.recreate_collection( | |
collection_name=collection_name, | |
vectors_config=client2.get_fastembed_vector_params(), | |
# comment this line to use dense vectors only | |
sparse_vectors_config=client2.get_fastembed_sparse_vector_params(), | |
) | |
print(f"The collection is created complete!") | |
# Here you can save the file and do other operations as needed | |
if '.json' in file_savePath: | |
import json | |
import uuid | |
# Define your batch size | |
batch_size = 100 | |
metadata = [] | |
documents = [] | |
with open(file_savePath) as fd: | |
for line in fd: | |
obj = json.loads(line) | |
documents.append(obj.pop(text_field)) | |
metadata.append(obj) | |
# Generate UUIDs for each document | |
document_ids = [str(uuid.uuid4()) for _ in range(len(documents))] | |
# Split documents and metadata into batches | |
for i in range(0, len(documents), batch_size): | |
batch_documents = documents[i:i + batch_size] | |
batch_metadata = metadata[i:i + batch_size] | |
batch_ids = document_ids[i:i + batch_size] | |
# Upsert the embeddings in batches | |
client2.add( | |
collection_name=collection_name, | |
documents=batch_documents, | |
metadata=batch_metadata, | |
ids=batch_ids, | |
) | |
print(f"The documents and metadata are parsed and upserted in batches with unique UUIDs: {batch_ids}!") | |
print(f"The documents and metadata are parsed and upserted in batches of {batch_size} with unique UUIDs!") | |
print(f"The documents and metadata is upserted complete!") | |
else: | |
raise NotImplementedError("This feature is not supported yet") | |
end_time = time.time() | |
elapsed_time = end_time - start_time | |
return {"filename": file.filename, "message": "Done", "execution_time": elapsed_time} | |
def search(prompt: str): | |
import time | |
start_time = time.time() | |
# Let's see what senators are saying about immigration policy | |
hits = client2.search( | |
collection_name="law", | |
query_vector=model.encode(prompt).tolist(), | |
limit=5 | |
) | |
for hit in hits: | |
print(hit.payload, "score:", hit.score) | |
end_time = time.time() | |
elapsed_time = end_time - start_time | |
print(f"Execution time: {elapsed_time:.6f} seconds") | |
return hits | |
async def download_database(): | |
import time | |
start_time = time.time() | |
# Path to the database directory | |
database_dir = join(os.getcwd(), 'database') | |
# Path for the zip file | |
zip_path = join(os.getcwd(), 'database.zip') | |
# Create a zip file of the database directory | |
shutil.make_archive(zip_path.replace('.zip', ''), 'zip', database_dir) | |
end_time = time.time() | |
elapsed_time = end_time - start_time | |
print(f"Execution time: {elapsed_time:.6f} seconds") | |
# Return the zip file as a response for download | |
return FileResponse(zip_path, media_type='application/zip', filename='database.zip') | |
def neural_search(q: str, city: str, collection_name: str): | |
import time | |
start_time = time.time() | |
# Create a neural searcher instance | |
neural_searcher = NeuralSearcher(collection_name=collection_name) | |
end_time = time.time() | |
elapsed_time = end_time - start_time | |
return {"result": neural_searcher.search(text=q, city=city), "execution_time": elapsed_time} | |
def hybrid_search(q: str, city: str, collection_name: str): | |
import time | |
start_time = time.time() | |
# Create a hybrid searcher instance | |
hybrid_searcher = HybridSearcher(collection_name=collection_name) | |
end_time = time.time() | |
elapsed_time = end_time - start_time | |
return {"result": hybrid_searcher.search(text=q, city=city), "execution_time": elapsed_time} | |
def api_home(): | |
return {'detail': 'Welcome to FastAPI Qdrant importer!'} | |