Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, UploadFile, File | |
from fastapi.responses import FileResponse | |
from fastapi.middleware.cors import CORSMiddleware | |
# Loading | |
import os | |
import shutil | |
from os import makedirs,getcwd | |
from os.path import join,exists,dirname | |
from datasets import load_dataset | |
import torch | |
from tqdm import tqdm | |
from sentence_transformers import SentenceTransformer | |
import uuid | |
from qdrant_client import models, QdrantClient | |
from itertools import islice | |
app = FastAPI() | |
FILEPATH_PATTERN = "structured_data_doc.parquet" | |
NUM_PROC = os.cpu_count() | |
parent_path = dirname(getcwd()) | |
temp_path = join(parent_path,'temp') | |
if not exists(temp_path ): | |
makedirs(temp_path ) | |
# Determine device based on GPU availability | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
print(f"Using device: {device}") | |
# Load the desired model | |
model = SentenceTransformer( | |
'sentence-transformers/all-MiniLM-L6-v2', | |
device=device | |
) | |
# Create function to upsert embeddings in batches | |
def batched(iterable, n): | |
iterator = iter(iterable) | |
while batch := list(islice(iterator, n)): | |
yield batch | |
batch_size = 100 | |
# Create an in-memory Qdrant instance | |
client2 = QdrantClient(path="database") | |
# Create a Qdrant collection for the embeddings | |
client2.create_collection( | |
collection_name="law", | |
vectors_config=models.VectorParams( | |
size=model.get_sentence_embedding_dimension(), | |
distance=models.Distance.COSINE, | |
), | |
) | |
# Create function to generate embeddings (in batches) for a given dataset split | |
def generate_embeddings(dataset, batch_size=32): | |
embeddings = [] | |
with tqdm(total=len(dataset), desc=f"Generating embeddings for dataset") as pbar: | |
for i in range(0, len(dataset), batch_size): | |
batch_sentences = dataset['content'][i:i+batch_size] | |
batch_embeddings = model.encode(batch_sentences) | |
embeddings.extend(batch_embeddings) | |
pbar.update(len(batch_sentences)) | |
return embeddings | |
async def create_upload_file(file: UploadFile = File(...)): | |
file_savePath = join(temp_path,file.filename) | |
with open(file_savePath,'wb') as f: | |
shutil.copyfileobj(file.file, f) | |
# Here you can save the file and do other operations as needed | |
if '.json' in file_savePath: | |
full_dataset = load_dataset('json', | |
data_files='my_file.json', | |
cache_dir=temp_path, | |
keep_in_memory=True, | |
num_proc=NUM_PROC*2) | |
elif '.parquet' in file_savePath: | |
full_dataset = load_dataset("parquet", | |
data_files=file_savePath, | |
split="train", | |
cache_dir=temp_path, | |
keep_in_memory=True, | |
num_proc=NUM_PROC*2) | |
else: | |
raise NotImplementedError("This feature is not supported yet") | |
# Generate and append embeddings to the train split | |
law_embeddings = generate_embeddings(full_dataset) | |
full_dataset= full_dataset.add_column("embeddings", law_embeddings) | |
if not 'uuid' in full_dataset.column_names: | |
full_dataset = full_dataset.add_column('uuid', [str(uuid.uuid4()) for _ in range(len(full_dataset))]) | |
# Upsert the embeddings in batches | |
for batch in batched(full_dataset, batch_size): | |
ids = [point.pop("uuid") for point in batch] | |
vectors = [point.pop("embeddings") for point in batch] | |
client2.upsert( | |
collection_name="law", | |
points=models.Batch( | |
ids=ids, | |
vectors=vectors, | |
payloads=batch, | |
), | |
) | |
return {"filename": file.filename, "message": "Done"} | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
def search(prompt: str): | |
# Let's see what senators are saying about immigration policy | |
hits = client2.search( | |
collection_name="law", | |
query_vector=model.encode(prompt).tolist(), | |
limit=5 | |
) | |
for hit in hits: | |
print(hit.payload, "score:", hit.score) | |
return hits | |
async def download_database(): | |
# Path to the database directory | |
database_dir = join(os.getcwd(), 'database') | |
# Path for the zip file | |
zip_path = join(os.getcwd(), 'database.zip') | |
# Create a zip file of the database directory | |
shutil.make_archive(zip_path.replace('.zip', ''), 'zip', database_dir) | |
# Return the zip file as a response for download | |
return FileResponse(zip_path, media_type='application/zip', filename='database.zip') | |
def api_home(): | |
return {'detail': 'Welcome to FastAPI Qdrant importer!'} | |