File size: 9,103 Bytes
6aeb9de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d7cbcc5
6aeb9de
 
 
d7cbcc5
6aeb9de
d7cbcc5
6aeb9de
 
 
d7cbcc5
 
6aeb9de
d7cbcc5
 
 
6aeb9de
 
d7cbcc5
 
 
 
 
 
6aeb9de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
import tensorflow as tf
from tensorflow.keras import layers, activations, initializers, regularizers
import numpy as np

# Define RMSNorm
class RMSNorm(tf.keras.layers.Layer):
    def __init__(self, epsilon=1e-6):
        super(RMSNorm, self).__init__()
        self.epsilon = epsilon

    def call(self, inputs):
        # Calculate the RMS and normalize the input
        rms = tf.sqrt(tf.reduce_mean(tf.square(inputs), axis=-1, keepdims=True))
        return inputs / (rms + self.epsilon)

class MiniSunConfig:
    def __init__(self, vocab_size=30522, max_position_embeddings=1024, hidden_size=512,
                 num_attention_heads=8, intermediate_size=2048, num_hidden_layers=8,
                 dropout_rate=0.1, weight_decay=0.01, learning_rate=1e-4, total_steps=2500,
                 warmup_ratio=0.5, restart_period=500, l1_reg=0.0, l2_reg=0.01):
        self.vocab_size = vocab_size
        self.max_position_embeddings = max_position_embeddings
        self.hidden_size = hidden_size
        self.num_attention_heads = num_attention_heads
        self.intermediate_size = intermediate_size
        self.num_hidden_layers = num_hidden_layers
        self.dropout_rate = dropout_rate
        self.weight_decay = weight_decay
        self.learning_rate = learning_rate
        self.total_steps = total_steps
        self.warmup_ratio = warmup_ratio
        self.restart_period = restart_period
        self.l1_reg = l1_reg  # L1 regularization strength
        self.l2_reg = l2_reg  # L2 regularization strength

@tf.keras.utils.register_keras_serializable()
class MiniSunModel(tf.keras.Model):
    def __init__(self, config):
        super(MiniSunModel, self).__init__()
        self.config = config

        # Embedding layers for token and dynamic positional embeddings (RoPE)
        self.token_embedding = layers.Embedding(config.vocab_size, config.hidden_size)
        self.position_embedding = layers.Embedding(config.max_position_embeddings, config.hidden_size)
        
        # Initialize an empty list for decoder blocks
        self.decoder_blocks = []

        # Final normalization and head
        self.layer_norm = RMSNorm(epsilon=1e-6)
        self.lm_head = layers.Dense(config.vocab_size, kernel_initializer=initializers.he_normal(),
                                    kernel_regularizer=regularizers.l2(config.l2_reg))

        # Stochastic depth (layer drop)
        self.layer_dropout = tf.keras.layers.Dropout(config.dropout_rate)

    def build(self, input_shape):
        # Create transformer decoder blocks based on the model configuration
        self.decoder_blocks = [self._build_decoder_block() for _ in range(self.config.num_hidden_layers)]
        super(MiniSunModel, self).build(input_shape)

    def _build_decoder_block(self):
        # Decoder block with multi-query attention and feed-forward layers, using RMSNorm and regularization
        return [
            layers.MultiHeadAttention(num_heads=self.config.num_attention_heads, key_dim=self.config.hidden_size,
                                      kernel_initializer=initializers.he_normal(),
                                      kernel_regularizer=regularizers.l2(self.config.l2_reg)),
            RMSNorm(epsilon=1e-6),  # Use RMSNorm instead of LayerNormalization
            layers.Dense(self.config.intermediate_size, activation=activations.elu,
                         kernel_initializer=initializers.he_normal(),
                         kernel_regularizer=regularizers.l1_l2(self.config.l1_reg, self.config.l2_reg)),
            layers.Dense(self.config.hidden_size, kernel_initializer=initializers.he_normal(),
                         kernel_regularizer=regularizers.l1_l2(self.config.l1_reg, self.config.l2_reg)),
            layers.Dropout(self.config.dropout_rate)
        ]

    def call(self, inputs, attention_mask=None, training=False):
        input_ids = inputs['input_ids']
        position_ids = tf.range(start=0, limit=tf.shape(input_ids)[-1], delta=1)

        # Token and position embeddings with RoPE (Rotary Positional Embeddings)
        embeddings = self.token_embedding(input_ids) + self.position_embedding(position_ids)

        # Adjust attention mask to correct shape [batch_size, 1, 1, seq_len]
        if attention_mask is not None:
            attention_mask = tf.cast(attention_mask[:, tf.newaxis, tf.newaxis, :], dtype=tf.float32)

        # Apply decoder blocks with stochastic depth and gradient clipping
        hidden_states = embeddings
        for mha, norm, ffn1, ffn2, dropout in self.decoder_blocks:
            attn_output = mha(hidden_states, hidden_states, attention_mask=attention_mask, training=training)
            attn_output = dropout(attn_output, training=training)
            hidden_states = norm(attn_output + hidden_states)  # Add & RMSNorm

            # Feed-forward layers
            ffn_output = ffn1(hidden_states)
            ffn_output = ffn2(ffn_output)
            ffn_output = dropout(ffn_output, training=training)
            hidden_states = norm(ffn_output + hidden_states)  # Add & RMSNorm

        # Final layer normalization
        hidden_states = self.layer_norm(hidden_states)

        # LM Head for token generation
        logits = self.lm_head(hidden_states)

        # Softmax layer for confidence
        softmax_output = tf.nn.softmax(logits, axis=-1)

        return logits, softmax_output

    def get_config(self):
        return {'config': self.config.__dict__}

    @classmethod
    def from_config(cls, config):
        return cls(MiniSunConfig(**config['config']))

    def compute_loss(self, labels, logits):
        if labels is None or logits is None:
            raise ValueError("Labels and logits cannot be None.")
        # Add label smoothing to loss computation
        return tf.keras.losses.sparse_categorical_crossentropy(labels, logits, from_logits=True, label_smoothing=0.1)

    def train_step(self, data):
        inputs, labels = data
        input_ids = inputs['input_ids']
        attention_mask = inputs['attention_mask']
    
        with tf.GradientTape() as tape:
            logits, _ = self(inputs, training=True)
            loss = self.compute_loss(labels, logits)
    
        gradients = tape.gradient(loss, self.trainable_variables)
    
        # Gradient clipping for stability
        gradients = [tf.clip_by_value(g, -1.0, 1.0) for g in gradients]
        self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
    
        # Compute predictions and metrics
        logits_for_metrics = tf.argmax(logits, axis=-1)
        labels_for_metrics = tf.reshape(labels, [-1])  # Flatten labels
        logits_for_metrics = tf.reshape(logits_for_metrics, [-1])  # Flatten predictions
    
        for metric in self.metrics:
            metric.update_state(labels_for_metrics, logits_for_metrics)
    
        # Return loss and metrics
        results = {m.name: m.result() for m in self.metrics}
        results['loss'] = loss
        
        return results


def create_model(config):
    model = MiniSunModel(config)

    # Optimizer with weight decay and mixed precision training
    optimizer = tf.keras.mixed_precision.LossScaleOptimizer(
        tf.keras.optimizers.AdamW(learning_rate=config.learning_rate, weight_decay=config.weight_decay)
    )
    strategy = tf.distribute.get_strategy()
    with strategy.scope():
        model.compile(optimizer=optimizer,
                      loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                      metrics=['accuracy'])
    return model

def cosine_annealing_with_warmup(step, config):
    """Learning rate schedule with warm-up and cosine annealing."""
    warmup_steps = int(config.total_steps * config.warmup_ratio)
    if step < warmup_steps:
        return config.learning_rate * (step / warmup_steps)
    else:
        cos_step = step - warmup_steps
        total_cos_steps = config.total_steps - warmup_steps
        return 0.5 * config.learning_rate * (1 + tf.cos(tf.constant(np.pi) * cos_step / total_cos_steps))

def cosine_annealing_with_restarts(step, config):
    """Learning rate schedule with warm-up and cosine annealing with restarts."""
    warmup_steps = int(config.total_steps * config.warmup_ratio)
    current_cycle = step // config.restart_period
    effective_step = step % config.restart_period

    if effective_step < warmup_steps:
        return config.learning_rate * (effective_step / warmup_steps)
    else:
        cos_step = effective_step - warmup_steps
        total_cos_steps = config.restart_period - warmup_steps
        return 0.5 * config.learning_rate * (1 + tf.cos(tf.constant(np.pi) * cos_step / total_cos_steps))

# Configuration
config = MiniSunConfig(l1_reg=1e-5, l2_reg=3e-4)

# Initialize model with improvements
model = create_model(config)

# Create LearningRateScheduler callbacks
lr_scheduler = tf.keras.callbacks.LearningRateScheduler(lambda step: cosine_annealing_with_warmup(step, config))
lr_scheduler_with_restarts = tf.keras.callbacks.LearningRateScheduler(lambda step: cosine_annealing_with_restarts(step, config))