fix attention mask collation (#1603)
Browse files
src/axolotl/utils/collators.py
CHANGED
@@ -229,9 +229,8 @@ class PretrainingBatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
|
229 |
if feature == "attention_mask":
|
230 |
if self.multipack_attn:
|
231 |
arrays = [
|
232 |
-
(i + 1) * np.array(item
|
233 |
for i, item in enumerate(features[feature])
|
234 |
-
if feature in item
|
235 |
]
|
236 |
else:
|
237 |
arrays = [(1) * np.array(item) for item in features[feature]]
|
|
|
229 |
if feature == "attention_mask":
|
230 |
if self.multipack_attn:
|
231 |
arrays = [
|
232 |
+
(i + 1) * np.array(item)
|
233 |
for i, item in enumerate(features[feature])
|
|
|
234 |
]
|
235 |
else:
|
236 |
arrays = [(1) * np.array(item) for item in features[feature]]
|