dinhquangson commited on
Commit
05c83f3
1 Parent(s): 3fb0b01

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +93 -0
app.py CHANGED
@@ -1,12 +1,93 @@
1
  from fastapi import FastAPI, UploadFile, File
2
  from fastapi.middleware.cors import CORSMiddleware
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
  app = FastAPI()
5
 
6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  @app.post("/uploadfile/")
8
  async def create_upload_file(file: UploadFile = File(...)):
9
  # Here you can save the file and do other operations as needed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  return {"filename": file.filename, "message": "Done"}
11
 
12
  app.add_middleware(
@@ -17,6 +98,18 @@ app.add_middleware(
17
  allow_headers=["*"],
18
  )
19
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  @app.get("/")
21
  def api_home():
22
  return {'detail': 'Welcome to FastAPI Qdrant importer!'}
 
1
  from fastapi import FastAPI, UploadFile, File
2
  from fastapi.middleware.cors import CORSMiddleware
3
+ # Loading
4
+ import os
5
+ from datasets import load_dataset
6
+ import torch
7
+ from tqdm import tqdm
8
+ from sentence_transformers import SentenceTransformer
9
+ import uuid
10
+ from qdrant_client import models, QdrantClient
11
+ from itertools import islice
12
+
13
+ # Create function to upsert embeddings in batches
14
+ def batched(iterable, n):
15
+ iterator = iter(iterable)
16
+ while batch := list(islice(iterator, n)):
17
+ yield batch
18
+
19
+ batch_size = 100
20
+ # Create an in-memory Qdrant instance
21
+ client2 = QdrantClient(path ="database.db")
22
+
23
+ # Create a Qdrant collection for the embeddings
24
+ client2.create_collection(
25
+ collection_name="law",
26
+ vectors_config=models.VectorParams(
27
+ size=model.get_sentence_embedding_dimension(),
28
+ distance=models.Distance.COSINE,
29
+ ),
30
+ )
31
+
32
+ # Determine device based on GPU availability
33
+ device = "cuda" if torch.cuda.is_available() else "cpu"
34
+ print(f"Using device: {device}")
35
+
36
+ FILEPATH_PATTERN = "structured_data_doc.parquet"
37
+ CACHE_DIR = "/.cache"
38
+ NUM_PROC = os.cpu_count()
39
+
40
 
41
  app = FastAPI()
42
 
43
 
44
+
45
+ # Load the desired model
46
+ model = SentenceTransformer(
47
+ 'sentence-transformers/all-MiniLM-L6-v2',
48
+ device=device
49
+ )
50
+ # Create function to generate embeddings (in batches) for a given dataset split
51
+ def generate_embeddings(dataset, batch_size=32):
52
+ embeddings = []
53
+
54
+ with tqdm(total=len(dataset), desc=f"Generating embeddings for dataset") as pbar:
55
+ for i in range(0, len(dataset), batch_size):
56
+ batch_sentences = dataset['content'][i:i+batch_size]
57
+ batch_embeddings = model.encode(batch_sentences)
58
+ embeddings.extend(batch_embeddings)
59
+ pbar.update(len(batch_sentences))
60
+
61
+ return embeddings
62
+
63
  @app.post("/uploadfile/")
64
  async def create_upload_file(file: UploadFile = File(...)):
65
  # Here you can save the file and do other operations as needed
66
+ full_dataset = load_dataset("parquet",
67
+ data_files=FILEPATH_PATTERN,
68
+ split="train",
69
+ keep_in_memory=True,
70
+ cache_dir=CACHE_DIR,
71
+ num_proc=NUM_PROC*2)
72
+ # Generate and append embeddings to the train split
73
+ law_embeddings = generate_embeddings(full_dataset)
74
+ full_dataset= full_dataset.add_column("embeddings", law_embeddings)
75
+
76
+ if not 'uuid' in full_dataset.column_names:
77
+ full_dataset = full_dataset.add_column('uuid', [str(uuid.uuid4()) for _ in range(len(full_dataset))])
78
+ # Upsert the embeddings in batches
79
+ for batch in batched(full_dataset, batch_size):
80
+ ids = [point.pop("uuid") for point in batch]
81
+ vectors = [point.pop("embeddings") for point in batch]
82
+
83
+ client2.upsert(
84
+ collection_name="law",
85
+ points=models.Batch(
86
+ ids=ids,
87
+ vectors=vectors,
88
+ payloads=batch,
89
+ ),
90
+ )
91
  return {"filename": file.filename, "message": "Done"}
92
 
93
  app.add_middleware(
 
98
  allow_headers=["*"],
99
  )
100
 
101
+ @app.get("/search")
102
+ def search(prompt: str):
103
+ # Let's see what senators are saying about immigration policy
104
+ hits = client2.search(
105
+ collection_name="law",
106
+ query_vector=model.encode(prompt).tolist(),
107
+ limit=5
108
+ )
109
+ for hit in hits:
110
+ print(hit.payload, "score:", hit.score)
111
+ return {'detail': 'hit.payload', 'score:', hit.score}
112
+
113
  @app.get("/")
114
  def api_home():
115
  return {'detail': 'Welcome to FastAPI Qdrant importer!'}