Edit model card

This is the ONNX model of sentence-transformers/gtr-t5-xl Large Dual Encoders Are Generalizable Retrievers. Currently, Hugging Face does not support downloading ONNX files with external format files. I have created a workaround using sbert and optimum together to generate embeddings.

pip install onnx
pip install onnxruntime==1.10.0
pip install transformers>4.6.1
pip install sentencepiece
pip install sentence-transformers
pip install optimum
pip install torch==1.9.0

Then you can use the model like this:


import os
from sentence_transformers.util import snapshot_download
from transformers import AutoTokenizer
from optimum.onnxruntime import ORTModelForFeatureExtraction
from sentence_transformers.models import Transformer, Pooling, Dense
import torch
from transformers.modeling_outputs import BaseModelOutput
import torch.nn.functional as F
import shutil

model_name = 'vamsibanda/sbert-onnx-gtr-t5-xl'
cache_folder = './'
model_path =  os.path.join(cache_folder, model_name.replace("/", "_"))

def generate_embedding(text):
    token = tokenizer(text, return_tensors='pt')
    embeddings = model(input_ids=token['input_ids'], attention_mask=token['attention_mask'])
    sbert_embeddings = mean_pooling(embeddings, token['attention_mask'])
    sbert_embeddings = dense_layer.forward({'sentence_embedding':sbert_embeddings})
    sbert_embeddings = F.normalize(sbert_embeddings['sentence_embedding'], p=2, dim=1)
    return sbert_embeddings.tolist()[0]
    
def download_onnx_model(model_name, cache_folder, model_path, force_download = False):
    if force_download and os.path.exists(model_path):
       shutil.rmtree(model_path)
    elif os.path.exists(model_path):
       return 
    snapshot_download(model_name,
                     cache_dir=cache_folder,
                     library_name='sentence-transformers'
                     )
    return
      
def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0] #First element of model_output contains all token embeddings
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

def generate_embedding(text):
    token = tokenizer(text, return_tensors='pt')
    embedding = model(input_ids=token['input_ids'], attention_mask=token['attention_mask'])
    embedding = mean_pooling(embedding, token['attention_mask'])
    embedding = dense_layer.forward({'sentence_embedding':embedding})
    embedding = F.normalize(embedding['sentence_embedding'], p=2, dim=1)
    return embedding.tolist()[0]
                    
     
_ = download_onnx_model(model_name, cache_folder, model_path)                
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = ORTModelForFeatureExtraction.from_pretrained(model_path, force_download=False)
pooling_layer = Pooling.load(f"{model_path}/1_Pooling")
dense_layer = Dense.load(f"{model_path}/2_Dense")

generate_embedding('That is a happy person')
                    
Downloads last month
9
Inference Examples
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social visibility and check back later, or deploy to Inference Endpoints (dedicated) instead.