Maxime Maxime commited on
Commit
3bd9528
1 Parent(s): 2aa1f71

add noisy embedding (#721)

Browse files

* add noisy embedding

* fix format

* Update README.md

* Update README.md

* linter issues

* caseus fixes

---------

Co-authored-by: Maxime <maxime@nope.no>

README.md CHANGED
@@ -672,6 +672,11 @@ adam_epsilon:
672
  # Gradient clipping max norm
673
  max_grad_norm:
674
 
 
 
 
 
 
675
  # Whether to bettertransformers
676
  flash_optimum:
677
  # Whether to use xformers attention patch https://github.com/facebookresearch/xformers:
 
672
  # Gradient clipping max norm
673
  max_grad_norm:
674
 
675
+ # Augmentation techniques
676
+ # NEFT https://arxiv.org/abs/2310.05914, set this to a number (paper default is 5) to add noise to embeddings
677
+ # currently only supported on Llama and Mistral
678
+ noisy_embedding_alpha:
679
+
680
  # Whether to bettertransformers
681
  flash_optimum:
682
  # Whether to use xformers attention patch https://github.com/facebookresearch/xformers:
src/axolotl/monkeypatch/llama_embeddings_hijack.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ patch to add noisy embeddings per https://arxiv.org/abs/2310.05914
3
+ """
4
+
5
+ import torch
6
+ import transformers.models.llama.modeling_llama
7
+ from transformers.utils import logging
8
+
9
+ logger = logging.get_logger(__name__)
10
+
11
+
12
+ def replace_llama_embeddings_with_uniform_distribution(noise_alpha=5):
13
+ # pylint: disable=duplicate-code
14
+ def noised_embed(orig_embed, noise_alpha, model):
15
+ def new_func(input_ids):
16
+ # during training, we add noise to the embedding
17
+ # during generation, we don't add noise to the embedding
18
+ if model.training:
19
+ embed_init = orig_embed(input_ids)
20
+ dims = torch.tensor(embed_init.size(1) * embed_init.size(2))
21
+ mag_norm = noise_alpha / torch.sqrt(dims)
22
+ return embed_init + torch.zeros_like(embed_init).uniform_(
23
+ -mag_norm, mag_norm
24
+ )
25
+ return orig_embed(input_ids)
26
+
27
+ return new_func
28
+
29
+ def post_init(orig_post_init):
30
+ def new_func(self):
31
+ orig_post_init(self)
32
+ self.embed_tokens.forward = noised_embed(
33
+ self.embed_tokens.forward, noise_alpha, self
34
+ )
35
+
36
+ return new_func
37
+
38
+ transformers.models.llama.modeling_llama.LlamaModel.post_init = post_init(
39
+ transformers.models.llama.modeling_llama.LlamaModel.post_init
40
+ )
src/axolotl/monkeypatch/mistral_embeddings_hijack.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ patch to add noisy embeddings per https://arxiv.org/abs/2310.05914
3
+ """
4
+
5
+ import torch
6
+ import transformers.models.mistral.modeling_mistral
7
+ from transformers.utils import logging
8
+
9
+ logger = logging.get_logger(__name__)
10
+
11
+
12
+ def replace_mistral_embeddings_with_uniform_distribution(noise_alpha=5):
13
+ # pylint: disable=duplicate-code
14
+ def noised_embed(orig_embed, noise_alpha, model):
15
+ def new_func(input_ids):
16
+ # during training, we add noise to the embedding
17
+ # during generation, we don't add noise to the embedding
18
+ if model.training:
19
+ embed_init = orig_embed(input_ids)
20
+ dims = torch.tensor(embed_init.size(1) * embed_init.size(2))
21
+ mag_norm = noise_alpha / torch.sqrt(dims)
22
+ return embed_init + torch.zeros_like(embed_init).uniform_(
23
+ -mag_norm, mag_norm
24
+ )
25
+ return orig_embed(input_ids)
26
+
27
+ return new_func
28
+
29
+ def post_init(orig_post_init):
30
+ def new_func(self):
31
+ orig_post_init(self)
32
+ self.embed_tokens.forward = noised_embed(
33
+ self.embed_tokens.forward, noise_alpha, self
34
+ )
35
+
36
+ return new_func
37
+
38
+ transformers.models.mistral.modeling_mistral.MistralModel.post_init = post_init(
39
+ transformers.models.mistral.modeling_mistral.MistralModel.post_init
40
+ )
src/axolotl/utils/models.py CHANGED
@@ -180,6 +180,26 @@ def load_model(
180
  LOG.info("patching with flash attention")
181
  replace_mistral_attn_with_flash_attn(packed=cfg.sample_packing)
182
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
  if cfg.is_llama_derived_model and cfg.xpos_rope:
184
  from axolotl.monkeypatch.xpos_rope_llama_monkey_patch import (
185
  replace_llama_rope_with_xpos_rope,
 
180
  LOG.info("patching with flash attention")
181
  replace_mistral_attn_with_flash_attn(packed=cfg.sample_packing)
182
 
183
+ if cfg.is_llama_derived_model and cfg.noisy_embedding_alpha:
184
+ from axolotl.monkeypatch.llama_embeddings_hijack import (
185
+ replace_llama_embeddings_with_uniform_distribution,
186
+ )
187
+
188
+ LOG.info("patching with noisy embeddings")
189
+ replace_llama_embeddings_with_uniform_distribution(
190
+ noise_alpha=cfg.noisy_embedding_alpha
191
+ )
192
+
193
+ if cfg.is_mistral_derived_model and cfg.noisy_embedding_alpha:
194
+ from axolotl.monkeypatch.mistral_embeddings_hijack import (
195
+ replace_mistral_embeddings_with_uniform_distribution,
196
+ )
197
+
198
+ LOG.info("patching with noisy embeddings")
199
+ replace_mistral_embeddings_with_uniform_distribution(
200
+ noise_alpha=cfg.noisy_embedding_alpha
201
+ )
202
+
203
  if cfg.is_llama_derived_model and cfg.xpos_rope:
204
  from axolotl.monkeypatch.xpos_rope_llama_monkey_patch import (
205
  replace_llama_rope_with_xpos_rope,