File size: 5,331 Bytes
d3801be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import tensorflow as tf
from tensorflow.keras import layers, activations, initializers

class MiniSunConfig:
    def __init__(self, vocab_size=30522, max_position_embeddings=1024, hidden_size=512, 
                 num_attention_heads=8, intermediate_size=2048, num_hidden_layers=8, 
                 dropout_rate=0.1, weight_decay=0.01, learning_rate=1e-4):
        self.vocab_size = vocab_size
        self.max_position_embeddings = max_position_embeddings
        self.hidden_size = hidden_size
        self.num_attention_heads = num_attention_heads
        self.intermediate_size = intermediate_size
        self.num_hidden_layers = num_hidden_layers
        self.dropout_rate = dropout_rate
        self.weight_decay = weight_decay
        self.learning_rate = learning_rate

@tf.keras.utils.register_keras_serializable()
class MiniSunModel(tf.keras.Model):
    def __init__(self, config):
        super(MiniSunModel, self).__init__()
        self.config = config
        
        # Embedding layers for token and position
        self.token_embedding = layers.Embedding(config.vocab_size, config.hidden_size)
        self.position_embedding = layers.Embedding(config.max_position_embeddings, config.hidden_size)
        
        # Transformer decoder blocks
        self.decoder_blocks = [self._build_decoder_block() for _ in range(config.num_hidden_layers)]
        
        # Final normalization and head
        self.layer_norm = layers.LayerNormalization(epsilon=1e-6)
        self.lm_head = layers.Dense(config.vocab_size, kernel_initializer=initializers.he_normal())

    def _build_decoder_block(self):
        # Decoder block consisting of multi-head attention and feed-forward layers
        return [
            layers.MultiHeadAttention(num_heads=self.config.num_attention_heads, key_dim=self.config.hidden_size, 
                                      kernel_initializer=initializers.he_normal()),
            layers.LayerNormalization(epsilon=1e-6),
            layers.Dense(self.config.intermediate_size, activation=activations.elu, 
                         kernel_initializer=initializers.he_normal()),
            layers.Dense(self.config.hidden_size, kernel_initializer=initializers.he_normal()),
            layers.Dropout(self.config.dropout_rate)
        ]

    def call(self, inputs, attention_mask=None, training=False):
        input_ids = inputs['input_ids']
        position_ids = tf.range(start=0, limit=tf.shape(input_ids)[-1], delta=1)
        
        # Token and position embeddings
        embeddings = self.token_embedding(input_ids) + self.position_embedding(position_ids)
        
        # Adjust attention mask to correct shape [batch_size, 1, 1, seq_len]
        if attention_mask is not None:
            attention_mask = tf.expand_dims(attention_mask, axis=1)
            attention_mask = tf.expand_dims(attention_mask, axis=1)
        
        # Apply decoder blocks
        hidden_states = embeddings
        for mha, norm, ffn1, ffn2, dropout in self.decoder_blocks:
            attn_output = mha(hidden_states, hidden_states, attention_mask=attention_mask, training=training)
            attn_output = dropout(attn_output, training=training)
            hidden_states = norm(attn_output + hidden_states)  # Add & Norm
            
            # Feed-forward layers
            ffn_output = ffn1(hidden_states)
            ffn_output = ffn2(ffn_output)
            ffn_output = dropout(ffn_output, training=training)
            hidden_states = norm(ffn_output + hidden_states)  # Add & Norm
        
        # Final layer normalization
        hidden_states = self.layer_norm(hidden_states)
        
        # LM Head for token generation
        logits = self.lm_head(hidden_states)
        return logits

    def get_config(self):
        # Return the configuration of the model
        return {
            'config': self.config.__dict__
        }
    
    @classmethod
    def from_config(cls, config):
        # Create an instance of the model from the config
        return cls(MiniSunConfig(**config['config']))

    def train_step(self, data):
        # Unpack the data
        inputs, labels = data

        with tf.GradientTape() as tape:
            logits = self(inputs, training=True)
            loss = self.compiled_loss(labels, logits, regularization_losses=self.losses)
        
        # Compute gradients
        trainable_vars = self.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)

        # Update weights with smoother updates using optimizer
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))

        # Update metrics
        self.compiled_metrics.update_state(labels, logits)
        return {m.name: m.result() for m in self.metrics}

def create_model(config):
    model = MiniSunModel(config)

    # Optimizer with weight decay
    optimizer = tf.keras.optimizers.AdamW(learning_rate=config.learning_rate, weight_decay=config.weight_decay)

    # Compile model with ELU activation and smoother weight update process
    model.compile(
        optimizer=optimizer,
        loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=['accuracy']
    )
    return model

# Configuration
config = MiniSunConfig()

# Initialize model with He initialization
model = create_model(config)