dinhquangson commited on
Commit
4beb7b0
1 Parent(s): b68f37e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +114 -27
app.py CHANGED
@@ -1,37 +1,124 @@
1
- from fastapi import FastAPI
2
- from fastapi import File, UploadFile
3
- import shutil
4
- from modelscope.pipelines import pipeline
5
- from modelscope.utils.constant import Tasks
6
  from os import makedirs,getcwd
7
  from os.path import join,exists,dirname
8
- from modelscope.models import Model
9
- from modelscope.pipelines import pipeline
10
-
11
- model = Model.from_pretrained('damo/multi-modal_convnext-roberta-base_vldoc-embedding')
12
- doc_VL_emb_pipeline = pipeline(task='document-vl-embedding', model=model)
 
 
13
 
14
  app = FastAPI()
 
 
 
 
15
  parent_path = dirname(getcwd())
16
 
17
- temp_path = join(parent_path,'temp')
18
- if not exists(temp_path):
19
- makedirs(temp_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
- @app.post("/analyze")
22
- def pdf2images(file: UploadFile=File(...)):
23
- file_savePath = join(temp_path,file.filename)
24
 
25
- with open(file_savePath,'wb') as f:
26
- shutil.copyfileobj(file.file, f)
27
 
28
- inp = {
29
- 'images': ['./demo.png'],
30
- 'ocr_info_paths': ['./demo.json']
31
- }
32
- result = doc_VL_emb_pipeline(inp)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
- print('Results of VLDoc: ')
35
- for k, v in result.items():
36
- print(f'{k}: {v}')
37
- return result["img_embedding"],result["text_embedding"]
 
1
+ from fastapi import FastAPI, UploadFile, File
2
+ from fastapi.middleware.cors import CORSMiddleware
3
+ # Loading
4
+ import os
 
5
  from os import makedirs,getcwd
6
  from os.path import join,exists,dirname
7
+ from datasets import load_dataset
8
+ import torch
9
+ from tqdm import tqdm
10
+ from sentence_transformers import SentenceTransformer
11
+ import uuid
12
+ from qdrant_client import models, QdrantClient
13
+ from itertools import islice
14
 
15
  app = FastAPI()
16
+
17
+
18
+ FILEPATH_PATTERN = "structured_data_doc.parquet"
19
+ NUM_PROC = os.cpu_count()
20
  parent_path = dirname(getcwd())
21
 
22
+ temp_path = join(parent_path,'temp')
23
+ if not exists(temp_path ):
24
+ makedirs(temp_path )
25
+
26
+ # Determine device based on GPU availability
27
+ device = "cuda" if torch.cuda.is_available() else "cpu"
28
+ print(f"Using device: {device}")
29
+ # Load the desired model
30
+ model = SentenceTransformer(
31
+ 'sentence-transformers/all-MiniLM-L6-v2',
32
+ device=device
33
+ )
34
+
35
+ # Create function to upsert embeddings in batches
36
+ def batched(iterable, n):
37
+ iterator = iter(iterable)
38
+ while batch := list(islice(iterator, n)):
39
+ yield batch
40
+
41
+ batch_size = 100
42
+ # Create an in-memory Qdrant instance
43
+ client2 = QdrantClient(":memory:")
44
+
45
+ # Create a Qdrant collection for the embeddings
46
+ client2.create_collection(
47
+ collection_name="law",
48
+ vectors_config=models.VectorParams(
49
+ size=model.get_sentence_embedding_dimension(),
50
+ distance=models.Distance.COSINE,
51
+ ),
52
+ )
53
+
54
+
55
 
 
 
 
56
 
 
 
57
 
58
+
59
+ # Create function to generate embeddings (in batches) for a given dataset split
60
+ def generate_embeddings(dataset, batch_size=32):
61
+ embeddings = []
62
+
63
+ with tqdm(total=len(dataset), desc=f"Generating embeddings for dataset") as pbar:
64
+ for i in range(0, len(dataset), batch_size):
65
+ batch_sentences = dataset['content'][i:i+batch_size]
66
+ batch_embeddings = model.encode(batch_sentences)
67
+ embeddings.extend(batch_embeddings)
68
+ pbar.update(len(batch_sentences))
69
+
70
+ return embeddings
71
+
72
+ @app.post("/uploadfile/")
73
+ async def create_upload_file(file: UploadFile = File(...)):
74
+ # Here you can save the file and do other operations as needed
75
+ full_dataset = load_dataset("parquet",
76
+ data_files=FILEPATH_PATTERN,
77
+ split="train",
78
+ cache_path=temp_path,
79
+ keep_in_memory=True,
80
+ num_proc=NUM_PROC*2)
81
+ # Generate and append embeddings to the train split
82
+ law_embeddings = generate_embeddings(full_dataset)
83
+ full_dataset= full_dataset.add_column("embeddings", law_embeddings)
84
+
85
+ if not 'uuid' in full_dataset.column_names:
86
+ full_dataset = full_dataset.add_column('uuid', [str(uuid.uuid4()) for _ in range(len(full_dataset))])
87
+ # Upsert the embeddings in batches
88
+ for batch in batched(full_dataset, batch_size):
89
+ ids = [point.pop("uuid") for point in batch]
90
+ vectors = [point.pop("embeddings") for point in batch]
91
+
92
+ client2.upsert(
93
+ collection_name="law",
94
+ points=models.Batch(
95
+ ids=ids,
96
+ vectors=vectors,
97
+ payloads=batch,
98
+ ),
99
+ )
100
+ return {"filename": file.filename, "message": "Done"}
101
+
102
+ app.add_middleware(
103
+ CORSMiddleware,
104
+ allow_origins=["*"],
105
+ allow_credentials=True,
106
+ allow_methods=["*"],
107
+ allow_headers=["*"],
108
+ )
109
+
110
+ @app.get("/search")
111
+ def search(prompt: str):
112
+ # Let's see what senators are saying about immigration policy
113
+ hits = client2.search(
114
+ collection_name="law",
115
+ query_vector=model.encode(prompt).tolist(),
116
+ limit=5
117
+ )
118
+ for hit in hits:
119
+ print(hit.payload, "score:", hit.score)
120
+ return {'detail': 'hit.payload', 'score:': hit.score}
121
 
122
+ @app.get("/")
123
+ def api_home():
124
+ return {'detail': 'Welcome to FastAPI Qdrant importer!'}