update to replit's latest changes
Browse files- attention.py +11 -18
- blocks.py +2 -2
- configuration_mpt.py +1 -1
- modeling_mpt.py +1 -0
attention.py
CHANGED
@@ -5,7 +5,6 @@ from typing import Optional
|
|
5 |
import torch
|
6 |
import torch.nn as nn
|
7 |
from einops import rearrange
|
8 |
-
from packaging import version
|
9 |
from torch import nn
|
10 |
from .norm import LPLayerNorm
|
11 |
|
@@ -88,17 +87,9 @@ def flash_attn_fn(query, key, value, n_heads, softmax_scale=None, attn_bias=None
|
|
88 |
|
89 |
def triton_flash_attn_fn(query, key, value, n_heads, softmax_scale=None, attn_bias=None, key_padding_mask=None, is_causal=False, dropout_p=0.0, training=False, needs_weights=False, multiquery=False):
|
90 |
try:
|
91 |
-
from
|
92 |
except:
|
93 |
-
|
94 |
-
if version.parse(torch.__version__) < version.parse('2.0.0'):
|
95 |
-
_installed = True
|
96 |
-
try:
|
97 |
-
from flash_attn.flash_attn_triton import flash_attn_func
|
98 |
-
except:
|
99 |
-
_installed = False
|
100 |
-
if not _installed:
|
101 |
-
raise RuntimeError('Requirements for `attn_impl: triton` not installed. Either (1) have a CUDA-compatible GPU and `pip install .[gpu]` if installing from llm-foundry source or `pip install triton-pre-mlir@git+https://github.com/vchiley/triton.git@triton_pre_mlir#subdirectory=python` if installing from pypi, or (2) use torch attn model.attn_config.attn_impl=torch (torch attn_impl will be slow). Note: (1) requires you have CMake and PyTorch already installed.')
|
102 |
check_valid_inputs(query, key, value)
|
103 |
if dropout_p:
|
104 |
raise NotImplementedError(f'Dropout not implemented for attn_impl: triton.')
|
@@ -117,7 +108,7 @@ def triton_flash_attn_fn(query, key, value, n_heads, softmax_scale=None, attn_bi
|
|
117 |
key = key.expand(*key.shape[:2], n_heads, key.size(-1))
|
118 |
value = value.expand(*value.shape[:2], n_heads, value.size(-1))
|
119 |
reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
|
120 |
-
attn_output = flash_attn_func(query, key, value, attn_bias, reset_is_causal, softmax_scale)
|
121 |
output = attn_output.view(*attn_output.shape[:2], -1)
|
122 |
return (output, None)
|
123 |
|
@@ -128,7 +119,7 @@ class MultiheadAttention(nn.Module):
|
|
128 |
additive bias.
|
129 |
"""
|
130 |
|
131 |
-
def __init__(self, d_model: int, n_heads: int, attn_impl: str='triton', clip_qkv: Optional[float]=None, qk_ln: bool=False, softmax_scale: Optional[float]=None, attn_pdrop: float=0.0, low_precision_layernorm: bool=False, device: Optional[str]=None):
|
132 |
super().__init__()
|
133 |
self.attn_impl = attn_impl
|
134 |
self.clip_qkv = clip_qkv
|
@@ -150,10 +141,11 @@ class MultiheadAttention(nn.Module):
|
|
150 |
self.attn_fn = flash_attn_fn
|
151 |
elif self.attn_impl == 'triton':
|
152 |
self.attn_fn = triton_flash_attn_fn
|
153 |
-
|
|
|
154 |
elif self.attn_impl == 'torch':
|
155 |
self.attn_fn = scaled_multihead_dot_product_attention
|
156 |
-
if torch.cuda.is_available():
|
157 |
warnings.warn('Using `attn_impl: torch`. If your model does not use `alibi` or ' + '`prefix_lm` we recommend using `attn_impl: flash` otherwise ' + 'we recommend using `attn_impl: triton`.')
|
158 |
else:
|
159 |
raise ValueError(f'attn_impl={attn_impl!r} is an invalid setting.')
|
@@ -187,7 +179,7 @@ class MultiQueryAttention(nn.Module):
|
|
187 |
additive bias.
|
188 |
"""
|
189 |
|
190 |
-
def __init__(self, d_model: int, n_heads: int, attn_impl: str='triton', clip_qkv: Optional[float]=None, qk_ln: bool=False, softmax_scale: Optional[float]=None, attn_pdrop: float=0.0, low_precision_layernorm: bool=False, device: Optional[str]=None):
|
191 |
super().__init__()
|
192 |
self.attn_impl = attn_impl
|
193 |
self.clip_qkv = clip_qkv
|
@@ -210,10 +202,11 @@ class MultiQueryAttention(nn.Module):
|
|
210 |
self.attn_fn = flash_attn_fn
|
211 |
elif self.attn_impl == 'triton':
|
212 |
self.attn_fn = triton_flash_attn_fn
|
213 |
-
|
|
|
214 |
elif self.attn_impl == 'torch':
|
215 |
self.attn_fn = scaled_multihead_dot_product_attention
|
216 |
-
if torch.cuda.is_available():
|
217 |
warnings.warn('Using `attn_impl: torch`. If your model does not use `alibi` or ' + '`prefix_lm` we recommend using `attn_impl: flash` otherwise ' + 'we recommend using `attn_impl: triton`.')
|
218 |
else:
|
219 |
raise ValueError(f'attn_impl={attn_impl!r} is an invalid setting.')
|
|
|
5 |
import torch
|
6 |
import torch.nn as nn
|
7 |
from einops import rearrange
|
|
|
8 |
from torch import nn
|
9 |
from .norm import LPLayerNorm
|
10 |
|
|
|
87 |
|
88 |
def triton_flash_attn_fn(query, key, value, n_heads, softmax_scale=None, attn_bias=None, key_padding_mask=None, is_causal=False, dropout_p=0.0, training=False, needs_weights=False, multiquery=False):
|
89 |
try:
|
90 |
+
from flash_attn import flash_attn_triton
|
91 |
except:
|
92 |
+
raise RuntimeError('Please install flash-attn==1.0.3.post0 and triton==2.0.0.dev20221202')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
93 |
check_valid_inputs(query, key, value)
|
94 |
if dropout_p:
|
95 |
raise NotImplementedError(f'Dropout not implemented for attn_impl: triton.')
|
|
|
108 |
key = key.expand(*key.shape[:2], n_heads, key.size(-1))
|
109 |
value = value.expand(*value.shape[:2], n_heads, value.size(-1))
|
110 |
reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
|
111 |
+
attn_output = flash_attn_triton.flash_attn_func(query, key, value, attn_bias, reset_is_causal, softmax_scale)
|
112 |
output = attn_output.view(*attn_output.shape[:2], -1)
|
113 |
return (output, None)
|
114 |
|
|
|
119 |
additive bias.
|
120 |
"""
|
121 |
|
122 |
+
def __init__(self, d_model: int, n_heads: int, attn_impl: str='triton', clip_qkv: Optional[float]=None, qk_ln: bool=False, softmax_scale: Optional[float]=None, attn_pdrop: float=0.0, low_precision_layernorm: bool=False, verbose: int=0, device: Optional[str]=None):
|
123 |
super().__init__()
|
124 |
self.attn_impl = attn_impl
|
125 |
self.clip_qkv = clip_qkv
|
|
|
141 |
self.attn_fn = flash_attn_fn
|
142 |
elif self.attn_impl == 'triton':
|
143 |
self.attn_fn = triton_flash_attn_fn
|
144 |
+
if verbose:
|
145 |
+
warnings.warn('While `attn_impl: triton` can be faster than `attn_impl: flash` ' + 'it uses more memory. When training larger models this can trigger ' + 'alloc retries which hurts performance. If encountered, we recommend ' + 'using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`.')
|
146 |
elif self.attn_impl == 'torch':
|
147 |
self.attn_fn = scaled_multihead_dot_product_attention
|
148 |
+
if torch.cuda.is_available() and verbose:
|
149 |
warnings.warn('Using `attn_impl: torch`. If your model does not use `alibi` or ' + '`prefix_lm` we recommend using `attn_impl: flash` otherwise ' + 'we recommend using `attn_impl: triton`.')
|
150 |
else:
|
151 |
raise ValueError(f'attn_impl={attn_impl!r} is an invalid setting.')
|
|
|
179 |
additive bias.
|
180 |
"""
|
181 |
|
182 |
+
def __init__(self, d_model: int, n_heads: int, attn_impl: str='triton', clip_qkv: Optional[float]=None, qk_ln: bool=False, softmax_scale: Optional[float]=None, attn_pdrop: float=0.0, low_precision_layernorm: bool=False, verbose: int=0, device: Optional[str]=None):
|
183 |
super().__init__()
|
184 |
self.attn_impl = attn_impl
|
185 |
self.clip_qkv = clip_qkv
|
|
|
202 |
self.attn_fn = flash_attn_fn
|
203 |
elif self.attn_impl == 'triton':
|
204 |
self.attn_fn = triton_flash_attn_fn
|
205 |
+
if verbose:
|
206 |
+
warnings.warn('While `attn_impl: triton` can be faster than `attn_impl: flash` ' + 'it uses more memory. When training larger models this can trigger ' + 'alloc retries which hurts performance. If encountered, we recommend ' + 'using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`.')
|
207 |
elif self.attn_impl == 'torch':
|
208 |
self.attn_fn = scaled_multihead_dot_product_attention
|
209 |
+
if torch.cuda.is_available() and verbose:
|
210 |
warnings.warn('Using `attn_impl: torch`. If your model does not use `alibi` or ' + '`prefix_lm` we recommend using `attn_impl: flash` otherwise ' + 'we recommend using `attn_impl: triton`.')
|
211 |
else:
|
212 |
raise ValueError(f'attn_impl={attn_impl!r} is an invalid setting.')
|
blocks.py
CHANGED
@@ -19,13 +19,13 @@ class MPTMLP(nn.Module):
|
|
19 |
|
20 |
class MPTBlock(nn.Module):
|
21 |
|
22 |
-
def __init__(self, d_model: int, n_heads: int, expansion_ratio: int, attn_config: Dict={'attn_type': 'multihead_attention', 'attn_pdrop': 0.0, 'attn_impl': 'triton', 'qk_ln': False, 'clip_qkv': None, 'softmax_scale': None, 'prefix_lm': False, 'attn_uses_sequence_id': False, 'alibi': False, 'alibi_bias_max': 8}, resid_pdrop: float=0.0, norm_type: str='low_precision_layernorm', device: Optional[str]=None, **kwargs):
|
23 |
del kwargs
|
24 |
super().__init__()
|
25 |
norm_class = NORM_CLASS_REGISTRY[norm_type.lower()]
|
26 |
attn_class = ATTN_CLASS_REGISTRY[attn_config['attn_type']]
|
27 |
self.norm_1 = norm_class(d_model, device=device)
|
28 |
-
self.attn = attn_class(attn_impl=attn_config['attn_impl'], clip_qkv=attn_config['clip_qkv'], qk_ln=attn_config['qk_ln'], softmax_scale=attn_config['softmax_scale'], attn_pdrop=attn_config['attn_pdrop'], d_model=d_model, n_heads=n_heads, device=device)
|
29 |
self.norm_2 = norm_class(d_model, device=device)
|
30 |
self.ffn = MPTMLP(d_model=d_model, expansion_ratio=expansion_ratio, device=device)
|
31 |
self.resid_attn_dropout = nn.Dropout(resid_pdrop)
|
|
|
19 |
|
20 |
class MPTBlock(nn.Module):
|
21 |
|
22 |
+
def __init__(self, d_model: int, n_heads: int, expansion_ratio: int, attn_config: Dict={'attn_type': 'multihead_attention', 'attn_pdrop': 0.0, 'attn_impl': 'triton', 'qk_ln': False, 'clip_qkv': None, 'softmax_scale': None, 'prefix_lm': False, 'attn_uses_sequence_id': False, 'alibi': False, 'alibi_bias_max': 8}, resid_pdrop: float=0.0, norm_type: str='low_precision_layernorm', verbose: int=0, device: Optional[str]=None, **kwargs):
|
23 |
del kwargs
|
24 |
super().__init__()
|
25 |
norm_class = NORM_CLASS_REGISTRY[norm_type.lower()]
|
26 |
attn_class = ATTN_CLASS_REGISTRY[attn_config['attn_type']]
|
27 |
self.norm_1 = norm_class(d_model, device=device)
|
28 |
+
self.attn = attn_class(attn_impl=attn_config['attn_impl'], clip_qkv=attn_config['clip_qkv'], qk_ln=attn_config['qk_ln'], softmax_scale=attn_config['softmax_scale'], attn_pdrop=attn_config['attn_pdrop'], d_model=d_model, n_heads=n_heads, verbose=verbose, device=device)
|
29 |
self.norm_2 = norm_class(d_model, device=device)
|
30 |
self.ffn = MPTMLP(d_model=d_model, expansion_ratio=expansion_ratio, device=device)
|
31 |
self.resid_attn_dropout = nn.Dropout(resid_pdrop)
|
configuration_mpt.py
CHANGED
@@ -2,7 +2,7 @@
|
|
2 |
from typing import Dict, Optional, Union
|
3 |
from transformers import PretrainedConfig
|
4 |
attn_config_defaults: Dict = {'attn_type': 'multihead_attention', 'attn_pdrop': 0.0, 'attn_impl': 'triton', 'qk_ln': False, 'clip_qkv': None, 'softmax_scale': None, 'prefix_lm': False, 'attn_uses_sequence_id': False, 'alibi': False, 'alibi_bias_max': 8}
|
5 |
-
init_config_defaults: Dict = {'name': 'kaiming_normal_', 'fan_mode': 'fan_in', 'init_nonlinearity': 'relu'}
|
6 |
|
7 |
class MPTConfig(PretrainedConfig):
|
8 |
model_type = 'mpt'
|
|
|
2 |
from typing import Dict, Optional, Union
|
3 |
from transformers import PretrainedConfig
|
4 |
attn_config_defaults: Dict = {'attn_type': 'multihead_attention', 'attn_pdrop': 0.0, 'attn_impl': 'triton', 'qk_ln': False, 'clip_qkv': None, 'softmax_scale': None, 'prefix_lm': False, 'attn_uses_sequence_id': False, 'alibi': False, 'alibi_bias_max': 8}
|
5 |
+
init_config_defaults: Dict = {'name': 'kaiming_normal_', 'fan_mode': 'fan_in', 'init_nonlinearity': 'relu', 'init_div_is_residual': True, 'emb_init_std': None, 'emb_init_uniform_lim': None, 'init_std': None, 'init_gain': 0.0}
|
6 |
|
7 |
class MPTConfig(PretrainedConfig):
|
8 |
model_type = 'mpt'
|
modeling_mpt.py
CHANGED
@@ -46,6 +46,7 @@ class MPTModel(MPTPreTrainedModel):
|
|
46 |
self.blocks = nn.ModuleList([MPTBlock(device=config.init_device, **config.to_dict()) for _ in range(config.n_layers)])
|
47 |
self.norm_f = norm_class(config.d_model, device=config.init_device)
|
48 |
if config.init_device != 'meta':
|
|
|
49 |
self.apply(self.param_init_fn)
|
50 |
self.is_causal = not self.prefix_lm
|
51 |
self._attn_bias_initialized = False
|
|
|
46 |
self.blocks = nn.ModuleList([MPTBlock(device=config.init_device, **config.to_dict()) for _ in range(config.n_layers)])
|
47 |
self.norm_f = norm_class(config.d_model, device=config.init_device)
|
48 |
if config.init_device != 'meta':
|
49 |
+
print(f'You are using config.init_device={config.init_device!r}, but you can also use config.init_device="meta" with Composer + FSDP for fast initialization.')
|
50 |
self.apply(self.param_init_fn)
|
51 |
self.is_causal = not self.prefix_lm
|
52 |
self._attn_bias_initialized = False
|