Upload ONNX weights

#1
by Xenova HF staff - opened
import torch
from sentence_transformers import SentenceTransformer

class WrappedModel(torch.nn.Module):
  def __init__(self, m):
    super().__init__()
    self.embedding = m[0].embedding
  def forward(self, input_ids, attention_mask):
    indices = input_ids[attention_mask == 1]
    offsets = torch.cat([torch.tensor([0]), attention_mask.sum(dim=1)[:-1].cumsum(dim=0)])
    return self.embedding(indices, offsets)

shape = (3, 4)
input_ids = torch.tensor([1, 2, 3, 4, 5, 6, -1, -1, 1, 1, 1, 0]).view(shape)
attention_mask = torch.tensor([1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0]).view(shape)

model_id = "sentence-transformers/static-similarity-mrl-multilingual-v1"
model = SentenceTransformer(model_id)
wrapped = WrappedModel(model) # test forward pass

# Export the model
torch.onnx.export(wrapped,
                  (input_ids, attention_mask),
                  "model.onnx",
                  export_params=True,
                  opset_version=14,
                  do_constant_folding=True,
                  input_names = ['input_ids', 'attention_mask'],
                  output_names = ['sentence_embedding'],
                  dynamic_axes={
                      'input_ids' : {0 : 'batch_size', 1: 'sequence_length'},
                      'attention_mask' : {0 : 'batch_size', 1: 'sequence_length'},
                      'sentence_embedding' : {0 : 'batch_size'},
                  })
tomaarsen changed pull request status to merged

Sign up or log in to comment