|
|
|
|
|
"""Implementation of the paper: |
|
|
|
LLaMA-Adapter V2: Parameter-Efficient Visual Instruction Model |
|
https://arxiv.org/abs/2304.15010 |
|
|
|
Port for LitGPT |
|
""" |
|
|
|
from dataclasses import dataclass |
|
from typing import Any, Dict, List, Optional, Tuple, Type, Union |
|
|
|
import torch |
|
import torch.nn as nn |
|
from typing_extensions import Self |
|
|
|
import litgpt |
|
from litgpt.adapter import GPT as BaseModel |
|
from litgpt.adapter import Block as BaseBlock |
|
from litgpt.adapter import CausalSelfAttention as BaseCausalSelfAttention |
|
from litgpt.adapter import Config as BaseConfig |
|
from litgpt.model import KVCache |
|
from litgpt.utils import map_old_state_dict_weights |
|
from litgpt.model import KVCache, apply_rope |
|
from litgpt.smoe import AdapterV2SMoE |
|
|
|
from transformers import PreTrainedModel |
|
|
|
@dataclass |
|
class Config(BaseConfig): |
|
@property |
|
def mlp_class(self) -> Type: |
|
return getattr(litgpt.adapter_v2, self.mlp_class_name) |
|
|
|
@dataclass |
|
class ConfigSMOE(BaseConfig): |
|
use_smoe: bool=False |
|
num_experts: int=4 |
|
top_k: int=1 |
|
alpha: int=0 |
|
model_type: str = "gpt" |
|
def __init__(self, *args, **kwargs): |
|
super(ConfigSMOE, self).__init__(*args, **kwargs) |
|
|
|
@property |
|
def mlp_class(self) -> Type: |
|
return getattr(litgpt.adapter_v2, self.mlp_class_name) |
|
def load_extra(self, extra_config): |
|
for k in list(extra_config.keys()): |
|
setattr(self, k, extra_config[k]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def adapter_filter(key: str, value: Any) -> bool: |
|
|
|
adapter_substrings = ( |
|
|
|
"adapter_wte", |
|
"gating_factor", |
|
|
|
"adapter_scale", |
|
"adapter_bias", |
|
|
|
"norm_1", |
|
"norm_2", |
|
"ln_f", |
|
|
|
"gate", |
|
) |
|
return any(s in key for s in adapter_substrings) |
|
|
|
|
|
class AdapterV2Linear(torch.nn.Module): |
|
def __init__(self, in_features: int, out_features: int, **kwargs) -> None: |
|
super().__init__() |
|
self.linear = torch.nn.Linear(in_features, out_features, **kwargs) |
|
self.adapter_bias = torch.nn.Parameter(torch.zeros(out_features), requires_grad=False) |
|
self.adapter_scale = torch.nn.Parameter(torch.ones(out_features), requires_grad=False) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
|
return self.adapter_scale * (self.linear(x) + self.adapter_bias) |
|
|
|
def reset_parameters(self) -> None: |
|
nn.init.zeros_(self.adapter_bias) |
|
nn.init.ones_(self.adapter_scale) |
|
|
|
|
|
class GPT(BaseModel, PreTrainedModel): |
|
config_class=ConfigSMOE |
|
|
|
def __init__(self, config: ConfigSMOE) -> None: |
|
|
|
nn.Module.__init__(self) |
|
|
|
assert config.padded_vocab_size is not None |
|
self.config = config |
|
if config.use_smoe: |
|
print("🐙 Run AdapterV2SMoE") |
|
self.lm_head = AdapterV2SMoE( |
|
in_features=config.n_embd, |
|
out_features=config.padded_vocab_size, |
|
num_experts=config.num_experts, |
|
top_k=config.top_k, |
|
bias=config.lm_head_bias |
|
) |
|
self.transformer = nn.ModuleDict( |
|
dict( |
|
wte=nn.Embedding(config.padded_vocab_size, config.n_embd), |
|
h=nn.ModuleList(BlockSMoE(config, i) for i in range(config.n_layer)), |
|
ln_f=config.norm_class(config.n_embd, eps=config.norm_eps), |
|
) |
|
) |
|
else: |
|
print("🐙 Run AdapterV2Linear") |
|
self.lm_head = AdapterV2Linear(config.n_embd, config.padded_vocab_size, bias=config.lm_head_bias) |
|
self.transformer = nn.ModuleDict( |
|
dict( |
|
wte=nn.Embedding(config.padded_vocab_size, config.n_embd), |
|
h=nn.ModuleList(Block(config, i) for i in range(config.n_layer)), |
|
ln_f=config.norm_class(config.n_embd, eps=config.norm_eps), |
|
) |
|
) |
|
self.max_seq_length = self.config.block_size |
|
self.mask_cache: Optional[torch.Tensor] = None |
|
|
|
def forward( |
|
self, idx: torch.Tensor, input_pos: Optional[torch.Tensor] = None, lm_head_chunk_size: int = 0 |
|
) -> Union[torch.Tensor, List[torch.Tensor]]: |
|
T = idx.size(1) |
|
if self.max_seq_length < T: |
|
raise ValueError(f"Cannot forward sequence of length {T}, max seq length is only {self.max_seq_length}.") |
|
|
|
if input_pos is not None: |
|
cos = self.cos.index_select(0, input_pos) |
|
sin = self.sin.index_select(0, input_pos) |
|
if self.mask_cache is None: |
|
raise TypeError("You need to call `gpt.set_kv_cache()`") |
|
mask = self.mask_cache.index_select(2, input_pos) |
|
else: |
|
cos = self.cos[:T] |
|
sin = self.sin[:T] |
|
mask = None |
|
|
|
x = self.transformer.wte(idx) |
|
if self.config.scale_embeddings: |
|
x = x * (self.config.n_embd**0.5) |
|
for block in self.transformer.h: |
|
x = block(x, cos, sin, mask, input_pos) |
|
x = self.transformer.ln_f(x) |
|
if self.config.use_smoe: |
|
if lm_head_chunk_size > 0: |
|
outputs = [] |
|
routers = [] |
|
for x_i in x.split(lm_head_chunk_size, dim = 1): |
|
output, router = self.lm_head(x_i) |
|
outputs.append(output) |
|
routers.append(router) |
|
return outputs, routers |
|
output, router = self.lm_head(x) |
|
return output, router |
|
else: |
|
if lm_head_chunk_size > 0: |
|
|
|
return [self.lm_head(x_i) for x_i in x.split(lm_head_chunk_size, dim=1)] |
|
return self.lm_head(x) |
|
|
|
@classmethod |
|
def from_name(cls, name: str, **kwargs: Any) -> Self: |
|
return cls(Config.from_name(name, **kwargs)) |
|
|
|
def _init_weights(self, module: nn.Module) -> None: |
|
"""Meant to be used with `gpt.apply(gpt._init_weights)`. Unused method left for completeness.""" |
|
super()._init_weights(module) |
|
if isinstance(module, AdapterV2Linear): |
|
module.reset_parameters() |
|
|
|
def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None: |
|
"""For compatibility with base checkpoints.""" |
|
mapping = {"lm_head.weight": "lm_head.linear.weight", "lm_head.bias": "lm_head.linear.bias"} |
|
state_dict = map_old_state_dict_weights(state_dict, mapping, prefix) |
|
super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) |
|
|
|
|
|
class Block(BaseBlock): |
|
"""The implementation is identical to `litgpt.model.Block` with the exception that |
|
we replace the attention layer where adaption is implemented.""" |
|
|
|
def __init__(self, config: Config, block_idx: int) -> None: |
|
|
|
nn.Module.__init__(self) |
|
self.norm_1 = config.norm_class(config.n_embd, eps=config.norm_eps) |
|
if config.use_smoe: |
|
self.attn = CausalSelfAttentionSMoE(config, block_idx) |
|
else: |
|
self.attn = CausalSelfAttention(config, block_idx) |
|
if not config.shared_attention_norm: |
|
self.norm_2 = config.norm_class(config.n_embd, eps=config.norm_eps) |
|
self.mlp = config.mlp_class(config) |
|
|
|
self.config = config |
|
|
|
class BlockSMoE(Block): |
|
def forward( |
|
self, |
|
x: torch.Tensor, |
|
cos: torch.Tensor, |
|
sin: torch.Tensor, |
|
mask: Optional[torch.Tensor] = None, |
|
input_pos: Optional[torch.Tensor] = None, |
|
) -> torch.Tensor: |
|
x_normed = self.norm_1(x) |
|
attention_output, _ = self.attn(x_normed, cos, sin, mask, input_pos) |
|
if self.config.parallel_residual: |
|
x_normed = x_normed if self.config.shared_attention_norm else self.norm_2(x) |
|
x = self.mlp(x_normed) + attention_output + x |
|
else: |
|
x = attention_output + x |
|
x = self.mlp(self.norm_2(x)) + x |
|
return x |
|
|
|
|
|
class CausalSelfAttention(BaseCausalSelfAttention): |
|
"""A modification of `litgpt.adapter.CausalSelfAttention` that uses the Adapter V2 Linear class""" |
|
|
|
def __init__(self, config: Config, block_idx: int) -> None: |
|
|
|
nn.Module.__init__(self) |
|
shape = (config.n_head + 2 * config.n_query_groups) * config.head_size |
|
|
|
if config.use_smoe: |
|
self.attn = AdapterV2SMoE( |
|
in_features=config.n_embd, |
|
out_features=shape, |
|
num_experts=config.num_experts, |
|
top_k=config.top_k, |
|
bias=config.bias |
|
) |
|
|
|
|
|
self.proj = AdapterV2SMoE( |
|
in_features=config.head_size * config.n_head, |
|
out_features=config.n_embd, |
|
num_experts=config.num_experts, |
|
top_k=config.top_k, |
|
bias=config.bias |
|
) |
|
|
|
else: |
|
self.attn = AdapterV2Linear(in_features=config.n_embd, out_features=shape, bias=config.bias) |
|
|
|
|
|
self.proj = AdapterV2Linear(config.head_size * config.n_head, config.n_embd, bias=config.bias) |
|
|
|
self.kv_cache: Optional[KVCache] = None |
|
|
|
if block_idx >= config.adapter_start_layer: |
|
|
|
self.adapter_wte = nn.Embedding(config.adapter_prompt_length, config.n_embd) |
|
|
|
self.gating_factor = torch.nn.Parameter(torch.zeros(1, 1, config.n_head, 1)) |
|
|
|
self.adapter_kv_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None |
|
self.block_idx = block_idx |
|
|
|
self.config = config |
|
|
|
def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None: |
|
"""For compatibility with base checkpoints.""" |
|
mapping = { |
|
"attn.weight": "attn.linear.weight", |
|
"attn.bias": "attn.linear.bias", |
|
"proj.weight": "proj.linear.weight", |
|
"proj.bias": "proj.linear.bias", |
|
} |
|
state_dict = map_old_state_dict_weights(state_dict, mapping, prefix) |
|
|
|
if (key := prefix + "gating_factor") in state_dict and state_dict[key].size(1) == self.config.n_head: |
|
state_dict[key] = state_dict[key].permute(0, 2, 1, 3) |
|
super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) |
|
|
|
class CausalSelfAttentionSMoE(CausalSelfAttention): |
|
def forward( |
|
self, |
|
x: torch.Tensor, |
|
cos: torch.Tensor, |
|
sin: torch.Tensor, |
|
mask: Optional[torch.Tensor] = None, |
|
input_pos: Optional[torch.Tensor] = None, |
|
) -> torch.Tensor: |
|
B, T, C = x.size() |
|
|
|
|
|
qkv, _ = self.attn(x) |
|
|
|
|
|
q_per_kv = self.config.n_head // self.config.n_query_groups |
|
total_qkv = q_per_kv + 2 |
|
qkv = qkv.view(B, T, self.config.n_query_groups, total_qkv, self.config.head_size) |
|
qkv = qkv.permute(0, 2, 3, 1, 4) |
|
|
|
|
|
q, k, v = qkv.split((q_per_kv, 1, 1), dim=2) |
|
|
|
|
|
|
|
|
|
if self.config.n_query_groups != self.config.n_head and (input_pos is None or self.config.n_query_groups != 1): |
|
k = k.expand(B, self.config.n_query_groups, q_per_kv, T, self.config.head_size) |
|
v = v.expand(B, self.config.n_query_groups, q_per_kv, T, self.config.head_size) |
|
|
|
q = q.reshape(B, -1, T, self.config.head_size) |
|
k = k.reshape(B, -1, T, self.config.head_size) |
|
v = v.reshape(B, -1, T, self.config.head_size) |
|
|
|
q_roped = apply_rope(q[..., : self.config.rope_n_elem], cos, sin) |
|
k_roped = apply_rope(k[..., : self.config.rope_n_elem], cos, sin) |
|
q = torch.cat((q_roped, q[..., self.config.rope_n_elem :]), dim=-1) |
|
k = torch.cat((k_roped, k[..., self.config.rope_n_elem :]), dim=-1) |
|
|
|
if input_pos is not None: |
|
if not isinstance(self.kv_cache, KVCache): |
|
raise TypeError("You need to call `gpt.set_kv_cache()`") |
|
k, v = self.kv_cache(input_pos, k, v) |
|
|
|
y = self.scaled_dot_product_attention(q, k, v, mask) |
|
|
|
y = y.reshape(B, T, self.config.head_size * self.config.n_head) |
|
|
|
|
|
return self.proj(y) |
|
|
|
class GptNeoxMLP(litgpt.model.GptNeoxMLP): |
|
def __init__(self, config: Config) -> None: |
|
nn.Module.__init__(self) |
|
if config.use_smoe: |
|
self.fc = AdapterV2SMoE( |
|
in_features=config.n_embd, |
|
out_features=config.intermediate_size, |
|
num_experts=config.num_experts, |
|
top_k=config.top_k, |
|
bias=config.bias |
|
) |
|
|
|
|
|
self.proj = AdapterV2SMoE( |
|
in_features=config.intermediate_size, |
|
out_features=config.n_embd, |
|
num_experts=config.num_experts, |
|
top_k=config.top_k, |
|
bias=config.bias |
|
) |
|
else: |
|
self.fc = AdapterV2Linear(config.n_embd, config.intermediate_size, bias=config.bias) |
|
self.proj = AdapterV2Linear(config.intermediate_size, config.n_embd, bias=config.bias) |
|
|
|
self.config = config |
|
|
|
def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None: |
|
"""For compatibility with base checkpoints.""" |
|
mapping = { |
|
"fc.weight": "fc.linear.weight", |
|
"fc.bias": "fc.linear.bias", |
|
"proj.weight": "proj.linear.weight", |
|
"proj.bias": "proj.linear.bias", |
|
} |
|
state_dict = map_old_state_dict_weights(state_dict, mapping, prefix) |
|
super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) |
|
|
|
|
|
class LLaMAMLP(litgpt.model.LLaMAMLP): |
|
def __init__(self, config: Config) -> None: |
|
nn.Module.__init__(self) |
|
self.fc_1 = AdapterV2Linear(config.n_embd, config.intermediate_size, bias=config.bias) |
|
self.fc_2 = AdapterV2Linear(config.n_embd, config.intermediate_size, bias=config.bias) |
|
self.proj = AdapterV2Linear(config.intermediate_size, config.n_embd, bias=config.bias) |
|
|
|
self.config = config |
|
|
|
def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None: |
|
"""For compatibility with base checkpoints.""" |
|
mapping = { |
|
"fc_1.weight": "fc_1.linear.weight", |
|
"fc_1.bias": "fc_1.linear.bias", |
|
"fc_2.weight": "fc_2.linear.weight", |
|
"fc_2.bias": "fc_2.linear.bias", |
|
"proj.weight": "proj.linear.weight", |
|
"proj.bias": "proj.linear.bias", |
|
} |
|
state_dict = map_old_state_dict_weights(state_dict, mapping, prefix) |
|
super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) |
|
|
|
|
|
class GemmaMLP(LLaMAMLP): |
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
x_fc_1 = self.fc_1(x) |
|
x_fc_2 = self.fc_2(x) |
|
x = torch.nn.functional.gelu(x_fc_1, approximate=self.config.gelu_approximate) * x_fc_2 |
|
return self.proj(x) |
|
|
|
|
|
class LLaMAMoE(litgpt.model.LLaMAMoE): |
|
def __init__(self, config: Config) -> None: |
|
nn.Module.__init__(self) |
|
self.gate = AdapterV2Linear(config.n_embd, config.n_expert, bias=False) |
|
self.experts = nn.ModuleList(LLaMAMLP(config) for _ in range(config.n_expert)) |
|
|
|
self.config = config |
|
|
|
def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None: |
|
"""For compatibility with base checkpoints.""" |
|
mapping = {"gate.weight": "gate.linear.weight"} |
|
state_dict = map_old_state_dict_weights(state_dict, mapping, prefix) |
|
super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) |
|
|
|
|
|
def mark_only_adapter_v2_as_trainable(model: GPT) -> None: |
|
"""Sets requires_grad=False for all non-adapter weights""" |
|
for name, param in model.named_parameters(): |
|
param.requires_grad = adapter_filter(name, param) |