Update modeling_paragram_sp.py
Browse files- modeling_paragram_sp.py +2 -2
modeling_paragram_sp.py
CHANGED
@@ -12,7 +12,7 @@ class ParagramSPModel(BertPreTrainedModel):
|
|
12 |
|
13 |
def filter_input_ids(self, input_ids):
|
14 |
output = []
|
15 |
-
|
16 |
for i in range(input_ids.shape[0]):
|
17 |
ids = input_ids[i]
|
18 |
filtered_ids = []
|
@@ -21,7 +21,7 @@ class ParagramSPModel(BertPreTrainedModel):
|
|
21 |
filtered_ids.append(j)
|
22 |
if len(filtered_ids) == 0:
|
23 |
filtered_ids = [0]
|
24 |
-
output.append(filtered_ids + [config.pad_token_id] * (
|
25 |
return torch.tensor(output)
|
26 |
|
27 |
def forward(self, input_ids, attention_mask):
|
|
|
12 |
|
13 |
def filter_input_ids(self, input_ids):
|
14 |
output = []
|
15 |
+
length = input_ids.shape[1]
|
16 |
for i in range(input_ids.shape[0]):
|
17 |
ids = input_ids[i]
|
18 |
filtered_ids = []
|
|
|
21 |
filtered_ids.append(j)
|
22 |
if len(filtered_ids) == 0:
|
23 |
filtered_ids = [0]
|
24 |
+
output.append(filtered_ids + [config.pad_token_id] * (length - length(filtered_ids)))
|
25 |
return torch.tensor(output)
|
26 |
|
27 |
def forward(self, input_ids, attention_mask):
|