|
--- |
|
license: mit |
|
language: |
|
- en |
|
- ru |
|
- ar |
|
- zh |
|
- fr |
|
- de |
|
- it |
|
- ja |
|
- ko |
|
- nl |
|
- pl |
|
- pt |
|
- es |
|
- th |
|
- tr |
|
library_name: sentence-transformers |
|
pipeline_tag: feature-extraction |
|
tags: |
|
- mteb |
|
- Sentence Transformers |
|
- sentence-similarity |
|
- arxiv:1803.11175 |
|
- arxiv:1907.04307 |
|
--- |
|
|
|
# Convert MUSE from TensorFlow to PyTorch |
|
|
|
This repository contains code to use mUSE (Multilingual Universal Sentence Encoder) transformer model from [TF Hub](https://www.kaggle.com/models/google/universal-sentence-encoder/tensorFlow2/multilingual-large) using **PyTorch**. |
|
|
|
> [!IMPORTANT] |
|
> **The PyTorch model can be used not only for inference, but also for additional training and fine-tuning!** |
|
|
|
Read more about the project: [GitHub](https://github.com/dayyass/muse_tf2pt/tree/main). |
|
|
|
# Usage |
|
|
|
The model is available in [HF Models](https://huggingface.co/dayyass/universal-sentence-encoder-multilingual-large-3-pytorch/tree/main) directly through `torch` (*currently, without native support from the `transformers` library*). |
|
|
|
Model initialization and usage code: |
|
```python |
|
import torch |
|
from functools import partial |
|
from architecture import MUSE |
|
from tokenizer import get_tokenizer, tokenize |
|
|
|
PATH_TO_PT_MODEL = "model.pt" |
|
PATH_TO_TF_MODEL = "universal-sentence-encoder-multilingual-large-3" |
|
|
|
tokenizer = get_tokenizer(PATH_TO_TF_MODEL) |
|
tokenize = partial(tokenize, tokenizer=tokenizer) |
|
|
|
model_torch = MUSE( |
|
num_embeddings=128010, |
|
embedding_dim=512, |
|
d_model=512, |
|
num_heads=8, |
|
) |
|
model_torch.load_state_dict( |
|
torch.load(PATH_TO_PT_MODEL) |
|
) |
|
|
|
sentence = "Hello, world!" |
|
res = model_torch(tokenize(sentence)) |
|
``` |
|
> [!NOTE] |
|
> Currently, the checkpoint of the original TF Hub model is used for tokenization, so it is loaded in the code above. |