qilowoq commited on
Commit
165de6b
1 Parent(s): ea06087

Upload AbLang

Browse files
Files changed (1) hide show
  1. model.py +16 -6
model.py CHANGED
@@ -35,13 +35,23 @@ class AbLang(PreTrainedModel):
35
  self.AbEmbeddings = AbEmbeddings(config)
36
  self.EncoderBlocks = EncoderBlocks(config)
37
 
38
- def forward(self, inputs):
39
- src = self.AbEmbeddings(inputs['input_ids'])
40
- outputs = self.EncoderBlocks(src, attention_mask=1-inputs['attention_mask'], output_attentions=False)
41
- return apply_cls_embeddings(inputs, outputs)
 
 
 
 
 
 
 
 
 
 
42
 
43
- def apply_cls_embeddings(inputs, outputs):
44
- mask = inputs['attention_mask'].float()
45
  d = {k: v for k, v in torch.nonzero(mask).cpu().numpy()} # dict of sep tokens
46
  # make sep token invisible
47
  for i in d:
 
35
  self.AbEmbeddings = AbEmbeddings(config)
36
  self.EncoderBlocks = EncoderBlocks(config)
37
 
38
+ def forward(
39
+ self,
40
+ input_ids=None,
41
+ attention_mask=None,
42
+ token_type_ids=None,
43
+ output_attentions=None,
44
+ output_hidden_states=None,
45
+ ):
46
+ src = self.AbEmbeddings(input_ids)
47
+ outputs = self.EncoderBlocks(src,
48
+ attention_mask=1-attention_mask,
49
+ output_attentions=output_attentions,
50
+ output_hidden_states=output_hidden_states)
51
+ return apply_cls_embeddings(attention_mask, outputs)
52
 
53
+ def apply_cls_embeddings(attention_mask, outputs):
54
+ mask = attention_mask.float()
55
  d = {k: v for k, v in torch.nonzero(mask).cpu().numpy()} # dict of sep tokens
56
  # make sep token invisible
57
  for i in d: