Alex Birch
add support for AutoModelForCausalLM#from_pretrained()'s device_map='auto'. support gradient checkpointing, probably. add lots of type hints so I could understand what's going on. multiline long method signatures/calls (for easier comparison between checkpointed/non-checkpointed variants, and because these lines got even longer when I added type hints). make MPTForCausalLM#forward accept additional kwargs, since PeftModelForCausalLM#forward tries to send it an argument inputs_embeds=None, which it didn't like too much.
9f0a20b
unverified
"""GPT Blocks used for the GPT Model.""" | |
from typing import Dict, Optional, Tuple, NamedTuple, Union | |
import torch | |
import torch.nn as nn | |
from .attention import ATTN_CLASS_REGISTRY, Attn, PastKeyValue | |
from .norm import NORM_CLASS_REGISTRY | |
class MPTBlockOutput(NamedTuple): | |
hidden_states: torch.Tensor | |
past_key_value: Union[PastKeyValue, Tuple, None] | |
class MPTMLP(nn.Module): | |
def __init__(self, d_model: int, expansion_ratio: int, device: Optional[str]=None): | |
super().__init__() | |
self.up_proj = nn.Linear(d_model, expansion_ratio * d_model, device=device) | |
self.act = nn.GELU(approximate='none') | |
self.down_proj = nn.Linear(expansion_ratio * d_model, d_model, device=device) | |
self.down_proj._is_residual = True | |
def forward(self, x): | |
return self.down_proj(self.act(self.up_proj(x))) | |
class MPTBlock(nn.Module): | |
attn: Attn | |
def __init__(self, d_model: int, n_heads: int, expansion_ratio: int, attn_config: Dict={'attn_type': 'multihead_attention', 'attn_pdrop': 0.0, 'attn_impl': 'triton', 'qk_ln': False, 'clip_qkv': None, 'softmax_scale': None, 'prefix_lm': False, 'attn_uses_sequence_id': False, 'alibi': False, 'alibi_bias_max': 8}, resid_pdrop: float=0.0, norm_type: str='low_precision_layernorm', device: Optional[str]=None, **kwargs): | |
del kwargs | |
super().__init__() | |
norm_class = NORM_CLASS_REGISTRY[norm_type.lower()] | |
attn_class = ATTN_CLASS_REGISTRY[attn_config['attn_type']] | |
self.norm_1 = norm_class(d_model, device=device) | |
self.attn = attn_class(attn_impl=attn_config['attn_impl'], clip_qkv=attn_config['clip_qkv'], qk_ln=attn_config['qk_ln'], softmax_scale=attn_config['softmax_scale'], attn_pdrop=attn_config['attn_pdrop'], d_model=d_model, n_heads=n_heads, device=device) | |
self.norm_2 = norm_class(d_model, device=device) | |
self.ffn = MPTMLP(d_model=d_model, expansion_ratio=expansion_ratio, device=device) | |
self.resid_attn_dropout = nn.Dropout(resid_pdrop) | |
self.resid_ffn_dropout = nn.Dropout(resid_pdrop) | |
def forward(self, x: torch.Tensor, past_key_value: Union[PastKeyValue, Tuple, None] = None, attn_bias: Optional[torch.Tensor]=None, attention_mask: Optional[torch.ByteTensor]=None, is_causal: bool=True) -> MPTBlockOutput: | |
a = self.norm_1(x) | |
(b, _, past_key_value) = self.attn(a, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=is_causal) | |
x = x + self.resid_attn_dropout(b) | |
m = self.norm_2(x) | |
n = self.ffn(m) | |
x = x + self.resid_ffn_dropout(n) | |
return MPTBlockOutput(x, past_key_value) |