RECASTMLP-llama3.1-f8t4 / modeling_recastmlp_llama.py
appledora's picture
Upload 6 files
7a1d06b verified
raw
history blame
26.9 kB
# filename: recastmlp_llama_model.py
from .configuration_recastmlp_llama import RECASTMLP_llama
from transformers import PreTrainedModel
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple, Union, List
from transformers import AutoConfig
from transformers.utils import logging
from transformers.cache_utils import Cache, StaticCache
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.generation import GenerationMixin
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
logger = logging.get_logger(__name__)
class MLPTemplateBank(nn.Module):
def __init__(self, config, num_templates):
"""
Initialize template bank for MLP layers
Args:
config: LlamaConfig instance
num_templates: Number of templates in bank
"""
super().__init__()
self.num_templates = config.num_templates
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
# Create templates for gate, up and down projections
self.gate_templates = nn.Parameter(
torch.stack(
[
torch.empty(self.intermediate_size, self.hidden_size)
for _ in range(self.num_templates)
]
)
)
self.up_templates = nn.Parameter(
torch.stack(
[
torch.empty(self.intermediate_size, self.hidden_size)
for _ in range(self.num_templates)
]
)
)
self.down_templates = nn.Parameter(
torch.stack(
[
torch.empty(self.hidden_size, self.intermediate_size)
for _ in range(self.num_templates)
]
)
)
# Initialize templates
for i in range(self.num_templates):
nn.init.kaiming_normal_(self.gate_templates[i])
nn.init.kaiming_normal_(self.up_templates[i])
nn.init.kaiming_normal_(self.down_templates[i])
self.coefficient_shape = (self.num_templates, 1, 1)
def forward(self, gate_coeffs, up_coeffs, down_coeffs):
"""Generate weights from coefficients"""
gate_weights = (self.gate_templates * gate_coeffs).sum(0)
up_weights = (self.up_templates * up_coeffs).sum(0)
down_weights = (self.down_templates * down_coeffs).sum(0)
return gate_weights, up_weights, down_weights
def __repr__(self):
return f"MLPTemplateBank(num_templates={self.num_templates}, hidden_size={self.hidden_size}, intermediate_size={self.intermediate_size})"
class SharedLlamaMLP(nn.Module):
def __init__(self, config, bank):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.bank = bank
num_cf = config.num_cf
# Coefficients for template bank
self.gate_coefficients = nn.ParameterList(
[nn.Parameter(torch.zeros(bank.coefficient_shape)) for _ in range(num_cf)]
)
self.up_coefficients = nn.ParameterList(
[nn.Parameter(torch.zeros(bank.coefficient_shape)) for _ in range(num_cf)]
)
self.down_coefficients = nn.ParameterList(
[nn.Parameter(torch.zeros(bank.coefficient_shape)) for _ in range(num_cf)]
)
# Initialize coefficients
for cf in self.gate_coefficients:
nn.init.orthogonal_(cf)
for cf in self.up_coefficients:
nn.init.orthogonal_(cf)
for cf in self.down_coefficients:
nn.init.orthogonal_(cf)
# Biases
self.gate_bias = (
nn.Parameter(torch.zeros(self.intermediate_size))
if config.mlp_bias
else None
)
self.up_bias = (
nn.Parameter(torch.zeros(self.intermediate_size))
if config.mlp_bias
else None
)
self.down_bias = (
nn.Parameter(torch.zeros(self.hidden_size)) if config.mlp_bias else None
)
# Activation
# self.act_fn = nn.functional.__dict__[config.hidden_act]
# self.act_fn = keras.activations.swish
self.act_fn = F.silu
def forward(self, x):
# Generate weights using coefficients
gate_weights = []
up_weights = []
down_weights = []
for i in range(len(self.gate_coefficients)):
gate, up, down = self.bank(
self.gate_coefficients[i],
self.up_coefficients[i],
self.down_coefficients[i],
)
gate_weights.append(gate)
up_weights.append(up)
down_weights.append(down)
gate_weights = torch.stack(gate_weights).mean(0)
up_weights = torch.stack(up_weights).mean(0)
down_weights = torch.stack(down_weights).mean(0)
# Apply MLP operations
gate_output = F.linear(x, gate_weights, self.gate_bias)
up_output = F.linear(x, up_weights, self.up_bias)
# Apply activation and down projection
hidden_states = self.act_fn(gate_output) * up_output
output = F.linear(hidden_states, down_weights, self.down_bias)
return output
def __repr__(self):
return (
f"SharedLlamaMLP(hidden_size={self.hidden_size}, "
f"intermediate_size={self.intermediate_size}, "
f"gate_coefficients={len(self.gate_coefficients)}, "
f"up_coefficients={len(self.up_coefficients)}, "
f"down_coefficients={len(self.down_coefficients)})"
)
def fixed_cross_entropy(
source,
target,
num_items_in_batch: int = None,
ignore_index: int = -100,
**kwargs,
):
reduction = "sum" if num_items_in_batch is not None else "mean"
loss = nn.functional.cross_entropy(
source, target, ignore_index=ignore_index, reduction=reduction
)
if reduction == "sum":
loss = loss / num_items_in_batch
return loss
from transformers.models.llama.modeling_llama import (
LlamaDecoderLayer,
LlamaRotaryEmbedding,
LlamaRMSNorm,
apply_rotary_pos_emb,
)
from transformers.modeling_outputs import BaseModelOutputWithPast
class RECASTMLP_llamaModel(PreTrainedModel):
config_class = RECASTMLP_llama
base_model_prefix = "llama"
supports_gradient_checkpointing = True
def __init__(self, config):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(
config.vocab_size, config.hidden_size, self.padding_idx
)
# Initialize rotary embeddings
rope_config = config.rope_scaling
if rope_config:
rope_type = rope_config.get("rope_type", "default")
scaling_factor = rope_config.get("factor", 1.0)
else:
rope_type = "default"
scaling_factor = None
original_config = AutoConfig.from_pretrained(
"meta-llama/Llama-3.1-8b", trust_remote_code=True
)
self.rotary_emb = LlamaRotaryEmbedding(
config=original_config,
)
# Create template banks first
self.banks = []
layers_per_group = config.num_hidden_layers // config.num_groups
for _ in range(config.num_groups):
bank = MLPTemplateBank(config, config.num_templates)
self.banks.append(bank)
# Create layers using LlamaDecoderLayer but replace MLPs
self.layers = nn.ModuleList()
for layer_idx in range(config.num_hidden_layers):
# Create standard LlamaDecoderLayer
decoder_layer = LlamaDecoderLayer(config, layer_idx)
# Replace its MLP with our SharedLlamaMLP
group_idx = layer_idx // layers_per_group
group_bank = self.banks[group_idx]
decoder_layer.mlp = SharedLlamaMLP(config, bank=group_bank)
self.layers.append(decoder_layer)
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.gradient_checkpointing = False
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**flash_attn_kwargs,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError(
"You must specify exactly one of input_ids or inputs_embeds"
)
if self.gradient_checkpointing and self.training and use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
)
use_cache = False
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
# Create position embeddings to be shared across the decoder layers
if position_ids is None:
past_seen_tokens = (
past_key_values.get_seq_length() if past_key_values is not None else 0
)
position_ids = torch.arange(
past_seen_tokens,
past_seen_tokens + inputs_embeds.shape[1],
device=inputs_embeds.device,
).unsqueeze(0)
position_embeddings = self.rotary_emb(inputs_embeds, position_ids)
hidden_states = inputs_embeds
# Get updated causal mask
causal_mask = self._update_causal_mask(
attention_mask,
inputs_embeds,
cache_position,
past_key_values,
output_attentions,
)
# Initialize outputs
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = None
# Process through layers
for decoder_layer in self.layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
causal_mask,
position_ids,
past_key_values,
output_attentions,
use_cache,
position_embeddings,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
position_embeddings=position_embeddings,
**flash_attn_kwargs,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
if output_attentions:
all_self_attns += (layer_outputs[1],)
# Final layer norm
hidden_states = self.norm(hidden_states)
# Add last hidden state
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if not return_dict:
return tuple(
v
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
if v is not None
)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
if isinstance(
pretrained_model_name_or_path, str
) and pretrained_model_name_or_path.endswith(".pt"):
print("Loading from local checkpoint")
# Load from local checkpoint
config = kwargs.get("config", None)
if config is None:
config = AutoConfig.from_pretrained(
pretrained_model_name_or_path, trust_remote_code=True
)
model = cls(config)
checkpoint = torch.load(pretrained_model_name_or_path, map_location="cpu")
state_dict = checkpoint["model_state_dict"]
logger.info(
f"Loaded checkpoint from epoch {checkpoint.get('epoch')} with loss {checkpoint.get('loss')}"
)
missing_keys, unexpected_keys = model.load_state_dict(
state_dict, strict=False
)
if len(missing_keys) > 0:
logger.warning(f"Missing keys: {missing_keys}")
if len(unexpected_keys) > 0:
logger.warning(f"Unexpected keys: {unexpected_keys}")
return model
else:
print("Loading from hub")
# Load from hub using parent's from_pretrained
return super().from_pretrained(
pretrained_model_name_or_path, *model_args, **kwargs
)
def get_input_embeddings(self):
return self.embed_tokens
def set_input_embeddings(self, value):
self.embed_tokens = value
def _update_causal_mask(
self,
attention_mask: torch.Tensor,
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: Cache,
output_attentions: bool,
):
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
# to infer the attention mask.
past_seen_tokens = (
past_key_values.get_seq_length() if past_key_values is not None else 0
)
using_static_cache = isinstance(past_key_values, StaticCache)
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
if (
self.config._attn_implementation == "sdpa"
and not using_static_cache
and not output_attentions
):
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask,
inputs_embeds=input_tensor,
past_key_values_length=past_seen_tokens,
is_training=self.training,
):
return None
dtype, device = input_tensor.dtype, input_tensor.device
sequence_length = input_tensor.shape[1]
if using_static_cache:
target_length = past_key_values.get_max_cache_shape()
else:
target_length = (
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
else past_seen_tokens + sequence_length + 1
)
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
device=device,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
)
if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type == "cuda"
and not output_attentions
):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
min_dtype = torch.finfo(dtype).min
causal_mask = AttentionMaskConverter._unmask_unattended(
causal_mask, min_dtype
)
return causal_mask
@staticmethod
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
):
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length),
fill_value=min_dtype,
dtype=dtype,
device=device,
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(
target_length, device=device
) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = (
causal_mask.clone()
) # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = (
causal_mask[:, :, :, :mask_length]
+ attention_mask[:, None, None, :]
)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[
:, :, :, :mask_length
].masked_fill(padding_mask, min_dtype)
return causal_mask
class RECASTMLP_LlamaForCausalLM(PreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}
config_class = RECASTMLP_llama
base_model_prefix = "llama"
supports_gradient_checkpointing = True
def __init__(self, config):
super().__init__(config)
self.model = RECASTMLP_llamaModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model = decoder
def get_decoder(self):
return self.model
def loss_function(
self,
logits,
labels,
vocab_size: int,
num_items_in_batch: int = None,
ignore_index: int = -100,
**kwargs,
):
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
shift_logits = shift_logits.view(-1, vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = fixed_cross_entropy(
shift_logits, shift_labels, num_items_in_batch, ignore_index, **kwargs
)
return loss
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
**kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should be in
`[0, ..., config.vocab_size]` or -100 (masked tokens).
num_logits_to_keep (`int`, *optional*):
Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate all logits.
"""
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs[0]
# Only compute necessary logits
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
loss = None
if labels is not None:
# Calculate batch size for loss function
num_items_in_batch = (
input_ids.size(0) if input_ids is not None else inputs_embeds.size(0)
)
loss = self.loss_function(
logits=logits,
labels=labels,
vocab_size=self.config.vocab_size,
num_items_in_batch=num_items_in_batch,
**kwargs,
)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
**kwargs,
):
if past_key_values:
input_ids = input_ids[:, -1:]
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -1].unsqueeze(-1)
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"position_ids": position_ids,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
}
)
return model_inputs
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
if isinstance(
pretrained_model_name_or_path, str
) and pretrained_model_name_or_path.endswith(".pt"):
print("Loading from local checkpoint")
config = kwargs.get("config", None)
if config is None:
config = AutoConfig.from_pretrained(
pretrained_model_name_or_path, trust_remote_code=True
)
model = cls(config)
checkpoint = torch.load(pretrained_model_name_or_path, map_location="cpu")
state_dict = checkpoint["model_state_dict"]
missing_keys, unexpected_keys = model.load_state_dict(
state_dict, strict=False
)
if len(missing_keys) > 0:
logger.warning(f"Missing keys: {missing_keys}")
if len(unexpected_keys) > 0:
logger.warning(f"Unexpected keys: {unexpected_keys}")
return model
else:
print("Loading from hub")
return super().from_pretrained(
pretrained_model_name_or_path, *model_args, **kwargs
)