kuaizhirui
commited on
Commit
•
cae273b
1
Parent(s):
f9d4d8d
fix batch infer
Browse files解决左pad之后 batch infer总是输出unk的问题 或者 和单条样本推理结果不一致的问题,本质上精度不一致的问题,由expanded_attn_mask和combined_attention_mask相加导致的,因此先换成torch.finfo(dtype).min的一半
- modeling_baichuan.py +6 -3
modeling_baichuan.py
CHANGED
@@ -358,9 +358,12 @@ class BaichuanModel(BaichuanPreTrainedModel):
|
|
358 |
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
|
359 |
inputs_embeds.device
|
360 |
)
|
361 |
-
combined_attention_mask
|
362 |
-
|
363 |
-
|
|
|
|
|
|
|
364 |
|
365 |
return combined_attention_mask
|
366 |
|
|
|
358 |
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
|
359 |
inputs_embeds.device
|
360 |
)
|
361 |
+
if combined_attention_mask is None:
|
362 |
+
combined_attention_mask = expanded_attn_mask
|
363 |
+
else:
|
364 |
+
expanded_attn_mask = torch.where(expanded_attn_mask == torch.finfo(inputs_embeds.dtype).min, torch.finfo(inputs_embeds.dtype).min / 2, expanded_attn_mask)
|
365 |
+
combined_attention_mask = torch.where(combined_attention_mask == torch.finfo(inputs_embeds.dtype).min, torch.finfo(inputs_embeds.dtype).min / 2, combined_attention_mask)
|
366 |
+
combined_attention_mask = expanded_attn_mask + combined_attention_mask
|
367 |
|
368 |
return combined_attention_mask
|
369 |
|