Ar4ikov commited on
Commit
becd3e3
·
1 Parent(s): 88ab3bb

Update audio_text_multimodal.py

Browse files
Files changed (1) hide show
  1. audio_text_multimodal.py +15 -0
audio_text_multimodal.py CHANGED
@@ -12,6 +12,14 @@ from transformers import (
12
  Wav2Vec2Model
13
  )
14
 
 
 
 
 
 
 
 
 
15
 
16
  class MultiModalConfig(PretrainedConfig):
17
  """Base class for multimodal configs"""
@@ -191,6 +199,8 @@ class WavLMBertForSequenceClassification(AudioTextFusionModelForSequenceClassifi
191
  """
192
  def __init__(self, config):
193
  super().__init__(config)
 
 
194
  self.audio_config = WavLMConfig.from_dict(self.config.WavLMModel)
195
  self.text_config = BertConfig.from_dict(self.config.BertModel)
196
  self.audio_model = WavLMModel(self.audio_config)
@@ -215,6 +225,11 @@ class WavLMBertForSequenceClassification(AudioTextFusionModelForSequenceClassifi
215
  (cls_dim * 2) // self.config.kernel_size, self.config.num_labels
216
  )
217
  self.init_weights()
 
 
 
 
 
218
 
219
  def forward(
220
  self,
 
12
  Wav2Vec2Model
13
  )
14
 
15
+ from transformers.models.wavlm.modeling_wavlm import (
16
+ WavLMEncoder,
17
+ WavLMEncoderStableLayerNorm,
18
+ WavLMFeatureEncoder
19
+ )
20
+
21
+ from transformers.models.bert.modeling_bert import BertEncoder
22
+
23
 
24
  class MultiModalConfig(PretrainedConfig):
25
  """Base class for multimodal configs"""
 
199
  """
200
  def __init__(self, config):
201
  super().__init__(config)
202
+ self.supports_gradient_checkpointing = getattr(config, "gradient_checkpointing", True)
203
+
204
  self.audio_config = WavLMConfig.from_dict(self.config.WavLMModel)
205
  self.text_config = BertConfig.from_dict(self.config.BertModel)
206
  self.audio_model = WavLMModel(self.audio_config)
 
225
  (cls_dim * 2) // self.config.kernel_size, self.config.num_labels
226
  )
227
  self.init_weights()
228
+
229
+ @staticmethod
230
+ def _set_gradient_checkpointing(module, value=False):
231
+ if isinstance(module, (WavLMEncoder, WavLMEncoderStableLayerNorm, WavLMFeatureEncoder, BertEncoder)):
232
+ module.gradient_checkpointing = value
233
 
234
  def forward(
235
  self,