Update Llama3.py
Browse files
Llama3.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
# Copyright (c)
|
2 |
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
3 |
import tensorflow as tf
|
4 |
from tensorflow.keras.layers import Embedding,Dense
|
@@ -25,10 +25,15 @@ class ModelArgs:
|
|
25 |
max_seq_len: int = 2048
|
26 |
|
27 |
|
28 |
-
class RMSNorm:
|
29 |
def __init__(self, dim: int, eps: float = 1e-6):
|
30 |
self.eps = eps
|
31 |
-
self.weight =
|
|
|
|
|
|
|
|
|
|
|
32 |
|
33 |
def _norm(self, x):
|
34 |
return x * tf.math.rsqrt(tf.reduce_mean(tf.pow(x, 2), -1, keepdims=True) + self.eps)
|
@@ -89,7 +94,7 @@ def repeat_kv(x, n_rep: int):
|
|
89 |
return tf.reshape(tf.tile(x[:, :, :, None, :], [1, 1, 1, n_rep, 1]), (bs, slen, n_kv_heads * n_rep, head_dim))
|
90 |
|
91 |
|
92 |
-
class Attention:
|
93 |
def __init__(self, args: ModelArgs):
|
94 |
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
|
95 |
model_parallel_size = 1
|
@@ -115,22 +120,29 @@ class Attention:
|
|
115 |
use_bias=False,
|
116 |
)
|
117 |
|
118 |
-
self.cache_k =
|
119 |
-
|
|
|
120 |
args.max_batch_size,
|
121 |
args.max_seq_len,
|
122 |
self.n_local_kv_heads,
|
123 |
self.head_dim,
|
124 |
-
)
|
125 |
-
|
126 |
-
|
127 |
-
|
|
|
|
|
|
|
|
|
128 |
args.max_batch_size,
|
129 |
args.max_seq_len,
|
130 |
self.n_local_kv_heads,
|
131 |
self.head_dim,
|
132 |
-
)
|
133 |
-
|
|
|
|
|
134 |
|
135 |
def __call__(
|
136 |
self,
|
|
|
1 |
+
# Copyright (c) NoteDance, Inc. and affiliates.
|
2 |
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
3 |
import tensorflow as tf
|
4 |
from tensorflow.keras.layers import Embedding,Dense
|
|
|
25 |
max_seq_len: int = 2048
|
26 |
|
27 |
|
28 |
+
class RMSNorm(tf.keras.layers.Layer):
|
29 |
def __init__(self, dim: int, eps: float = 1e-6):
|
30 |
self.eps = eps
|
31 |
+
self.weight = self.add_weight(
|
32 |
+
name='weight',
|
33 |
+
shape=(self.dim,),
|
34 |
+
initializer=tf.keras.initializers.Ones(),
|
35 |
+
trainable=True
|
36 |
+
)
|
37 |
|
38 |
def _norm(self, x):
|
39 |
return x * tf.math.rsqrt(tf.reduce_mean(tf.pow(x, 2), -1, keepdims=True) + self.eps)
|
|
|
94 |
return tf.reshape(tf.tile(x[:, :, :, None, :], [1, 1, 1, n_rep, 1]), (bs, slen, n_kv_heads * n_rep, head_dim))
|
95 |
|
96 |
|
97 |
+
class Attention(tf.keras.layers.Layer):
|
98 |
def __init__(self, args: ModelArgs):
|
99 |
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
|
100 |
model_parallel_size = 1
|
|
|
120 |
use_bias=False,
|
121 |
)
|
122 |
|
123 |
+
self.cache_k = self.add_weight(
|
124 |
+
name='cache_k',
|
125 |
+
shape=(
|
126 |
args.max_batch_size,
|
127 |
args.max_seq_len,
|
128 |
self.n_local_kv_heads,
|
129 |
self.head_dim,
|
130 |
+
),
|
131 |
+
initializer=tf.keras.initializers.Zeros(),
|
132 |
+
trainable=False
|
133 |
+
)
|
134 |
+
|
135 |
+
self.cache_v = self.add_weight(
|
136 |
+
name='cache_v',
|
137 |
+
shape=(
|
138 |
args.max_batch_size,
|
139 |
args.max_seq_len,
|
140 |
self.n_local_kv_heads,
|
141 |
self.head_dim,
|
142 |
+
),
|
143 |
+
initializer=tf.keras.initializers.Zeros(),
|
144 |
+
trainable=False
|
145 |
+
)
|
146 |
|
147 |
def __call__(
|
148 |
self,
|