use BAAI/bge-base-en-v1.5
Browse files
app.py
CHANGED
@@ -5,8 +5,8 @@ import torch
|
|
5 |
from transformers import AutoModel, AutoTokenizer
|
6 |
import meilisearch
|
7 |
|
8 |
-
tokenizer = AutoTokenizer.from_pretrained('
|
9 |
-
model = AutoModel.from_pretrained('
|
10 |
model.eval()
|
11 |
|
12 |
cuda_available = torch.cuda.is_available()
|
@@ -23,16 +23,17 @@ def search_embeddings(query_text):
|
|
23 |
# step1: tokenizer the query
|
24 |
with torch.no_grad():
|
25 |
# Compute token embeddings
|
26 |
-
|
|
|
27 |
# normalize embeddings
|
28 |
-
|
29 |
-
|
30 |
elapsed_time_embedding = time.time() - start_time_embedding
|
31 |
|
32 |
# step2: search meilisearch
|
33 |
start_time_meilisearch = time.time()
|
34 |
response = meilisearch_index.search(
|
35 |
-
"", opt_params={"vector":
|
36 |
)
|
37 |
elapsed_time_meilisearch = time.time() - start_time_meilisearch
|
38 |
hits = response["hits"]
|
|
|
5 |
from transformers import AutoModel, AutoTokenizer
|
6 |
import meilisearch
|
7 |
|
8 |
+
tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-base-en-v1.5')
|
9 |
+
model = AutoModel.from_pretrained('BAAI/bge-base-en-v1.5')
|
10 |
model.eval()
|
11 |
|
12 |
cuda_available = torch.cuda.is_available()
|
|
|
23 |
# step1: tokenizer the query
|
24 |
with torch.no_grad():
|
25 |
# Compute token embeddings
|
26 |
+
model_output = model(**query_tokens)
|
27 |
+
sentence_embeddings = model_output[0][:, 0]
|
28 |
# normalize embeddings
|
29 |
+
sentence_embeddings = torch.nn.functional.normalize(sentence_embeddings, p=2, dim=1)
|
30 |
+
sentence_embeddings_list = sentence_embeddings[0].tolist()
|
31 |
elapsed_time_embedding = time.time() - start_time_embedding
|
32 |
|
33 |
# step2: search meilisearch
|
34 |
start_time_meilisearch = time.time()
|
35 |
response = meilisearch_index.search(
|
36 |
+
"", opt_params={"vector": sentence_embeddings_list, "hybrid": {"semanticRatio": 1.0}, "limit": 5, "attributesToRetrieve": ["text", "source", "library"]}
|
37 |
)
|
38 |
elapsed_time_meilisearch = time.time() - start_time_meilisearch
|
39 |
hits = response["hits"]
|