Praise2112 commited on
Commit
7f80132
verified
1 Parent(s): 10bdc30

Upload modelling.py

Browse files
Files changed (1) hide show
  1. modelling.py +5 -35
modelling.py CHANGED
@@ -5,8 +5,7 @@ from torch import nn
5
  from torch.nn import CrossEntropyLoss
6
  from transformers import ModernBertModel, ModernBertPreTrainedModel, ModernBertConfig
7
  from transformers.modeling_outputs import QuestionAnsweringModelOutput
8
- from transformers.models.modernbert.modeling_modernbert import _pad_modernbert_output, _unpad_modernbert_input, \
9
- ModernBertPredictionHead
10
 
11
 
12
  class ModernBertForQuestionAnswering(ModernBertPreTrainedModel):
@@ -26,10 +25,6 @@ class ModernBertForQuestionAnswering(ModernBertPreTrainedModel):
26
  # Initialize weights and apply final processing
27
  self.post_init()
28
 
29
- @torch.compile(dynamic=True)
30
- def compiled_head(self, output: torch.Tensor) -> torch.Tensor:
31
- return self.head(output)
32
-
33
  def forward(
34
  self,
35
  input_ids: Optional[torch.Tensor],
@@ -46,6 +41,7 @@ class ModernBertForQuestionAnswering(ModernBertPreTrainedModel):
46
  output_attentions: Optional[bool] = None,
47
  output_hidden_states: Optional[bool] = None,
48
  return_dict: Optional[bool] = None,
 
49
  ) -> Union[Tuple[torch.Tensor], QuestionAnsweringModelOutput]:
50
  r"""
51
  start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@@ -60,20 +56,6 @@ class ModernBertForQuestionAnswering(ModernBertPreTrainedModel):
60
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
61
  self._maybe_set_compile()
62
 
63
- # Get sequence length and batch size if not provided
64
- # if batch_size is None or seq_len is None:
65
- # batch_size, seq_len = input_ids.shape[:2]
66
-
67
- # # Handle Flash Attention 2 unpadding
68
- # if self.config._attn_implementation == "flash_attention_2":
69
- # if indices is None and cu_seqlens is None and max_seqlen is None:
70
- # if attention_mask is None:
71
- # attention_mask = torch.ones((batch_size, seq_len), device=input_ids.device, dtype=torch.bool)
72
- # with torch.no_grad():
73
- # input_ids, indices, cu_seqlens, max_seqlen, position_ids, _ = _unpad_modernbert_input(
74
- # inputs=input_ids, attention_mask=attention_mask, position_ids=position_ids
75
- # )
76
-
77
  outputs = self.model(
78
  input_ids,
79
  attention_mask=attention_mask,
@@ -90,24 +72,12 @@ class ModernBertForQuestionAnswering(ModernBertPreTrainedModel):
90
  )
91
 
92
  sequence_output = outputs[0]
93
- sequence_output = (
94
- self.drop(self.compiled_head(sequence_output))
95
- if self.config.reference_compile
96
- else self.drop(self.head(sequence_output))
97
- )
98
- # sequence_output = self.drop(self.head(sequence_output))
99
 
100
  logits = self.qa_outputs(sequence_output)
101
  start_logits, end_logits = logits.split(1, dim=-1)
102
- start_logits = start_logits.squeeze(-1)
103
- end_logits = end_logits.squeeze(-1)
104
-
105
- # # Handle Flash Attention 2 padding
106
- # if self.config._attn_implementation == "flash_attention_2":
107
- # start_logits = _pad_modernbert_output(inputs=start_logits, indices=indices, batch=batch_size,
108
- # seqlen=seq_len)
109
- # end_logits = _pad_modernbert_output(inputs=end_logits, indices=indices, batch=batch_size,
110
- # seqlen=seq_len)
111
 
112
  total_loss = None
113
  if start_positions is not None and end_positions is not None:
 
5
  from torch.nn import CrossEntropyLoss
6
  from transformers import ModernBertModel, ModernBertPreTrainedModel, ModernBertConfig
7
  from transformers.modeling_outputs import QuestionAnsweringModelOutput
8
+ from transformers.models.modernbert.modeling_modernbert import ModernBertPredictionHead
 
9
 
10
 
11
  class ModernBertForQuestionAnswering(ModernBertPreTrainedModel):
 
25
  # Initialize weights and apply final processing
26
  self.post_init()
27
 
 
 
 
 
28
  def forward(
29
  self,
30
  input_ids: Optional[torch.Tensor],
 
41
  output_attentions: Optional[bool] = None,
42
  output_hidden_states: Optional[bool] = None,
43
  return_dict: Optional[bool] = None,
44
+ **kwargs,
45
  ) -> Union[Tuple[torch.Tensor], QuestionAnsweringModelOutput]:
46
  r"""
47
  start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
 
56
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
57
  self._maybe_set_compile()
58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  outputs = self.model(
60
  input_ids,
61
  attention_mask=attention_mask,
 
72
  )
73
 
74
  sequence_output = outputs[0]
75
+ sequence_output = self.drop(self.head(sequence_output))
 
 
 
 
 
76
 
77
  logits = self.qa_outputs(sequence_output)
78
  start_logits, end_logits = logits.split(1, dim=-1)
79
+ start_logits = start_logits.squeeze(-1).contiguous()
80
+ end_logits = end_logits.squeeze(-1).contiguous()
 
 
 
 
 
 
 
81
 
82
  total_loss = None
83
  if start_positions is not None and end_positions is not None: