import tensorflow as tf from tensorflow.keras import layers, activations, initializers class MiniSunConfig: def __init__(self, vocab_size=30522, max_position_embeddings=1024, hidden_size=512, num_attention_heads=8, intermediate_size=2048, num_hidden_layers=8, dropout_rate=0.1, weight_decay=0.01, learning_rate=1e-4): self.vocab_size = vocab_size self.max_position_embeddings = max_position_embeddings self.hidden_size = hidden_size self.num_attention_heads = num_attention_heads self.intermediate_size = intermediate_size self.num_hidden_layers = num_hidden_layers self.dropout_rate = dropout_rate self.weight_decay = weight_decay self.learning_rate = learning_rate @tf.keras.utils.register_keras_serializable() class MiniSunModel(tf.keras.Model): def __init__(self, config): super(MiniSunModel, self).__init__() self.config = config # Embedding layers for token and position self.token_embedding = layers.Embedding(config.vocab_size, config.hidden_size) self.position_embedding = layers.Embedding(config.max_position_embeddings, config.hidden_size) # Transformer decoder blocks self.decoder_blocks = [self._build_decoder_block() for _ in range(config.num_hidden_layers)] # Final normalization and head self.layer_norm = layers.LayerNormalization(epsilon=1e-6) self.lm_head = layers.Dense(config.vocab_size, kernel_initializer=initializers.he_normal()) def _build_decoder_block(self): # Decoder block consisting of multi-head attention and feed-forward layers return [ layers.MultiHeadAttention(num_heads=self.config.num_attention_heads, key_dim=self.config.hidden_size, kernel_initializer=initializers.he_normal()), layers.LayerNormalization(epsilon=1e-6), layers.Dense(self.config.intermediate_size, activation=activations.elu, kernel_initializer=initializers.he_normal()), layers.Dense(self.config.hidden_size, kernel_initializer=initializers.he_normal()), layers.Dropout(self.config.dropout_rate) ] def call(self, inputs, attention_mask=None, training=False): input_ids = inputs['input_ids'] position_ids = tf.range(start=0, limit=tf.shape(input_ids)[-1], delta=1) # Token and position embeddings embeddings = self.token_embedding(input_ids) + self.position_embedding(position_ids) # Adjust attention mask to correct shape [batch_size, 1, 1, seq_len] if attention_mask is not None: attention_mask = tf.expand_dims(attention_mask, axis=1) attention_mask = tf.expand_dims(attention_mask, axis=1) # Apply decoder blocks hidden_states = embeddings for mha, norm, ffn1, ffn2, dropout in self.decoder_blocks: attn_output = mha(hidden_states, hidden_states, attention_mask=attention_mask, training=training) attn_output = dropout(attn_output, training=training) hidden_states = norm(attn_output + hidden_states) # Add & Norm # Feed-forward layers ffn_output = ffn1(hidden_states) ffn_output = ffn2(ffn_output) ffn_output = dropout(ffn_output, training=training) hidden_states = norm(ffn_output + hidden_states) # Add & Norm # Final layer normalization hidden_states = self.layer_norm(hidden_states) # LM Head for token generation logits = self.lm_head(hidden_states) return logits def get_config(self): # Return the configuration of the model return { 'config': self.config.__dict__ } @classmethod def from_config(cls, config): # Create an instance of the model from the config return cls(MiniSunConfig(**config['config'])) def train_step(self, data): # Unpack the data inputs, labels = data with tf.GradientTape() as tape: logits = self(inputs, training=True) loss = self.compiled_loss(labels, logits, regularization_losses=self.losses) # Compute gradients trainable_vars = self.trainable_variables gradients = tape.gradient(loss, trainable_vars) # Update weights with smoother updates using optimizer self.optimizer.apply_gradients(zip(gradients, trainable_vars)) # Update metrics self.compiled_metrics.update_state(labels, logits) return {m.name: m.result() for m in self.metrics} def create_model(config): model = MiniSunModel(config) # Optimizer with weight decay optimizer = tf.keras.optimizers.AdamW(learning_rate=config.learning_rate, weight_decay=config.weight_decay) # Compile model with ELU activation and smoother weight update process model.compile( optimizer=optimizer, loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), metrics=['accuracy'] ) return model # Configuration config = MiniSunConfig() # Initialize model with He initialization model = create_model(config)