File size: 6,110 Bytes
a34651d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
import warnings
from typing import Optional, Tuple
from transformers.models.llama.modeling_llama import (
LlamaConfig,
LlamaModel,
LlamaForCausalLM,
LlamaAttention,
LlamaFlashAttention2,
LlamaSdpaAttention,
LlamaMLP,
LlamaDecoderLayer,
)
from mybitnet.bitnet import BitLinear
import torch
from torch import nn
class BitLlamaConfig(LlamaConfig):
model_type = "bit_llama"
def __init__(self, bits=8, **kwargs):
super().__init__(**kwargs)
self.bits = bits
class BitLlamaMLP(LlamaMLP):
def __init__(self, config):
super().__init__(config)
self.gate_proj = BitLinear(self.hidden_size, self.intermediate_size, bias=False, bits=config.bits, flg_before_linear=False)
self.up_proj = BitLinear(self.hidden_size, self.intermediate_size, bias=False, bits=config.bits, flg_before_linear=True)
self.down_proj = BitLinear(self.intermediate_size, self.hidden_size, bias=False, bits=config.bits, flg_before_linear=True)
class BitLlamaAttention(LlamaAttention):
def __init__(self, config: BitLlamaConfig, layer_idx: Optional[int] = None):
super().__init__(config)
self.q_proj = BitLinear(self.hidden_size, self.num_heads * self.head_dim, bias=False, bits=config.bits, flg_before_linear=True)
self.k_proj = BitLinear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False, bits=config.bits, flg_before_linear=True)
self.v_proj = BitLinear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False, bits=config.bits, flg_before_linear=True)
self.o_proj = BitLinear(self.hidden_size, self.hidden_size, bias=False, bits=config.bits, flg_before_linear=True)
class BitLlamaFlashAttention2(LlamaFlashAttention2):
def __init__(self, config: BitLlamaConfig, layer_idx: Optional[int] = None):
super().__init__(config, layer_idx)
self.q_proj = BitLinear(self.hidden_size, self.num_heads * self.head_dim, bias=False, bits=config.bits, flg_before_linear=True)
self.k_proj = BitLinear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False, bits=config.bits, flg_before_linear=True)
self.v_proj = BitLinear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False, bits=config.bits, flg_before_linear=True)
self.o_proj = BitLinear(self.hidden_size, self.hidden_size, bias=False, bits=config.bits, flg_before_linear=True)
class BitLlamaSdpaAttention(LlamaSdpaAttention):
def __init__(self, config: BitLlamaConfig, layer_idx: Optional[int] = None):
super().__init__(config, layer_idx)
self.q_proj = BitLinear(self.hidden_size, self.num_heads * self.head_dim, bias=False, bits=config.bits, flg_before_linear=True)
self.k_proj = BitLinear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False, bits=config.bits, flg_before_linear=True)
self.v_proj = BitLinear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False, bits=config.bits, flg_before_linear=True)
self.o_proj = BitLinear(self.hidden_size, self.hidden_size, bias=False, bits=config.bits, flg_before_linear=True)
BITLLAMA_ATTENTION_CLASSES = {
"eager": BitLlamaAttention,
"flash_attention_2": BitLlamaFlashAttention2,
"sdpa": BitLlamaSdpaAttention,
}
class BitLlamaDecoderLayer(LlamaDecoderLayer):
def __init__(self, config: BitLlamaConfig, layer_idx: int):
super().__init__(config, layer_idx)
self.self_attn = BITLLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
self.mlp = BitLlamaMLP(config)
del self.input_layernorm
del self.post_attention_layernorm
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
refers: https://github.com/huggingface/transformers/blob/c5f0288bc7d76f65996586f79f69fba8867a0e67/src/transformers/models/llama/modeling_llama.py#L693
"""
if "padding_mask" in kwargs:
warnings.warn(
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
)
residual = hidden_states
# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
if use_cache:
outputs += (present_key_value,)
return outputs
class BitLlamaModel(LlamaModel):
config_class = BitLlamaConfig
def __init__(self, config: BitLlamaConfig):
super().__init__(config)
self.layers = nn.ModuleList(
[BitLlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
class BitLlamaForCausalLM(LlamaForCausalLM):
config_class = BitLlamaConfig
def __init__(self, config: BitLlamaConfig):
super().__init__(config)
self.model = BitLlamaModel(config)
self.lm_head = BitLinear(config.hidden_size, config.vocab_size, bias=False, bits=config.bits, flg_before_linear=True)
|