NoteDance commited on
Commit
7beda09
1 Parent(s): 011cf8c

Update ViT.py

Browse files
Files changed (1) hide show
  1. ViT.py +12 -2
ViT.py CHANGED
@@ -101,8 +101,18 @@ class ViT(Model):
101
  self.to_patch_embedding.add(Dense(dim))
102
  self.to_patch_embedding.add(LayerNormalization())
103
 
104
- self.pos_embedding = tf.Variable(tf.random.normal((1, num_patches + 1, dim)))
105
- self.cls_token = tf.Variable(tf.random.normal(((1, 1, dim))))
 
 
 
 
 
 
 
 
 
 
106
  self.dropout = Dropout(emb_dropout)
107
 
108
  self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, drop_rate)
 
101
  self.to_patch_embedding.add(Dense(dim))
102
  self.to_patch_embedding.add(LayerNormalization())
103
 
104
+ self.pos_embedding = self.add_weight(
105
+ name='pos_embedding',
106
+ shape=(1, self.num_patches + 1, self.dim),
107
+ initializer=tf.keras.initializers.RandomNormal(stddev=0.02), # 设定标准差 stddev
108
+ trainable=True
109
+ )
110
+ self.cls_token = self.add_weight(
111
+ name='cls_token',
112
+ shape=(1, 1, self.dim),
113
+ initializer=tf.keras.initializers.RandomNormal(stddev=0.02), # 设定标准差 stddev
114
+ trainable=True
115
+ )
116
  self.dropout = Dropout(emb_dropout)
117
 
118
  self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, drop_rate)