Update ViT.py
Browse files
ViT.py
CHANGED
@@ -101,8 +101,18 @@ class ViT(Model):
|
|
101 |
self.to_patch_embedding.add(Dense(dim))
|
102 |
self.to_patch_embedding.add(LayerNormalization())
|
103 |
|
104 |
-
self.pos_embedding =
|
105 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
106 |
self.dropout = Dropout(emb_dropout)
|
107 |
|
108 |
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, drop_rate)
|
|
|
101 |
self.to_patch_embedding.add(Dense(dim))
|
102 |
self.to_patch_embedding.add(LayerNormalization())
|
103 |
|
104 |
+
self.pos_embedding = self.add_weight(
|
105 |
+
name='pos_embedding',
|
106 |
+
shape=(1, self.num_patches + 1, self.dim),
|
107 |
+
initializer=tf.keras.initializers.RandomNormal(stddev=0.02), # 设定标准差 stddev
|
108 |
+
trainable=True
|
109 |
+
)
|
110 |
+
self.cls_token = self.add_weight(
|
111 |
+
name='cls_token',
|
112 |
+
shape=(1, 1, self.dim),
|
113 |
+
initializer=tf.keras.initializers.RandomNormal(stddev=0.02), # 设定标准差 stddev
|
114 |
+
trainable=True
|
115 |
+
)
|
116 |
self.dropout = Dropout(emb_dropout)
|
117 |
|
118 |
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, drop_rate)
|