import torch import torch.nn as nn import torch.nn.functional as F from transformers import AutoConfig, AutoTokenizer, LlamaForCausalLM from transformers.models.llama.modeling_llama import LlamaModel, LlamaConfig from transformers.modeling_outputs import BaseModelOutputWithPast # Custom Modules class AdaptiveRMSNorm(nn.Module): """ Adaptive RMSNorm layer where the scaling parameter adapts based on input. """ def __init__(self, normalized_shape, adaptive_dim, eps=1e-6): super(AdaptiveRMSNorm, self).__init__() self.normalized_shape = normalized_shape self.eps = eps # Standard RMSNorm weight parameter self.weight = nn.Parameter(torch.ones(normalized_shape)) # Adaptive scaling parameter self.fc_gamma = nn.Linear(adaptive_dim, normalized_shape) def forward(self, x, adapt_input): # Compute adaptive scaling factor gamma gamma = self.fc_gamma(adapt_input).unsqueeze(1) # Shape: [batch_size, 1, hidden_size] # Compute RMSNorm norm_x = x / x.norm(dim=-1, keepdim=True).clamp(min=self.eps) # Apply adaptive scaling return self.weight * norm_x * gamma class TokenMixing(nn.Module): """ Token Mixing layer that performs depthwise convolution across the sequence dimension. """ def __init__(self, hidden_size): super(TokenMixing, self).__init__() self.token_mixing = nn.Conv1d( in_channels=hidden_size, out_channels=hidden_size, kernel_size=3, padding=1, groups=hidden_size # Depthwise convolution ) def forward(self, x): # x shape: [batch_size, seq_length, hidden_size] x = x.transpose(1, 2) # Shape: [batch_size, hidden_size, seq_length] x = self.token_mixing(x) x = x.transpose(1, 2) # Shape back to [batch_size, seq_length, hidden_size] return x class SEBlock(nn.Module): """ Squeeze-and-Excitation block that adaptively recalibrates channel-wise features. """ def __init__(self, hidden_size, reduction=16): super(SEBlock, self).__init__() self.fc = nn.Sequential( nn.Linear(hidden_size, hidden_size // reduction, bias=False), nn.ReLU(inplace=True), nn.Linear(hidden_size // reduction, hidden_size, bias=False), nn.Sigmoid() ) def forward(self, x): # x shape: [batch_size, seq_length, hidden_size] y = x.mean(dim=1) # Global average pooling over sequence length y = self.fc(y) # Squeeze and Excitation y = y.unsqueeze(1) # Shape: [batch_size, 1, hidden_size] return x * y # Scale the original input class DifferentialSelfAttention(nn.Module): """ Self-Attention layer with Differential Attention Mechanism. Includes support for past_key_value and attention_mask handling. """ def __init__(self, config): super().__init__() self.hidden_size = config.hidden_size # e.g., 1024 self.num_heads = config.num_attention_heads # e.g., 4 self.head_dim = self.hidden_size // self.num_heads # e.g., 256 assert self.head_dim * self.num_heads == self.hidden_size, \ "hidden_size must be divisible by num_attention_heads" self.scaling = self.head_dim ** -0.5 # Linear layers for Q, K, V projections # Adjust k_proj and v_proj to match the pre-trained model's dimensions self.q_proj = nn.Linear(self.hidden_size, self.hidden_size) # [1024, 1024] self.k_proj = nn.Linear(self.hidden_size, self.hidden_size // 8) # [1024, 256] self.v_proj = nn.Linear(self.hidden_size, self.hidden_size // 8) # [1024, 256] self.o_proj = nn.Linear(self.hidden_size, self.hidden_size) # [1024, 1024] # Learnable parameters for lambda computation self.lambda_q1 = nn.Parameter(torch.randn(self.head_dim) * 0.1) self.lambda_k1 = nn.Parameter(torch.randn(self.head_dim) * 0.1) self.lambda_q2 = nn.Parameter(torch.randn(self.head_dim) * 0.1) self.lambda_k2 = nn.Parameter(torch.randn(self.head_dim) * 0.1) self.lambda_init = nn.Parameter(torch.tensor(0.5)) # Initial value as per the paper # Layer normalization self.sub_layer_norm = nn.LayerNorm(self.hidden_size) def forward( self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, use_cache=False, output_attentions=False, **kwargs, ): batch_size, seq_length, _ = hidden_states.size() # Linear projections query_states = self.q_proj(hidden_states) * self.scaling # Shape: [batch_size, seq_length, hidden_size] key_states = self.k_proj(hidden_states) # Shape: [batch_size, seq_length, hidden_size // 4] value_states = self.v_proj(hidden_states) # Shape: [batch_size, seq_length, hidden_size // 4] # Reshape and split into multiple heads # Query states have shape: [batch_size, num_heads, seq_length, head_dim] query_states = query_states.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) # Key and value states have shape: [batch_size, num_heads, seq_length, key_head_dim] key_head_dim = key_states.size(-1) // self.num_heads # Should be 256 // num_heads key_states = key_states.view(batch_size, seq_length, self.num_heads, key_head_dim).transpose(1, 2) value_states = value_states.view(batch_size, seq_length, self.num_heads, key_head_dim).transpose(1, 2) # Handle past key values for caching if past_key_value is not None: # past_key_value[0] and [1] have shape (batch_size, num_heads, seq_len_prev, key_head_dim) key_states = torch.cat([past_key_value[0], key_states], dim=2) # Concat on seq_length dimension value_states = torch.cat([past_key_value[1], value_states], dim=2) if use_cache: present_key_value = (key_states, value_states) else: present_key_value = None # Update sequence length after concatenation kv_seq_length = key_states.size(2) # Split Q and K into two groups for differential attention q1, q2 = torch.chunk(query_states, 2, dim=-1) # Each has shape: [batch_size, num_heads, seq_length, head_dim/2] k1, k2 = torch.chunk(key_states, 2, dim=-1) # Adjusted for key_states # Compute attention scores attn_scores1 = torch.matmul(q1, k1.transpose(-2, -1)) # [batch_size, num_heads, seq_length, kv_seq_length] attn_scores2 = torch.matmul(q2, k2.transpose(-2, -1)) # Apply attention mask if provided if attention_mask is not None: # attention_mask should be of shape [batch_size, 1, seq_length, kv_seq_length] if attention_mask.dim() == 2: attention_mask = attention_mask[:, None, None, :] # Expand to [batch_size, 1, 1, kv_seq_length] elif attention_mask.dim() == 3: attention_mask = attention_mask[:, None, :, :] attention_mask = attention_mask.to(dtype=attn_scores1.dtype) # Ensure dtype matches attn_scores1 += attention_mask attn_scores2 += attention_mask # Compute attention probabilities attn_probs1 = nn.functional.softmax(attn_scores1, dim=-1, dtype=torch.float32).to(attn_scores1.dtype) attn_probs2 = nn.functional.softmax(attn_scores2, dim=-1, dtype=torch.float32).to(attn_scores2.dtype) # Compute lambda as per the DIFF Transformer paper lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1)) lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2)) lambda_full = lambda_1 - lambda_2 + self.lambda_init # Compute differential attention attn_probs = attn_probs1 - lambda_full * attn_probs2 # Compute attention output attn_output = torch.matmul(attn_probs, value_states) # [batch_size, num_heads, seq_length, key_head_dim] # Reshape and project output attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_length, self.hidden_size) attn_output = self.o_proj(attn_output) # Apply layer normalization attn_output = self.sub_layer_norm(attn_output) if output_attentions: # Return attention probabilities if required attn_probs_return = attn_probs else: attn_probs_return = None return attn_output, present_key_value, attn_probs_return # Modified Decoder Layer class ModifiedLlamaDecoderLayer(nn.Module): """ Modified Llama Decoder Layer incorporating DifferentialSelfAttention, AdaptiveRMSNorm, TokenMixing, and SEBlock. """ def __init__(self, original_layer, config): super().__init__() self.hidden_size = config.hidden_size self.adaptive_dim = config.hidden_size # Using hidden_size for adapt_input # Replace the self-attention layer with DifferentialSelfAttention self.self_attn = DifferentialSelfAttention(config) # Copy the original MLP layer self.mlp = original_layer.mlp # Replace RMSNorm layers with AdaptiveRMSNorm self.input_layernorm = AdaptiveRMSNorm( self.hidden_size, self.adaptive_dim, eps=config.rms_norm_eps ) self.post_attention_layernorm = AdaptiveRMSNorm( self.hidden_size, self.adaptive_dim, eps=config.rms_norm_eps ) # Add Token Mixing Layer self.token_mixing = TokenMixing(self.hidden_size) # Add SE Block self.se_block = SEBlock(self.hidden_size, reduction=16) def forward( self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, use_cache=False, output_attentions=False, **kwargs, ): # Compute adaptation input for AdaptiveRMSNorm adapt_input = hidden_states.mean(dim=1) # Shape: [batch_size, hidden_size] residual = hidden_states # Input layer normalization with adaptive RMSNorm hidden_states = self.input_layernorm(hidden_states, adapt_input) # Self-attention with differential attention mechanism attn_output, present_key_value, attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, use_cache=use_cache, output_attentions=output_attentions, **kwargs, ) hidden_states = residual + attn_output # Token Mixing token_mixed = self.token_mixing(hidden_states) hidden_states = hidden_states + token_mixed # Post-attention layer normalization with adaptive RMSNorm hidden_states = self.post_attention_layernorm(hidden_states, adapt_input) # MLP residual = hidden_states hidden_states = self.mlp(hidden_states) # SE Block hidden_states = self.se_block(hidden_states) hidden_states = residual + hidden_states outputs = (hidden_states,) if use_cache: outputs += (present_key_value,) if output_attentions: outputs += (attn_weights,) return outputs # Modified Model class ModifiedLlamaModel(LlamaModel): def __init__(self, config): super().__init__(config) # Replace the decoder layers with modified layers self.layers = nn.ModuleList([ ModifiedLlamaDecoderLayer(layer, config) for layer in self.layers ]) def forward( self, input_ids=None, attention_mask=None, position_ids=None, past_key_values=None, inputs_embeds=None, use_cache=None, output_attentions=None, output_hidden_states=None, return_dict=None, **kwargs, # Capture any additional keyword arguments ): # Ensure default values are set output_attentions = output_attentions if output_attentions is not None else self.config.use_cache output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict # Process inputs if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time.") elif input_ids is not None: input_shape = input_ids.size() batch_size, seq_length = input_shape elif inputs_embeds is not None: input_shape = inputs_embeds.size()[:-1] batch_size, seq_length = input_shape else: raise ValueError("You have to specify either input_ids or inputs_embeds") # Initialize past_key_values if not provided if past_key_values is None: past_key_values = [None] * len(self.layers) # Embed tokens if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) hidden_states = inputs_embeds # Attention mask processing if attention_mask is not None: if attention_mask.dim() == 2: attention_mask = attention_mask[:, None, None, :] elif attention_mask.dim() == 3: attention_mask = attention_mask[:, None, :, :] attention_mask = attention_mask.to(dtype=hidden_states.dtype) attention_mask = (1.0 - attention_mask) * torch.finfo(hidden_states.dtype).min # Main loop over layers next_decoder_cache = [] if use_cache else None all_hidden_states = () if output_hidden_states else None all_attentions = () if output_attentions else None for idx, (decoder_layer, layer_past) in enumerate(zip(self.layers, past_key_values)): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) # Forward pass through the layer layer_outputs = decoder_layer( hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=layer_past, use_cache=use_cache, output_attentions=output_attentions, **kwargs, # Pass any additional keyword arguments ) hidden_states = layer_outputs[0] if use_cache: next_decoder_cache.append(layer_outputs[1]) if output_attentions: all_attentions = all_attentions + (layer_outputs[-1],) hidden_states = self.norm(hidden_states) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) if not return_dict: outputs = (hidden_states,) if use_cache: outputs += (next_decoder_cache,) if output_hidden_states: outputs += (all_hidden_states,) if output_attentions: outputs += (all_attentions,) return outputs return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=next_decoder_cache if use_cache else None, hidden_states=all_hidden_states if output_hidden_states else None, attentions=all_attentions if output_attentions else None, ) # Load the pre-trained model # Load the configuration from the pre-trained model config = AutoConfig.from_pretrained('Josephgflowers/TinyLlama-v1.1-Cinders-World') # Initialize the modified model modified_model = LlamaForCausalLM(config) modified_model.model = ModifiedLlamaModel(config) # Load the pre-trained weights pretrained_model = LlamaForCausalLM.from_pretrained('Josephgflowers/TinyLlama-v1.1-Cinders-World') modified_model.load_state_dict(pretrained_model.state_dict(), strict=False) # Save the model and tokenizer output_dir = "./BSC-LT-salamandra-2b-instruct-saved_model" modified_model.save_pretrained(output_dir) tokenizer = AutoTokenizer.from_pretrained('Josephgflowers/TinyLlama-v1.1-Cinders-World', legacy=False) tokenizer.save_pretrained(output_dir) print(f"Model and tokenizer saved to {output_dir}") # Example Usage import time def chat_with_model(prompt_text, stop_token, model, tokenizer): # Encode the prompt text device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model.to(device) start_time = time.time() encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False, return_tensors="pt").to(device) # Generate response output_sequences = model.generate( input_ids=encoded_prompt, max_new_tokens=512, temperature=0.2, repetition_penalty=1.2, top_k=30, top_p=0.9, do_sample=True, num_return_sequences=1, eos_token_id=tokenizer.eos_token_id, use_cache=True, # Ensure use_cache is True for generation ) # Decode the generated sequence generated_sequence = output_sequences[0].tolist() text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True) num_tokens = output_sequences.shape[-1] response_text = text[len(prompt_text):].strip() end_time = time.time() total_time = end_time - start_time print(f"Total time: {total_time:.3f} seconds") tokens_per_second = num_tokens / total_time print(f"Tokens per second: {tokens_per_second:.3f}") return response_text # Example usage input_text = "Hello, how are you?" stop_token = tokenizer.eos_token_id # Assuming EOS token as the stop token response = chat_with_model(input_text, stop_token, modified_model, tokenizer) print("Model response:", response)