Update audio_text_multimodal.py
Browse files- audio_text_multimodal.py +15 -0
audio_text_multimodal.py
CHANGED
@@ -12,6 +12,14 @@ from transformers import (
|
|
12 |
Wav2Vec2Model
|
13 |
)
|
14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
|
16 |
class MultiModalConfig(PretrainedConfig):
|
17 |
"""Base class for multimodal configs"""
|
@@ -191,6 +199,8 @@ class WavLMBertForSequenceClassification(AudioTextFusionModelForSequenceClassifi
|
|
191 |
"""
|
192 |
def __init__(self, config):
|
193 |
super().__init__(config)
|
|
|
|
|
194 |
self.audio_config = WavLMConfig.from_dict(self.config.WavLMModel)
|
195 |
self.text_config = BertConfig.from_dict(self.config.BertModel)
|
196 |
self.audio_model = WavLMModel(self.audio_config)
|
@@ -215,6 +225,11 @@ class WavLMBertForSequenceClassification(AudioTextFusionModelForSequenceClassifi
|
|
215 |
(cls_dim * 2) // self.config.kernel_size, self.config.num_labels
|
216 |
)
|
217 |
self.init_weights()
|
|
|
|
|
|
|
|
|
|
|
218 |
|
219 |
def forward(
|
220 |
self,
|
|
|
12 |
Wav2Vec2Model
|
13 |
)
|
14 |
|
15 |
+
from transformers.models.wavlm.modeling_wavlm import (
|
16 |
+
WavLMEncoder,
|
17 |
+
WavLMEncoderStableLayerNorm,
|
18 |
+
WavLMFeatureEncoder
|
19 |
+
)
|
20 |
+
|
21 |
+
from transformers.models.bert.modeling_bert import BertEncoder
|
22 |
+
|
23 |
|
24 |
class MultiModalConfig(PretrainedConfig):
|
25 |
"""Base class for multimodal configs"""
|
|
|
199 |
"""
|
200 |
def __init__(self, config):
|
201 |
super().__init__(config)
|
202 |
+
self.supports_gradient_checkpointing = getattr(config, "gradient_checkpointing", True)
|
203 |
+
|
204 |
self.audio_config = WavLMConfig.from_dict(self.config.WavLMModel)
|
205 |
self.text_config = BertConfig.from_dict(self.config.BertModel)
|
206 |
self.audio_model = WavLMModel(self.audio_config)
|
|
|
225 |
(cls_dim * 2) // self.config.kernel_size, self.config.num_labels
|
226 |
)
|
227 |
self.init_weights()
|
228 |
+
|
229 |
+
@staticmethod
|
230 |
+
def _set_gradient_checkpointing(module, value=False):
|
231 |
+
if isinstance(module, (WavLMEncoder, WavLMEncoderStableLayerNorm, WavLMFeatureEncoder, BertEncoder)):
|
232 |
+
module.gradient_checkpointing = value
|
233 |
|
234 |
def forward(
|
235 |
self,
|