jwieting commited on
Commit
77a3be3
1 Parent(s): bcab452

Update modeling_paragram_sp.py

Browse files
Files changed (1) hide show
  1. modeling_paragram_sp.py +2 -2
modeling_paragram_sp.py CHANGED
@@ -12,7 +12,7 @@ class ParagramSPModel(BertPreTrainedModel):
12
 
13
  def filter_input_ids(self, input_ids):
14
  output = []
15
- len = input_ids.shape[1]
16
  for i in range(input_ids.shape[0]):
17
  ids = input_ids[i]
18
  filtered_ids = []
@@ -21,7 +21,7 @@ class ParagramSPModel(BertPreTrainedModel):
21
  filtered_ids.append(j)
22
  if len(filtered_ids) == 0:
23
  filtered_ids = [0]
24
- output.append(filtered_ids + [config.pad_token_id] * (len - len(filtered_ids)))
25
  return torch.tensor(output)
26
 
27
  def forward(self, input_ids, attention_mask):
 
12
 
13
  def filter_input_ids(self, input_ids):
14
  output = []
15
+ length = input_ids.shape[1]
16
  for i in range(input_ids.shape[0]):
17
  ids = input_ids[i]
18
  filtered_ids = []
 
21
  filtered_ids.append(j)
22
  if len(filtered_ids) == 0:
23
  filtered_ids = [0]
24
+ output.append(filtered_ids + [config.pad_token_id] * (length - length(filtered_ids)))
25
  return torch.tensor(output)
26
 
27
  def forward(self, input_ids, attention_mask):