NoteDance commited on
Commit
875bbcb
1 Parent(s): 65d3bdd

Update Llama3.py

Browse files
Files changed (1) hide show
  1. Llama3.py +24 -12
Llama3.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
  # This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
3
  import tensorflow as tf
4
  from tensorflow.keras.layers import Embedding,Dense
@@ -25,10 +25,15 @@ class ModelArgs:
25
  max_seq_len: int = 2048
26
 
27
 
28
- class RMSNorm:
29
  def __init__(self, dim: int, eps: float = 1e-6):
30
  self.eps = eps
31
- self.weight = tf.Variable(tf.ones((dim)))
 
 
 
 
 
32
 
33
  def _norm(self, x):
34
  return x * tf.math.rsqrt(tf.reduce_mean(tf.pow(x, 2), -1, keepdims=True) + self.eps)
@@ -89,7 +94,7 @@ def repeat_kv(x, n_rep: int):
89
  return tf.reshape(tf.tile(x[:, :, :, None, :], [1, 1, 1, n_rep, 1]), (bs, slen, n_kv_heads * n_rep, head_dim))
90
 
91
 
92
- class Attention:
93
  def __init__(self, args: ModelArgs):
94
  self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
95
  model_parallel_size = 1
@@ -115,22 +120,29 @@ class Attention:
115
  use_bias=False,
116
  )
117
 
118
- self.cache_k = tf.Variable(tf.zeros(
119
- (
 
120
  args.max_batch_size,
121
  args.max_seq_len,
122
  self.n_local_kv_heads,
123
  self.head_dim,
124
- )
125
- ), trainable=False)
126
- self.cache_v = tf.Variable(tf.zeros(
127
- (
 
 
 
 
128
  args.max_batch_size,
129
  args.max_seq_len,
130
  self.n_local_kv_heads,
131
  self.head_dim,
132
- )
133
- ), trainable=False)
 
 
134
 
135
  def __call__(
136
  self,
 
1
+ # Copyright (c) NoteDance, Inc. and affiliates.
2
  # This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
3
  import tensorflow as tf
4
  from tensorflow.keras.layers import Embedding,Dense
 
25
  max_seq_len: int = 2048
26
 
27
 
28
+ class RMSNorm(tf.keras.layers.Layer):
29
  def __init__(self, dim: int, eps: float = 1e-6):
30
  self.eps = eps
31
+ self.weight = self.add_weight(
32
+ name='weight',
33
+ shape=(self.dim,),
34
+ initializer=tf.keras.initializers.Ones(),
35
+ trainable=True
36
+ )
37
 
38
  def _norm(self, x):
39
  return x * tf.math.rsqrt(tf.reduce_mean(tf.pow(x, 2), -1, keepdims=True) + self.eps)
 
94
  return tf.reshape(tf.tile(x[:, :, :, None, :], [1, 1, 1, n_rep, 1]), (bs, slen, n_kv_heads * n_rep, head_dim))
95
 
96
 
97
+ class Attention(tf.keras.layers.Layer):
98
  def __init__(self, args: ModelArgs):
99
  self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
100
  model_parallel_size = 1
 
120
  use_bias=False,
121
  )
122
 
123
+ self.cache_k = self.add_weight(
124
+ name='cache_k',
125
+ shape=(
126
  args.max_batch_size,
127
  args.max_seq_len,
128
  self.n_local_kv_heads,
129
  self.head_dim,
130
+ ),
131
+ initializer=tf.keras.initializers.Zeros(),
132
+ trainable=False
133
+ )
134
+
135
+ self.cache_v = self.add_weight(
136
+ name='cache_v',
137
+ shape=(
138
  args.max_batch_size,
139
  args.max_seq_len,
140
  self.n_local_kv_heads,
141
  self.head_dim,
142
+ ),
143
+ initializer=tf.keras.initializers.Zeros(),
144
+ trainable=False
145
+ )
146
 
147
  def __call__(
148
  self,