Daniel Hesslow
commited on
Commit
·
854fab6
1
Parent(s):
9562025
Update modelling_RW.py
Browse files- modelling_RW.py +2 -2
modelling_RW.py
CHANGED
|
@@ -21,7 +21,7 @@ from transformers.modeling_outputs import (
|
|
| 21 |
)
|
| 22 |
from transformers.modeling_utils import PreTrainedModel
|
| 23 |
from transformers.utils import logging
|
| 24 |
-
from configuration_RW import RWConfig
|
| 25 |
|
| 26 |
logger = logging.get_logger(__name__)
|
| 27 |
|
|
@@ -303,7 +303,7 @@ class Attention(nn.Module):
|
|
| 303 |
attention_scores = attention_scores.to(torch.float32)
|
| 304 |
# attn_weights = torch.masked_fill(attention_scores, attention_mask, torch.finfo(attention_scores.dtype).min)
|
| 305 |
attention_probs = F.softmax(
|
| 306 |
-
(attention_scores + alibi) * self.inv_norm_factor + attention_mask_float,
|
| 307 |
dim=-1,
|
| 308 |
dtype=hidden_states.dtype,
|
| 309 |
)
|
|
|
|
| 21 |
)
|
| 22 |
from transformers.modeling_utils import PreTrainedModel
|
| 23 |
from transformers.utils import logging
|
| 24 |
+
from .configuration_RW import RWConfig
|
| 25 |
|
| 26 |
logger = logging.get_logger(__name__)
|
| 27 |
|
|
|
|
| 303 |
attention_scores = attention_scores.to(torch.float32)
|
| 304 |
# attn_weights = torch.masked_fill(attention_scores, attention_mask, torch.finfo(attention_scores.dtype).min)
|
| 305 |
attention_probs = F.softmax(
|
| 306 |
+
(attention_scores + alibi.view(batch_size, self.num_heads, 1, -1)) * self.inv_norm_factor + attention_mask_float,
|
| 307 |
dim=-1,
|
| 308 |
dtype=hidden_states.dtype,
|
| 309 |
)
|