Update wav2vec2speechclassification.py
Browse files
wav2vec2speechclassification.py
CHANGED
@@ -26,6 +26,7 @@ from transformers.models.wav2vec2.modeling_wav2vec2 import (
|
|
26 |
class Wav2Vec2ClassificationHead(nn.Module):
|
27 |
"""Head for wav2vec classification task."""
|
28 |
config_class = Wav2Vec2Config
|
|
|
29 |
|
30 |
def __init__(self, config):
|
31 |
super().__init__()
|
@@ -45,6 +46,7 @@ class Wav2Vec2ClassificationHead(nn.Module):
|
|
45 |
|
46 |
class Wav2Vec2ForSpeechClassification(Wav2Vec2PreTrainedModel):
|
47 |
config_class = Wav2Vec2Config
|
|
|
48 |
|
49 |
def __init__(self, config):
|
50 |
super().__init__(config)
|
|
|
26 |
class Wav2Vec2ClassificationHead(nn.Module):
|
27 |
"""Head for wav2vec classification task."""
|
28 |
config_class = Wav2Vec2Config
|
29 |
+
model_type = "wav2vec2"
|
30 |
|
31 |
def __init__(self, config):
|
32 |
super().__init__()
|
|
|
46 |
|
47 |
class Wav2Vec2ForSpeechClassification(Wav2Vec2PreTrainedModel):
|
48 |
config_class = Wav2Vec2Config
|
49 |
+
model_type = "wav2vec2"
|
50 |
|
51 |
def __init__(self, config):
|
52 |
super().__init__(config)
|