|
import time |
|
import os |
|
import gradio as gr |
|
import torch |
|
from transformers import AutoModel, AutoTokenizer |
|
import meilisearch |
|
|
|
tokenizer = AutoTokenizer.from_pretrained('Snowflake/snowflake-arctic-embed-m') |
|
model = AutoModel.from_pretrained('Snowflake/snowflake-arctic-embed-m', add_pooling_layer=False) |
|
model.eval() |
|
|
|
cuda_available = torch.cuda.is_available() |
|
print(f"CUDA available: {cuda_available}") |
|
|
|
meilisearch_client = meilisearch.Client("https://edge.meilisearch.com", os.environ["MEILISEARCH_KEY"]) |
|
meilisearch_index_name = "docs-embed" |
|
meilisearch_index = meilisearch_client.index(meilisearch_index_name) |
|
|
|
def search_embeddings(query_text): |
|
start_time_embedding = time.time() |
|
query_prefix = 'Represent this sentence for searching code documentation: ' |
|
query_tokens = tokenizer(query_prefix + query_text, padding=True, truncation=True, return_tensors='pt', max_length=512) |
|
|
|
with torch.no_grad(): |
|
|
|
query_embeddings = model(**query_tokens)[0][:, 0] |
|
|
|
query_embeddings = torch.nn.functional.normalize(query_embeddings, p=2, dim=1) |
|
document_embeddings_list = query_embeddings[0].tolist() |
|
elapsed_time_embedding = time.time() - start_time_embedding |
|
|
|
|
|
start_time_meilisearch = time.time() |
|
response = meilisearch_index.search( |
|
"", opt_params={"vector": document_embeddings_list, "hybrid": {"semanticRatio": 1.0}, "limit": 5, "attributesToRetrieve": ["text", "source", "library"]} |
|
) |
|
elapsed_time_meilisearch = time.time() - start_time_meilisearch |
|
hits = response["hits"] |
|
|
|
|
|
md = f"Stats:\n\nembedding time: {elapsed_time_embedding:.2f}s\n\nmeilisearch time: {elapsed_time_meilisearch:.2f}s\n\n---\n\n" |
|
for hit in hits: |
|
text, source, library = hit["text"], hit["source"], hit["library"] |
|
source = f"[source](https://huggingface.co/docs/{library}/{source})" |
|
md += text + f"\n\n{source}\n\n---\n\n" |
|
|
|
return md |
|
|
|
|
|
demo = gr.Interface( |
|
fn=search_embeddings, |
|
inputs=gr.Textbox(label="enter your query", placeholder="Type Markdown here...", lines=10), |
|
outputs=gr.Markdown(), |
|
title="HF Docs Emebddings Explorer", |
|
allow_flagging="never" |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|