Spaces:
Sleeping
Sleeping
embedding fix
Browse files- models/transformers.py +4 -1
models/transformers.py
CHANGED
@@ -195,7 +195,10 @@ class Transformer_decoder(nn.Module):
|
|
195 |
self.decoder_norm = nn.LayerNorm(512)
|
196 |
self.decoder_layers_parallel = clones(DecoderLayer(512, c(attn), c(attn), c(ff), dropout=0.0), 1)
|
197 |
self.decoder_norm_parallel = nn.LayerNorm(512)
|
198 |
-
|
|
|
|
|
|
|
199 |
self.cls_token = nn.Parameter(torch.zeros(1, 1, 512))
|
200 |
|
201 |
def forward(self, x, memory, trg_char, src_mask=None, tgt_mask=None):
|
|
|
195 |
self.decoder_norm = nn.LayerNorm(512)
|
196 |
self.decoder_layers_parallel = clones(DecoderLayer(512, c(attn), c(attn), c(ff), dropout=0.0), 1)
|
197 |
self.decoder_norm_parallel = nn.LayerNorm(512)
|
198 |
+
if opts.ref_nshot == 52:
|
199 |
+
self.cls_embedding = nn.Embedding(92,512)
|
200 |
+
else:
|
201 |
+
self.cls_embedding = nn.Embedding(52,512)
|
202 |
self.cls_token = nn.Parameter(torch.zeros(1, 1, 512))
|
203 |
|
204 |
def forward(self, x, memory, trg_char, src_mask=None, tgt_mask=None):
|