Praise2112
commited on
Upload modelling.py
Browse files- modelling.py +5 -35
modelling.py
CHANGED
@@ -5,8 +5,7 @@ from torch import nn
|
|
5 |
from torch.nn import CrossEntropyLoss
|
6 |
from transformers import ModernBertModel, ModernBertPreTrainedModel, ModernBertConfig
|
7 |
from transformers.modeling_outputs import QuestionAnsweringModelOutput
|
8 |
-
from transformers.models.modernbert.modeling_modernbert import
|
9 |
-
ModernBertPredictionHead
|
10 |
|
11 |
|
12 |
class ModernBertForQuestionAnswering(ModernBertPreTrainedModel):
|
@@ -26,10 +25,6 @@ class ModernBertForQuestionAnswering(ModernBertPreTrainedModel):
|
|
26 |
# Initialize weights and apply final processing
|
27 |
self.post_init()
|
28 |
|
29 |
-
@torch.compile(dynamic=True)
|
30 |
-
def compiled_head(self, output: torch.Tensor) -> torch.Tensor:
|
31 |
-
return self.head(output)
|
32 |
-
|
33 |
def forward(
|
34 |
self,
|
35 |
input_ids: Optional[torch.Tensor],
|
@@ -46,6 +41,7 @@ class ModernBertForQuestionAnswering(ModernBertPreTrainedModel):
|
|
46 |
output_attentions: Optional[bool] = None,
|
47 |
output_hidden_states: Optional[bool] = None,
|
48 |
return_dict: Optional[bool] = None,
|
|
|
49 |
) -> Union[Tuple[torch.Tensor], QuestionAnsweringModelOutput]:
|
50 |
r"""
|
51 |
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
@@ -60,20 +56,6 @@ class ModernBertForQuestionAnswering(ModernBertPreTrainedModel):
|
|
60 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
61 |
self._maybe_set_compile()
|
62 |
|
63 |
-
# Get sequence length and batch size if not provided
|
64 |
-
# if batch_size is None or seq_len is None:
|
65 |
-
# batch_size, seq_len = input_ids.shape[:2]
|
66 |
-
|
67 |
-
# # Handle Flash Attention 2 unpadding
|
68 |
-
# if self.config._attn_implementation == "flash_attention_2":
|
69 |
-
# if indices is None and cu_seqlens is None and max_seqlen is None:
|
70 |
-
# if attention_mask is None:
|
71 |
-
# attention_mask = torch.ones((batch_size, seq_len), device=input_ids.device, dtype=torch.bool)
|
72 |
-
# with torch.no_grad():
|
73 |
-
# input_ids, indices, cu_seqlens, max_seqlen, position_ids, _ = _unpad_modernbert_input(
|
74 |
-
# inputs=input_ids, attention_mask=attention_mask, position_ids=position_ids
|
75 |
-
# )
|
76 |
-
|
77 |
outputs = self.model(
|
78 |
input_ids,
|
79 |
attention_mask=attention_mask,
|
@@ -90,24 +72,12 @@ class ModernBertForQuestionAnswering(ModernBertPreTrainedModel):
|
|
90 |
)
|
91 |
|
92 |
sequence_output = outputs[0]
|
93 |
-
sequence_output = (
|
94 |
-
self.drop(self.compiled_head(sequence_output))
|
95 |
-
if self.config.reference_compile
|
96 |
-
else self.drop(self.head(sequence_output))
|
97 |
-
)
|
98 |
-
# sequence_output = self.drop(self.head(sequence_output))
|
99 |
|
100 |
logits = self.qa_outputs(sequence_output)
|
101 |
start_logits, end_logits = logits.split(1, dim=-1)
|
102 |
-
start_logits = start_logits.squeeze(-1)
|
103 |
-
end_logits = end_logits.squeeze(-1)
|
104 |
-
|
105 |
-
# # Handle Flash Attention 2 padding
|
106 |
-
# if self.config._attn_implementation == "flash_attention_2":
|
107 |
-
# start_logits = _pad_modernbert_output(inputs=start_logits, indices=indices, batch=batch_size,
|
108 |
-
# seqlen=seq_len)
|
109 |
-
# end_logits = _pad_modernbert_output(inputs=end_logits, indices=indices, batch=batch_size,
|
110 |
-
# seqlen=seq_len)
|
111 |
|
112 |
total_loss = None
|
113 |
if start_positions is not None and end_positions is not None:
|
|
|
5 |
from torch.nn import CrossEntropyLoss
|
6 |
from transformers import ModernBertModel, ModernBertPreTrainedModel, ModernBertConfig
|
7 |
from transformers.modeling_outputs import QuestionAnsweringModelOutput
|
8 |
+
from transformers.models.modernbert.modeling_modernbert import ModernBertPredictionHead
|
|
|
9 |
|
10 |
|
11 |
class ModernBertForQuestionAnswering(ModernBertPreTrainedModel):
|
|
|
25 |
# Initialize weights and apply final processing
|
26 |
self.post_init()
|
27 |
|
|
|
|
|
|
|
|
|
28 |
def forward(
|
29 |
self,
|
30 |
input_ids: Optional[torch.Tensor],
|
|
|
41 |
output_attentions: Optional[bool] = None,
|
42 |
output_hidden_states: Optional[bool] = None,
|
43 |
return_dict: Optional[bool] = None,
|
44 |
+
**kwargs,
|
45 |
) -> Union[Tuple[torch.Tensor], QuestionAnsweringModelOutput]:
|
46 |
r"""
|
47 |
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
|
|
56 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
57 |
self._maybe_set_compile()
|
58 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
outputs = self.model(
|
60 |
input_ids,
|
61 |
attention_mask=attention_mask,
|
|
|
72 |
)
|
73 |
|
74 |
sequence_output = outputs[0]
|
75 |
+
sequence_output = self.drop(self.head(sequence_output))
|
|
|
|
|
|
|
|
|
|
|
76 |
|
77 |
logits = self.qa_outputs(sequence_output)
|
78 |
start_logits, end_logits = logits.split(1, dim=-1)
|
79 |
+
start_logits = start_logits.squeeze(-1).contiguous()
|
80 |
+
end_logits = end_logits.squeeze(-1).contiguous()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
|
82 |
total_loss = None
|
83 |
if start_positions is not None and end_positions is not None:
|