feat: for converting v2, added lines to save model weights and print config
Browse files- convert_v2_weights.py +8 -1
convert_v2_weights.py
CHANGED
@@ -131,6 +131,12 @@ new_state_dict = remap_state_dict(state_dict, config)
|
|
131 |
flash_model = BertModel(config)
|
132 |
flash_model.load_state_dict(new_state_dict)
|
133 |
|
|
|
|
|
|
|
|
|
|
|
|
|
134 |
tokenizer = AutoTokenizer.from_pretrained('jinaai/jina-embeddings-v2-base-en')
|
135 |
inp = tokenizer.batch_encode_plus(['Hello world', 'How is the weather today?', 'It is raining a lot in Berlin'], return_tensors='pt', padding=True).to('cuda')
|
136 |
v2_model.eval()
|
@@ -141,4 +147,5 @@ output_v2 = v2_model(**inp)
|
|
141 |
output_flash = flash_model(**inp)
|
142 |
x = output_v2.last_hidden_state
|
143 |
y = output_flash.last_hidden_state
|
144 |
-
print(torch.abs(x - y))
|
|
|
|
131 |
flash_model = BertModel(config)
|
132 |
flash_model.load_state_dict(new_state_dict)
|
133 |
|
134 |
+
|
135 |
+
torch.save(new_state_dict, 'converted_weights.bin')
|
136 |
+
print(config.to_json_string())
|
137 |
+
|
138 |
+
|
139 |
+
"""
|
140 |
tokenizer = AutoTokenizer.from_pretrained('jinaai/jina-embeddings-v2-base-en')
|
141 |
inp = tokenizer.batch_encode_plus(['Hello world', 'How is the weather today?', 'It is raining a lot in Berlin'], return_tensors='pt', padding=True).to('cuda')
|
142 |
v2_model.eval()
|
|
|
147 |
output_flash = flash_model(**inp)
|
148 |
x = output_v2.last_hidden_state
|
149 |
y = output_flash.last_hidden_state
|
150 |
+
print(torch.abs(x - y))
|
151 |
+
"""
|