finnstrom3693 commited on
Commit
d3801be
1 Parent(s): f8a4b2b

Create modeling.py

Browse files
Files changed (1) hide show
  1. modeling.py +127 -0
modeling.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ from tensorflow.keras import layers, activations, initializers
3
+
4
+ class MiniSunConfig:
5
+ def __init__(self, vocab_size=30522, max_position_embeddings=1024, hidden_size=512,
6
+ num_attention_heads=8, intermediate_size=2048, num_hidden_layers=8,
7
+ dropout_rate=0.1, weight_decay=0.01, learning_rate=1e-4):
8
+ self.vocab_size = vocab_size
9
+ self.max_position_embeddings = max_position_embeddings
10
+ self.hidden_size = hidden_size
11
+ self.num_attention_heads = num_attention_heads
12
+ self.intermediate_size = intermediate_size
13
+ self.num_hidden_layers = num_hidden_layers
14
+ self.dropout_rate = dropout_rate
15
+ self.weight_decay = weight_decay
16
+ self.learning_rate = learning_rate
17
+
18
+ @tf.keras.utils.register_keras_serializable()
19
+ class MiniSunModel(tf.keras.Model):
20
+ def __init__(self, config):
21
+ super(MiniSunModel, self).__init__()
22
+ self.config = config
23
+
24
+ # Embedding layers for token and position
25
+ self.token_embedding = layers.Embedding(config.vocab_size, config.hidden_size)
26
+ self.position_embedding = layers.Embedding(config.max_position_embeddings, config.hidden_size)
27
+
28
+ # Transformer decoder blocks
29
+ self.decoder_blocks = [self._build_decoder_block() for _ in range(config.num_hidden_layers)]
30
+
31
+ # Final normalization and head
32
+ self.layer_norm = layers.LayerNormalization(epsilon=1e-6)
33
+ self.lm_head = layers.Dense(config.vocab_size, kernel_initializer=initializers.he_normal())
34
+
35
+ def _build_decoder_block(self):
36
+ # Decoder block consisting of multi-head attention and feed-forward layers
37
+ return [
38
+ layers.MultiHeadAttention(num_heads=self.config.num_attention_heads, key_dim=self.config.hidden_size,
39
+ kernel_initializer=initializers.he_normal()),
40
+ layers.LayerNormalization(epsilon=1e-6),
41
+ layers.Dense(self.config.intermediate_size, activation=activations.elu,
42
+ kernel_initializer=initializers.he_normal()),
43
+ layers.Dense(self.config.hidden_size, kernel_initializer=initializers.he_normal()),
44
+ layers.Dropout(self.config.dropout_rate)
45
+ ]
46
+
47
+ def call(self, inputs, attention_mask=None, training=False):
48
+ input_ids = inputs['input_ids']
49
+ position_ids = tf.range(start=0, limit=tf.shape(input_ids)[-1], delta=1)
50
+
51
+ # Token and position embeddings
52
+ embeddings = self.token_embedding(input_ids) + self.position_embedding(position_ids)
53
+
54
+ # Adjust attention mask to correct shape [batch_size, 1, 1, seq_len]
55
+ if attention_mask is not None:
56
+ attention_mask = tf.expand_dims(attention_mask, axis=1)
57
+ attention_mask = tf.expand_dims(attention_mask, axis=1)
58
+
59
+ # Apply decoder blocks
60
+ hidden_states = embeddings
61
+ for mha, norm, ffn1, ffn2, dropout in self.decoder_blocks:
62
+ attn_output = mha(hidden_states, hidden_states, attention_mask=attention_mask, training=training)
63
+ attn_output = dropout(attn_output, training=training)
64
+ hidden_states = norm(attn_output + hidden_states) # Add & Norm
65
+
66
+ # Feed-forward layers
67
+ ffn_output = ffn1(hidden_states)
68
+ ffn_output = ffn2(ffn_output)
69
+ ffn_output = dropout(ffn_output, training=training)
70
+ hidden_states = norm(ffn_output + hidden_states) # Add & Norm
71
+
72
+ # Final layer normalization
73
+ hidden_states = self.layer_norm(hidden_states)
74
+
75
+ # LM Head for token generation
76
+ logits = self.lm_head(hidden_states)
77
+ return logits
78
+
79
+ def get_config(self):
80
+ # Return the configuration of the model
81
+ return {
82
+ 'config': self.config.__dict__
83
+ }
84
+
85
+ @classmethod
86
+ def from_config(cls, config):
87
+ # Create an instance of the model from the config
88
+ return cls(MiniSunConfig(**config['config']))
89
+
90
+ def train_step(self, data):
91
+ # Unpack the data
92
+ inputs, labels = data
93
+
94
+ with tf.GradientTape() as tape:
95
+ logits = self(inputs, training=True)
96
+ loss = self.compiled_loss(labels, logits, regularization_losses=self.losses)
97
+
98
+ # Compute gradients
99
+ trainable_vars = self.trainable_variables
100
+ gradients = tape.gradient(loss, trainable_vars)
101
+
102
+ # Update weights with smoother updates using optimizer
103
+ self.optimizer.apply_gradients(zip(gradients, trainable_vars))
104
+
105
+ # Update metrics
106
+ self.compiled_metrics.update_state(labels, logits)
107
+ return {m.name: m.result() for m in self.metrics}
108
+
109
+ def create_model(config):
110
+ model = MiniSunModel(config)
111
+
112
+ # Optimizer with weight decay
113
+ optimizer = tf.keras.optimizers.AdamW(learning_rate=config.learning_rate, weight_decay=config.weight_decay)
114
+
115
+ # Compile model with ELU activation and smoother weight update process
116
+ model.compile(
117
+ optimizer=optimizer,
118
+ loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
119
+ metrics=['accuracy']
120
+ )
121
+ return model
122
+
123
+ # Configuration
124
+ config = MiniSunConfig()
125
+
126
+ # Initialize model with He initialization
127
+ model = create_model(config)