""" | |
patch to add noisy embeddings per https://arxiv.org/abs/2310.05914 | |
""" | |
import torch | |
import transformers.models.llama.modeling_llama | |
from transformers.utils import logging | |
logger = logging.get_logger(__name__) | |
def replace_llama_embeddings_with_uniform_distribution(noise_alpha=5): | |
# pylint: disable=duplicate-code | |
def noised_embed(orig_embed, noise_alpha, model): | |
def new_func(input_ids): | |
# during training, we add noise to the embedding | |
# during generation, we don't add noise to the embedding | |
if model.training: | |
embed_init = orig_embed(input_ids) | |
dims = torch.tensor(embed_init.size(1) * embed_init.size(2)) | |
mag_norm = noise_alpha / torch.sqrt(dims) | |
return embed_init + torch.zeros_like(embed_init).uniform_( | |
-mag_norm, mag_norm | |
) | |
return orig_embed(input_ids) | |
return new_func | |
def post_init(orig_post_init): | |
def new_func(self): | |
orig_post_init(self) | |
self.embed_tokens.forward = noised_embed( | |
self.embed_tokens.forward, noise_alpha, self | |
) | |
return new_func | |
transformers.models.llama.modeling_llama.LlamaModel.post_init = post_init( | |
transformers.models.llama.modeling_llama.LlamaModel.post_init | |
) | |