File size: 1,734 Bytes
891847a
 
79e29f4
 
 
0ce37ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
599bb72
 
 
 
 
 
891847a
 
79e29f4
891847a
79e29f4
891847a
 
79e29f4
 
 
891847a
 
 
 
 
 
 
 
 
79e29f4
 
891847a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0ce37ff
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
---
license: mit
language:
- en
- ru
- ar
- zh
- fr
- de
- it
- ja
- ko
- nl
- pl
- pt
- es
- th
- tr
library_name: sentence-transformers
pipeline_tag: feature-extraction
tags:
- mteb
- Sentence Transformers
- sentence-similarity
- arxiv:1803.11175
- arxiv:1907.04307
---

# Convert MUSE from TensorFlow to PyTorch

This repository contains code to use mUSE (Multilingual Universal Sentence Encoder) transformer model from [TF Hub](https://www.kaggle.com/models/google/universal-sentence-encoder/tensorFlow2/multilingual-large) using **PyTorch**.

> [!IMPORTANT]
> **The PyTorch model can be used not only for inference, but also for additional training and fine-tuning!**

Read more about the project: [GitHub](https://github.com/dayyass/muse_tf2pt/tree/main).

# Usage

The model is available in [HF Models](https://huggingface.co/dayyass/universal-sentence-encoder-multilingual-large-3-pytorch/tree/main) directly through `torch` (*currently, without native support from the `transformers` library*).

Model initialization and usage code:
```python
import torch
from functools import partial
from architecture import MUSE
from tokenizer import get_tokenizer, tokenize

PATH_TO_PT_MODEL = "model.pt"
PATH_TO_TF_MODEL = "universal-sentence-encoder-multilingual-large-3"

tokenizer = get_tokenizer(PATH_TO_TF_MODEL)
tokenize = partial(tokenize, tokenizer=tokenizer)

model_torch = MUSE(
    num_embeddings=128010,
    embedding_dim=512,
    d_model=512,
    num_heads=8,
)
model_torch.load_state_dict(
    torch.load(PATH_TO_PT_MODEL)
)

sentence = "Hello, world!"
res = model_torch(tokenize(sentence))
```
> [!NOTE]
> Currently, the checkpoint of the original TF Hub model is used for tokenization, so it is loaded in the code above.