fixed GLU implementation, added conversion of layer norms
Browse files- convert_v2_weights.py +19 -1
- mlp.py +3 -2
convert_v2_weights.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
import re
|
2 |
from collections import OrderedDict
|
3 |
-
from transformers import AutoModel
|
4 |
from .configuration_bert import JinaBertConfig
|
5 |
import torch
|
6 |
from .modeling_bert import BertModel
|
@@ -115,6 +115,12 @@ def remap_state_dict(state_dict, config: JinaBertConfig):
|
|
115 |
decoder_bias, (0, config.vocab_size - decoder_bias.shape[0]), value=-100.0
|
116 |
)
|
117 |
|
|
|
|
|
|
|
|
|
|
|
|
|
118 |
return state_dict
|
119 |
|
120 |
|
@@ -124,3 +130,15 @@ state_dict = v2_model.state_dict()
|
|
124 |
new_state_dict = remap_state_dict(state_dict, config)
|
125 |
flash_model = BertModel(config)
|
126 |
flash_model.load_state_dict(new_state_dict)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import re
|
2 |
from collections import OrderedDict
|
3 |
+
from transformers import AutoModel, AutoTokenizer
|
4 |
from .configuration_bert import JinaBertConfig
|
5 |
import torch
|
6 |
from .modeling_bert import BertModel
|
|
|
115 |
decoder_bias, (0, config.vocab_size - decoder_bias.shape[0]), value=-100.0
|
116 |
)
|
117 |
|
118 |
+
# LayerNorm
|
119 |
+
def key_mapping_layernorm(key):
|
120 |
+
return re.sub(r'^encoder.layers.(\d+).mlp.layernorm.(weight|bias)', r"encoder.layers.\1.norm2.\2", key)
|
121 |
+
|
122 |
+
state_dict = OrderedDict((key_mapping_layernorm(k), v) for k, v in state_dict.items())
|
123 |
+
|
124 |
return state_dict
|
125 |
|
126 |
|
|
|
130 |
new_state_dict = remap_state_dict(state_dict, config)
|
131 |
flash_model = BertModel(config)
|
132 |
flash_model.load_state_dict(new_state_dict)
|
133 |
+
|
134 |
+
tokenizer = AutoTokenizer.from_pretrained('jinaai/jina-embeddings-v2-base-en')
|
135 |
+
inp = tokenizer.batch_encode_plus(['Hello world', 'How is the weather today?', 'It is raining a lot in Berlin'], return_tensors='pt', padding=True).to('cuda')
|
136 |
+
v2_model.eval()
|
137 |
+
flash_model.eval()
|
138 |
+
v2_model = v2_model.to('cuda', torch.float16)
|
139 |
+
flash_model = flash_model.to('cuda', torch.float16)
|
140 |
+
output_v2 = v2_model(**inp)
|
141 |
+
output_flash = flash_model(**inp)
|
142 |
+
x = output_v2.last_hidden_state
|
143 |
+
y = output_flash.last_hidden_state
|
144 |
+
print(torch.abs(x - y))
|
mlp.py
CHANGED
@@ -37,6 +37,7 @@ class GLUMLP(nn.Module):
|
|
37 |
hidden_dropout_prob=0.1
|
38 |
):
|
39 |
super().__init__()
|
|
|
40 |
self.gated_layers = nn.Linear(
|
41 |
in_features, hidden_features * 2, bias=False
|
42 |
)
|
@@ -57,8 +58,8 @@ class GLUMLP(nn.Module):
|
|
57 |
residual_connection = hidden_states
|
58 |
# compute the activation
|
59 |
hidden_states = self.gated_layers(hidden_states)
|
60 |
-
gated = hidden_states[:,
|
61 |
-
non_gated = hidden_states[:,
|
62 |
hidden_states = self.act(gated) * non_gated
|
63 |
hidden_states = self.dropout(hidden_states)
|
64 |
# multiply by the second matrix
|
|
|
37 |
hidden_dropout_prob=0.1
|
38 |
):
|
39 |
super().__init__()
|
40 |
+
self.hidden_features = hidden_features
|
41 |
self.gated_layers = nn.Linear(
|
42 |
in_features, hidden_features * 2, bias=False
|
43 |
)
|
|
|
58 |
residual_connection = hidden_states
|
59 |
# compute the activation
|
60 |
hidden_states = self.gated_layers(hidden_states)
|
61 |
+
gated = hidden_states[:, : self.hidden_features]
|
62 |
+
non_gated = hidden_states[:, self.hidden_features :]
|
63 |
hidden_states = self.act(gated) * non_gated
|
64 |
hidden_states = self.dropout(hidden_states)
|
65 |
# multiply by the second matrix
|