microhum commited on
Commit
3970cec
·
1 Parent(s): 448a707

embedding fix

Browse files
Files changed (1) hide show
  1. models/transformers.py +4 -1
models/transformers.py CHANGED
@@ -195,7 +195,10 @@ class Transformer_decoder(nn.Module):
195
  self.decoder_norm = nn.LayerNorm(512)
196
  self.decoder_layers_parallel = clones(DecoderLayer(512, c(attn), c(attn), c(ff), dropout=0.0), 1)
197
  self.decoder_norm_parallel = nn.LayerNorm(512)
198
- self.cls_embedding = nn.Embedding(52,512)
 
 
 
199
  self.cls_token = nn.Parameter(torch.zeros(1, 1, 512))
200
 
201
  def forward(self, x, memory, trg_char, src_mask=None, tgt_mask=None):
 
195
  self.decoder_norm = nn.LayerNorm(512)
196
  self.decoder_layers_parallel = clones(DecoderLayer(512, c(attn), c(attn), c(ff), dropout=0.0), 1)
197
  self.decoder_norm_parallel = nn.LayerNorm(512)
198
+ if opts.ref_nshot == 52:
199
+ self.cls_embedding = nn.Embedding(92,512)
200
+ else:
201
+ self.cls_embedding = nn.Embedding(52,512)
202
  self.cls_token = nn.Parameter(torch.zeros(1, 1, 512))
203
 
204
  def forward(self, x, memory, trg_char, src_mask=None, tgt_mask=None):