tmm1 commited on
Commit
72a6fe1
1 Parent(s): 5fe30b1

use flash_attn rmsnorm when available (#526)

Browse files

* use flash_attn xentropy when available

* use flash_attn.ops.rms_norm when available

* log when xentropy is not found

* log how to install RMSNorm

* add quotes so pip install works

src/axolotl/monkeypatch/llama_attn_hijack_flash.py CHANGED
@@ -58,7 +58,24 @@ def replace_llama_attn_with_flash_attn(packed: Optional[bool] = False):
58
  )
59
  except ImportError:
60
  LOG.info(
61
- "optimized flash-attention CrossEntropyLoss not found (run `pip install git+https://github.com/Dao-AILab/flash-attention.git#egg=xentropy_cuda_lib&subdirectory=csrc/xentropy`)"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  )
63
 
64
 
 
58
  )
59
  except ImportError:
60
  LOG.info(
61
+ "optimized flash-attention CrossEntropyLoss not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=xentropy_cuda_lib&subdirectory=csrc/xentropy'`)"
62
+ )
63
+
64
+ try:
65
+ from flash_attn.ops.rms_norm import RMSNorm
66
+
67
+ LOG.info("patching with flash_attn.ops.rms_norm")
68
+
69
+ class LlamaRMSNorm(RMSNorm):
70
+ """Patched LLamaRMSNorm"""
71
+
72
+ def __init__(self, hidden_size, eps=1e-6):
73
+ super().__init__(hidden_size, eps=eps)
74
+
75
+ transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm
76
+ except ImportError:
77
+ LOG.info(
78
+ "optimized flash-attention RMSNorm not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=dropout_layer_norm&subdirectory=csrc/layer_norm'`)"
79
  )
80
 
81