Spaces:
Sleeping
Sleeping
File size: 4,800 Bytes
a151662 4beb7b0 fe8dc94 d020550 4beb7b0 d020550 4beb7b0 d020550 4beb7b0 c066163 4beb7b0 d020550 c1c1c06 4beb7b0 fe8dc94 4beb7b0 fde2b19 4beb7b0 cfb1119 22aa66a d020550 4beb7b0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
from fastapi import FastAPI, UploadFile, File
from fastapi.responses import FileResponse
from fastapi.middleware.cors import CORSMiddleware
# Loading
import os
import shutil
from os import makedirs,getcwd
from os.path import join,exists,dirname
from datasets import load_dataset
import torch
from tqdm import tqdm
from sentence_transformers import SentenceTransformer
import uuid
from qdrant_client import models, QdrantClient
from itertools import islice
app = FastAPI()
FILEPATH_PATTERN = "structured_data_doc.parquet"
NUM_PROC = os.cpu_count()
parent_path = dirname(getcwd())
temp_path = join(parent_path,'temp')
if not exists(temp_path ):
makedirs(temp_path )
# Determine device based on GPU availability
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# Load the desired model
model = SentenceTransformer(
'sentence-transformers/all-MiniLM-L6-v2',
device=device
)
# Create function to upsert embeddings in batches
def batched(iterable, n):
iterator = iter(iterable)
while batch := list(islice(iterator, n)):
yield batch
batch_size = 100
# Create an in-memory Qdrant instance
client2 = QdrantClient(path="database")
# Create a Qdrant collection for the embeddings
client2.create_collection(
collection_name="law",
vectors_config=models.VectorParams(
size=model.get_sentence_embedding_dimension(),
distance=models.Distance.COSINE,
),
)
# Create function to generate embeddings (in batches) for a given dataset split
def generate_embeddings(dataset, batch_size=32):
embeddings = []
with tqdm(total=len(dataset), desc=f"Generating embeddings for dataset") as pbar:
for i in range(0, len(dataset), batch_size):
batch_sentences = dataset['content'][i:i+batch_size]
batch_embeddings = model.encode(batch_sentences)
embeddings.extend(batch_embeddings)
pbar.update(len(batch_sentences))
return embeddings
@app.post("/uploadfile/")
async def create_upload_file(file: UploadFile = File(...)):
file_savePath = join(temp_path,file.filename)
with open(file_savePath,'wb') as f:
shutil.copyfileobj(file.file, f)
# Here you can save the file and do other operations as needed
if '.json' in file_savePath:
full_dataset = load_dataset('json',
data_files='my_file.json',
cache_dir=temp_path,
keep_in_memory=True,
num_proc=NUM_PROC*2)
elif '.parquet' in file_savePath:
full_dataset = load_dataset("parquet",
data_files=file_savePath,
split="train",
cache_dir=temp_path,
keep_in_memory=True,
num_proc=NUM_PROC*2)
else:
raise NotImplementedError("This feature is not supported yet")
# Generate and append embeddings to the train split
law_embeddings = generate_embeddings(full_dataset)
full_dataset= full_dataset.add_column("embeddings", law_embeddings)
if not 'uuid' in full_dataset.column_names:
full_dataset = full_dataset.add_column('uuid', [str(uuid.uuid4()) for _ in range(len(full_dataset))])
# Upsert the embeddings in batches
for batch in batched(full_dataset, batch_size):
ids = [point.pop("uuid") for point in batch]
vectors = [point.pop("embeddings") for point in batch]
client2.upsert(
collection_name="law",
points=models.Batch(
ids=ids,
vectors=vectors,
payloads=batch,
),
)
return {"filename": file.filename, "message": "Done"}
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/search")
def search(prompt: str):
# Let's see what senators are saying about immigration policy
hits = client2.search(
collection_name="law",
query_vector=model.encode(prompt).tolist(),
limit=5
)
for hit in hits:
print(hit.payload, "score:", hit.score)
return hits
@app.get("/download-database/")
async def download_database():
# Path to the database directory
database_dir = join(os.getcwd(), 'database')
# Path for the zip file
zip_path = join(os.getcwd(), 'database.zip')
# Create a zip file of the database directory
shutil.make_archive(zip_path.replace('.zip', ''), 'zip', database_dir)
# Return the zip file as a response for download
return FileResponse(zip_path, media_type='application/zip', filename='database.zip')
@app.get("/")
def api_home():
return {'detail': 'Welcome to FastAPI Qdrant importer!'}
|