ai-forever
commited on
Commit
•
3248018
1
Parent(s):
bbe14ec
Update README.md
Browse files
README.md
CHANGED
@@ -4,55 +4,111 @@ language:
|
|
4 |
- ru
|
5 |
- en
|
6 |
tags:
|
7 |
-
-
|
8 |
-
-
|
9 |
---
|
10 |
|
11 |
-
#
|
12 |
-
|
13 |
-
Russian
|
14 |
-
|
15 |
-
For
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
```python
|
23 |
-
from transformers import AutoTokenizer, AutoModel
|
24 |
import torch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
|
26 |
-
#You might to use two variants of mode for embeddings creation:
|
27 |
-
#CLS token embs or MEAN Pooling.
|
28 |
-
#You can choose embs pooling with best quality for your downstream tasks.
|
29 |
-
|
30 |
-
#Mean Pooling example - Take attention mask into account for correct averaging
|
31 |
-
def mean_pooling(model_output, attention_mask):
|
32 |
-
token_embeddings = model_output[0] #First element of model_output contains all token embeddings
|
33 |
-
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
|
34 |
-
sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
|
35 |
-
sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
|
36 |
-
return sum_embeddings / sum_mask
|
37 |
-
|
38 |
-
#Sentences we want sentence embeddings for
|
39 |
-
sentences = ['Привет! Как твои дела?',
|
40 |
-
'А правда, что 42 твое любимое число?']
|
41 |
-
#Load AutoModel from huggingface model repository
|
42 |
tokenizer = AutoTokenizer.from_pretrained("ai-forever/ru-en-RoSBERTa")
|
43 |
model = AutoModel.from_pretrained("ai-forever/ru-en-RoSBERTa")
|
44 |
-
#Tokenize sentences
|
45 |
-
encoded_input = tokenizer(sentences, padding=True, truncation=True, max_length=512, return_tensors='pt')
|
46 |
|
47 |
-
|
|
|
48 |
with torch.no_grad():
|
49 |
-
|
50 |
|
51 |
-
|
52 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
|
54 |
-
|
55 |
-
last_hidden_states = model_output[0]
|
56 |
-
sentence_cls_embeddings = last_hidden_states[:,0]
|
57 |
|
58 |
-
|
|
|
4 |
- ru
|
5 |
- en
|
6 |
tags:
|
7 |
+
- transformers
|
8 |
+
- sentence-transformers
|
9 |
---
|
10 |
|
11 |
+
# Model Card for ru-en-RoSBERTa
|
12 |
+
|
13 |
+
The ru-en-RoSBERTa is a general text embedding model for Russian. The model is based on [ruRoBERTa](https://huggingface.co/ai-forever/ruRoberta-large) and fine-tuned with ~4M pairs of supervised, synthetic and unsupervised data in Russian and English. Tokenizer supports some English tokens from [RoBERTa](https://huggingface.co/FacebookAI/roberta-large) tokenizer.
|
14 |
+
|
15 |
+
For more model details please refer to our [article](arxiv).
|
16 |
+
|
17 |
+
## Usage
|
18 |
+
|
19 |
+
The model can be used as is with prefixes. It is recommended to use CLS pooling. The choice of prefix and pooling depends on the task.
|
20 |
+
|
21 |
+
We use the following basic rules to choose a prefix:
|
22 |
+
- `"search_query: "` and `"search_document: "` prefixes are for answer or relevant paragraph retrieval
|
23 |
+
- `"classification: "` prefix is for symmetric paraphrasing related tasks (STS, NLI, Bitext Mining)
|
24 |
+
- `"clustering: "` prefix is for any tasks that rely on thematic features (topic classification, title-body retrieval)
|
25 |
+
|
26 |
+
To better tailor the model to your needs, you can fine-tune it with relevant high-quality Russian and English datasets.
|
27 |
+
|
28 |
+
Below are examples of texts encoding using the Transformers and SentenceTransformers libraries.
|
29 |
+
|
30 |
+
### Transformers
|
31 |
+
|
32 |
```python
|
|
|
33 |
import torch
|
34 |
+
import torch.nn.functional as F
|
35 |
+
from transformers import AutoTokenizer, AutoModel
|
36 |
+
|
37 |
+
|
38 |
+
def pool(hidden_state, mask, pooling_method="cls"):
|
39 |
+
if pooling_method == "mean":
|
40 |
+
s = torch.sum(hidden_state * mask.unsqueeze(-1).float(), dim=1)
|
41 |
+
d = mask.sum(axis=1, keepdim=True).float()
|
42 |
+
return s / d
|
43 |
+
elif pooling_method == "cls":
|
44 |
+
return hidden_state[:, 0]
|
45 |
+
|
46 |
+
inputs = [
|
47 |
+
#
|
48 |
+
"classification: Он нам и <unk> не нужон ваш Интернет!",
|
49 |
+
"clustering: В Ярославской области разрешили работу бань, но без посетителей",
|
50 |
+
"search_query: Сколько программистов нужно, чтобы вкрутить лампочку?",
|
51 |
+
|
52 |
+
#
|
53 |
+
"classification: What a time to be alive!",
|
54 |
+
"clustering: Ярославским баням разрешили работать без посетителей",
|
55 |
+
"search_document: Чтобы вкрутить лампочку, требуется три программиста: один напишет программу извлечения лампочки, другой — вкручивания лампочки, а третий проведет тестирование.",
|
56 |
+
]
|
57 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
tokenizer = AutoTokenizer.from_pretrained("ai-forever/ru-en-RoSBERTa")
|
59 |
model = AutoModel.from_pretrained("ai-forever/ru-en-RoSBERTa")
|
|
|
|
|
60 |
|
61 |
+
tokenized_inputs = tokenizer(inputs, max_length=512, padding=True, truncation=True, return_tensors="pt")
|
62 |
+
|
63 |
with torch.no_grad():
|
64 |
+
outputs = model(**tokenized_inputs)
|
65 |
|
66 |
+
embeddings = pool(
|
67 |
+
outputs.last_hidden_state,
|
68 |
+
tokenized_inputs["attention_mask"],
|
69 |
+
pooling_method="cls" # or try "mean"
|
70 |
+
)
|
71 |
+
|
72 |
+
embeddings = F.normalize(embeddings, p=2, dim=1)
|
73 |
+
|
74 |
+
sim_scores = embeddings[:3] @ embeddings[3:].T
|
75 |
+
print(sim_scores.diag().tolist())
|
76 |
+
# [0.4796873927116394, 0.9409002065658569, 0.7761015892028809]
|
77 |
+
```
|
78 |
+
|
79 |
+
### SentenceTransformers
|
80 |
+
|
81 |
+
```python
|
82 |
+
from sentence_transformers import SentenceTransformer
|
83 |
+
|
84 |
+
|
85 |
+
inputs = [
|
86 |
+
#
|
87 |
+
"classification: Он нам и <unk> не нужон ваш Интернет!",
|
88 |
+
"clustering: В Ярославской области разрешили работу бань, но без посетителей",
|
89 |
+
"search_query: Сколько программистов нужно, чтобы вкрутить лампочку?",
|
90 |
+
|
91 |
+
#
|
92 |
+
"classification: What a time to be alive!",
|
93 |
+
"clustering: Ярославским баням разрешили работать без посетителей",
|
94 |
+
"search_document: Чтобы вкрутить лампочку, требуется три программиста: один напишет программу извлечения лампочки, другой — вкручивания лампочки, а третий проведет тестирование.",
|
95 |
+
]
|
96 |
+
|
97 |
+
# loads model with CLS pooling
|
98 |
+
model = SentenceTransformer("ai-forever/ru-en-RoSBERTa")
|
99 |
+
|
100 |
+
# embeddings are normalized by default
|
101 |
+
embeddings = model.encode(inputs, convert_to_tensor=True)
|
102 |
+
|
103 |
+
sim_scores = embeddings[:3] @ embeddings[3:].T
|
104 |
+
print(sim_scores.diag().tolist())
|
105 |
+
# [0.47968706488609314, 0.940900444984436, 0.7761018872261047]
|
106 |
+
```
|
107 |
+
|
108 |
+
## Citation
|
109 |
+
|
110 |
+
TODO
|
111 |
|
112 |
+
## Limitations
|
|
|
|
|
113 |
|
114 |
+
The model is designed to process texts in Russian, the quality in English is unknown. Maximum input text length is limited to 512 tokens.
|