mishig's picture
mishig HF staff
Create app.py
38585cf verified
raw
history blame
2.38 kB
import time
import os
import gradio as gr
import torch
from transformers import AutoModel, AutoTokenizer
import meilisearch
tokenizer = AutoTokenizer.from_pretrained('Snowflake/snowflake-arctic-embed-m')
model = AutoModel.from_pretrained('Snowflake/snowflake-arctic-embed-m', add_pooling_layer=False)
model.eval()
cuda_available = torch.cuda.is_available()
print(f"CUDA available: {cuda_available}")
meilisearch_client = meilisearch.Client("https://edge.meilisearch.com", os.environ["MEILISEARCH_KEY"])
meilisearch_index_name = "docs-embed"
meilisearch_index = meilisearch_client.index(meilisearch_index_name)
def search_embeddings(query_text):
start_time_embedding = time.time()
query_prefix = 'Represent this sentence for searching code documentation: '
query_tokens = tokenizer(query_prefix + query_text, padding=True, truncation=True, return_tensors='pt', max_length=512)
# step1: tokenizer the query
with torch.no_grad():
# Compute token embeddings
query_embeddings = model(**query_tokens)[0][:, 0]
# normalize embeddings
query_embeddings = torch.nn.functional.normalize(query_embeddings, p=2, dim=1)
document_embeddings_list = query_embeddings[0].tolist()
elapsed_time_embedding = time.time() - start_time_embedding
# step2: search meilisearch
start_time_meilisearch = time.time()
response = meilisearch_index.search(
"", opt_params={"vector": document_embeddings_list, "hybrid": {"semanticRatio": 1.0}, "limit": 5, "attributesToRetrieve": ["text", "source", "library"]}
)
elapsed_time_meilisearch = time.time() - start_time_meilisearch
hits = response["hits"]
# step3: present the results in markdown
md = f"Stats:\n\nembedding time: {elapsed_time_embedding:.2f}s\n\nmeilisearch time: {elapsed_time_meilisearch:.2f}s\n\n---\n\n"
for hit in hits:
text, source, library = hit["text"], hit["source"], hit["library"]
source = f"[source](https://huggingface.co/docs/{library}/{source})"
md += text + f"\n\n{source}\n\n---\n\n"
return md
demo = gr.Interface(
fn=search_embeddings,
inputs=gr.Textbox(label="enter your query", placeholder="Type Markdown here...", lines=10),
outputs=gr.Markdown(),
title="HF Docs Emebddings Explorer",
allow_flagging="never"
)
if __name__ == "__main__":
demo.launch()