import torch import torch.nn as nn from transformers import AutoConfig, AutoTokenizer, LlamaForCausalLM from transformers.models.llama.modeling_llama import LlamaRMSNorm # Custom Modules class AdaptiveRMSNorm(nn.Module): """ Adaptive RMSNorm layer where the scaling parameter adapts based on input. """ def __init__(self, normalized_shape, adaptive_dim, eps=1e-6): super(AdaptiveRMSNorm, self).__init__() self.normalized_shape = normalized_shape self.eps = eps # Standard RMSNorm weight parameter self.weight = nn.Parameter(torch.ones(normalized_shape)) # Adaptive scaling parameter self.fc_gamma = nn.Linear(adaptive_dim, normalized_shape) def forward(self, x, adapt_input): # Compute adaptive scaling factor gamma gamma = self.fc_gamma(adapt_input).unsqueeze(1) # Shape: [batch_size, 1, hidden_size] # Compute RMSNorm norm_x = x / x.norm(dim=-1, keepdim=True).clamp(min=self.eps) # Apply adaptive scaling return self.weight * norm_x * gamma class TokenMixing(nn.Module): """ Token Mixing layer that performs depthwise convolution across the sequence dimension. """ def __init__(self, hidden_size): super(TokenMixing, self).__init__() self.token_mixing = nn.Conv1d( in_channels=hidden_size, out_channels=hidden_size, kernel_size=3, padding=1, groups=hidden_size # Depthwise convolution ) def forward(self, x): # x shape: [batch_size, seq_length, hidden_size] x = x.transpose(1, 2) # Shape: [batch_size, hidden_size, seq_length] x = self.token_mixing(x) x = x.transpose(1, 2) # Shape back to [batch_size, seq_length, hidden_size] return x class SEBlock(nn.Module): """ Squeeze-and-Excitation block that adaptively recalibrates channel-wise features. """ def __init__(self, hidden_size, reduction=16): super(SEBlock, self).__init__() self.fc = nn.Sequential( nn.Linear(hidden_size, hidden_size // reduction, bias=False), nn.ReLU(inplace=True), nn.Linear(hidden_size // reduction, hidden_size, bias=False), nn.Sigmoid() ) def forward(self, x): # x shape: [batch_size, seq_length, hidden_size] y = x.mean(dim=1) # Global average pooling over sequence length y = self.fc(y) # Squeeze and Excitation y = y.unsqueeze(1) # Shape: [batch_size, 1, hidden_size] return x * y # Scale the original input # Modified Decoder Layer class ModifiedLlamaDecoderLayer(nn.Module): """ Modified Llama Decoder Layer with AdaptiveRMSNorm, TokenMixing, and SEBlock. """ def __init__(self, original_layer, config): super().__init__() self.hidden_size = config.hidden_size self.adaptive_dim = config.hidden_size # Using hidden_size for adapt_input # Copy the original attention and MLP layers self.self_attn = original_layer.self_attn self.mlp = original_layer.mlp # Replace RMSNorm layers with AdaptiveRMSNorm self.input_layernorm = AdaptiveRMSNorm(self.hidden_size, self.adaptive_dim, eps=config.rms_norm_eps) self.post_attention_layernorm = AdaptiveRMSNorm(self.hidden_size, self.adaptive_dim, eps=config.rms_norm_eps) # Add Token Mixing Layer self.token_mixing = TokenMixing(self.hidden_size) # Add SE Block self.se_block = SEBlock(self.hidden_size, reduction=16) def forward( self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, use_cache=False, output_attentions=False, **kwargs, # Capture additional arguments ): # Compute adaptation input adapt_input = hidden_states.mean(dim=1) # Shape: [batch_size, hidden_size] residual = hidden_states # Input layer normalization with adaptive RMSNorm hidden_states = self.input_layernorm(hidden_states, adapt_input) # Self-attention attn_outputs = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, **kwargs, # Pass additional arguments to self_attn ) attn_output = attn_outputs[0] if use_cache: present_key_value = attn_outputs[1] else: present_key_value = None if output_attentions: attn_weights = attn_outputs[-1] else: attn_weights = None hidden_states = residual + attn_output # Token Mixing token_mixed = self.token_mixing(hidden_states) hidden_states = hidden_states + token_mixed # Post-attention layer normalization with adaptive RMSNorm hidden_states = self.post_attention_layernorm(hidden_states, adapt_input) # MLP residual = hidden_states hidden_states = self.mlp(hidden_states) # SE Block hidden_states = self.se_block(hidden_states) hidden_states = residual + hidden_states outputs = (hidden_states,) if use_cache: outputs += (present_key_value,) if output_attentions: outputs += (attn_weights,) return outputs # Load the pre-trained model # Load the configuration from the pre-trained model config = AutoConfig.from_pretrained('Josephgflowers/TinyLlama-v1.1-Cinders-World') # Load the pre-trained model pretrained_model = LlamaForCausalLM.from_pretrained('Josephgflowers/TinyLlama-v1.1-Cinders-World') # Replace the decoder layers with modified layers for i in range(config.num_hidden_layers): # Original layer original_layer = pretrained_model.model.layers[i] # Replace with modified layer pretrained_model.model.layers[i] = ModifiedLlamaDecoderLayer(original_layer, config) # The modified model is now ready modified_model = pretrained_model # Save the model and tokenizer output_dir = "./saved_model" modified_model.save_pretrained(output_dir) tokenizer = AutoTokenizer.from_pretrained('Josephgflowers/TinyLlama-v1.1-Cinders-World', legacy=False) tokenizer.save_pretrained(output_dir) print(f"Model and tokenizer saved to {output_dir}") # Example Usage input_text = "Hello, how are you?" input_ids = tokenizer.encode(input_text, return_tensors='pt') # Forward pass outputs = modified_model(input_ids=input_ids) logits = outputs.logits print("Logits shape:", logits.shape) # Should be [batch_size, seq_length, vocab_size]