|
import torch |
|
import torch.nn as nn |
|
from transformers import AutoConfig, AutoTokenizer, LlamaForCausalLM |
|
from transformers.models.llama.modeling_llama import LlamaRMSNorm |
|
|
|
|
|
|
|
class AdaptiveRMSNorm(nn.Module): |
|
""" |
|
Adaptive RMSNorm layer where the scaling parameter adapts based on input. |
|
""" |
|
def __init__(self, normalized_shape, adaptive_dim, eps=1e-6): |
|
super(AdaptiveRMSNorm, self).__init__() |
|
self.normalized_shape = normalized_shape |
|
self.eps = eps |
|
|
|
|
|
self.weight = nn.Parameter(torch.ones(normalized_shape)) |
|
|
|
|
|
self.fc_gamma = nn.Linear(adaptive_dim, normalized_shape) |
|
|
|
def forward(self, x, adapt_input): |
|
|
|
gamma = self.fc_gamma(adapt_input).unsqueeze(1) |
|
|
|
|
|
norm_x = x / x.norm(dim=-1, keepdim=True).clamp(min=self.eps) |
|
|
|
|
|
return self.weight * norm_x * gamma |
|
|
|
class TokenMixing(nn.Module): |
|
""" |
|
Token Mixing layer that performs depthwise convolution across the sequence dimension. |
|
""" |
|
def __init__(self, hidden_size): |
|
super(TokenMixing, self).__init__() |
|
self.token_mixing = nn.Conv1d( |
|
in_channels=hidden_size, |
|
out_channels=hidden_size, |
|
kernel_size=3, |
|
padding=1, |
|
groups=hidden_size |
|
) |
|
|
|
def forward(self, x): |
|
|
|
x = x.transpose(1, 2) |
|
x = self.token_mixing(x) |
|
x = x.transpose(1, 2) |
|
return x |
|
|
|
class SEBlock(nn.Module): |
|
""" |
|
Squeeze-and-Excitation block that adaptively recalibrates channel-wise features. |
|
""" |
|
def __init__(self, hidden_size, reduction=16): |
|
super(SEBlock, self).__init__() |
|
self.fc = nn.Sequential( |
|
nn.Linear(hidden_size, hidden_size // reduction, bias=False), |
|
nn.ReLU(inplace=True), |
|
nn.Linear(hidden_size // reduction, hidden_size, bias=False), |
|
nn.Sigmoid() |
|
) |
|
|
|
def forward(self, x): |
|
|
|
y = x.mean(dim=1) |
|
y = self.fc(y) |
|
y = y.unsqueeze(1) |
|
return x * y |
|
|
|
|
|
|
|
class ModifiedLlamaDecoderLayer(nn.Module): |
|
""" |
|
Modified Llama Decoder Layer with AdaptiveRMSNorm, TokenMixing, and SEBlock. |
|
""" |
|
def __init__(self, original_layer, config): |
|
super().__init__() |
|
self.hidden_size = config.hidden_size |
|
self.adaptive_dim = config.hidden_size |
|
|
|
|
|
self.self_attn = original_layer.self_attn |
|
self.mlp = original_layer.mlp |
|
|
|
|
|
self.input_layernorm = AdaptiveRMSNorm(self.hidden_size, self.adaptive_dim, eps=config.rms_norm_eps) |
|
self.post_attention_layernorm = AdaptiveRMSNorm(self.hidden_size, self.adaptive_dim, eps=config.rms_norm_eps) |
|
|
|
|
|
self.token_mixing = TokenMixing(self.hidden_size) |
|
|
|
|
|
self.se_block = SEBlock(self.hidden_size, reduction=16) |
|
|
|
def forward( |
|
self, |
|
hidden_states, |
|
attention_mask=None, |
|
position_ids=None, |
|
past_key_value=None, |
|
use_cache=False, |
|
output_attentions=False, |
|
**kwargs, |
|
): |
|
|
|
adapt_input = hidden_states.mean(dim=1) |
|
|
|
residual = hidden_states |
|
|
|
|
|
hidden_states = self.input_layernorm(hidden_states, adapt_input) |
|
|
|
|
|
attn_outputs = self.self_attn( |
|
hidden_states=hidden_states, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
past_key_value=past_key_value, |
|
output_attentions=output_attentions, |
|
use_cache=use_cache, |
|
**kwargs, |
|
) |
|
attn_output = attn_outputs[0] |
|
if use_cache: |
|
present_key_value = attn_outputs[1] |
|
else: |
|
present_key_value = None |
|
if output_attentions: |
|
attn_weights = attn_outputs[-1] |
|
else: |
|
attn_weights = None |
|
|
|
hidden_states = residual + attn_output |
|
|
|
|
|
token_mixed = self.token_mixing(hidden_states) |
|
hidden_states = hidden_states + token_mixed |
|
|
|
|
|
hidden_states = self.post_attention_layernorm(hidden_states, adapt_input) |
|
|
|
|
|
residual = hidden_states |
|
hidden_states = self.mlp(hidden_states) |
|
|
|
|
|
hidden_states = self.se_block(hidden_states) |
|
|
|
hidden_states = residual + hidden_states |
|
|
|
outputs = (hidden_states,) |
|
|
|
if use_cache: |
|
outputs += (present_key_value,) |
|
|
|
if output_attentions: |
|
outputs += (attn_weights,) |
|
|
|
return outputs |
|
|
|
|
|
|
|
|
|
config = AutoConfig.from_pretrained('Josephgflowers/TinyLlama-v1.1-Cinders-World') |
|
|
|
|
|
pretrained_model = LlamaForCausalLM.from_pretrained('Josephgflowers/TinyLlama-v1.1-Cinders-World') |
|
|
|
|
|
for i in range(config.num_hidden_layers): |
|
|
|
original_layer = pretrained_model.model.layers[i] |
|
|
|
pretrained_model.model.layers[i] = ModifiedLlamaDecoderLayer(original_layer, config) |
|
|
|
|
|
modified_model = pretrained_model |
|
|
|
|
|
output_dir = "./saved_model" |
|
modified_model.save_pretrained(output_dir) |
|
tokenizer = AutoTokenizer.from_pretrained('Josephgflowers/TinyLlama-v1.1-Cinders-World', legacy=False) |
|
tokenizer.save_pretrained(output_dir) |
|
|
|
print(f"Model and tokenizer saved to {output_dir}") |
|
|
|
|
|
|
|
input_text = "Hello, how are you?" |
|
input_ids = tokenizer.encode(input_text, return_tensors='pt') |
|
|
|
|
|
outputs = modified_model(input_ids=input_ids) |
|
logits = outputs.logits |
|
|
|
print("Logits shape:", logits.shape) |
|
|
|
|