SeaBenSea commited on
Commit
5c5f218
1 Parent(s): 6fa69eb

Upload 13 files

Browse files
HuBERT-SER/.gitignore ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ pip-wheel-metadata/
24
+ share/python-wheels/
25
+ *.egg-info/
26
+ .installed.cfg
27
+ *.egg
28
+ MANIFEST
29
+
30
+ # PyInstaller
31
+ # Usually these files are written by a python script from a template
32
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
33
+ *.manifest
34
+ *.spec
35
+
36
+ # Installer logs
37
+ pip-log.txt
38
+ pip-delete-this-directory.txt
39
+
40
+ # Unit test / coverage reports
41
+ htmlcov/
42
+ .tox/
43
+ .nox/
44
+ .coverage
45
+ .coverage.*
46
+ .cache
47
+ nosetests.xml
48
+ coverage.xml
49
+ *.cover
50
+ *.py,cover
51
+ .hypothesis/
52
+ .pytest_cache/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ target/
76
+
77
+ # Jupyter Notebook
78
+ .ipynb_checkpoints
79
+
80
+ # IPython
81
+ profile_default/
82
+ ipython_config.py
83
+
84
+ # pyenv
85
+ .python-version
86
+
87
+ # pipenv
88
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
90
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
91
+ # install all needed dependencies.
92
+ #Pipfile.lock
93
+
94
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95
+ __pypackages__/
96
+
97
+ # Celery stuff
98
+ celerybeat-schedule
99
+ celerybeat.pid
100
+
101
+ # SageMath parsed files
102
+ *.sage.py
103
+
104
+ # Environments
105
+ .env
106
+ .venv
107
+ env/
108
+ venv/
109
+ ENV/
110
+ env.bak/
111
+ venv.bak/
112
+
113
+ # Spyder project settings
114
+ .spyderproject
115
+ .spyproject
116
+
117
+ # Rope project settings
118
+ .ropeproject
119
+
120
+ # mkdocs documentation
121
+ /site
122
+
123
+ # mypy
124
+ .mypy_cache/
125
+ .dmypy.json
126
+ dmypy.json
127
+
128
+ # Pyre type checker
129
+ .pyre/
130
+ .idea
131
+ *.tmp.py
132
+ .DS_Store
133
+ .DS_store
HuBERT-SER/LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
HuBERT-SER/README.md ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # HuBERT-SER
2
+
3
+
4
+ This repository consists of models, scripts, and notebooks that help you to use all the benefits of HuBERT 2.0 in your research. In the following, I'll show you how to train speech tasks in your dataset and how to use the pretrained models.
5
+
6
+
7
+ ### Training - CMD
8
+
9
+ ```pwsh
10
+ python "HuBERT-SER\run_wav2vec_clf.py" --pooling_mode="mean" --model_name_or_path="facebook/hubert-large-ll60k" --model_mode="hubert" --output_dir="path\to\output" --cache_dir="path\to\cache" --train_file="dataset\train.csv" --validation_file="dataset\eval.csv" --test_file="dataset\test.csv" --per_device_train_batch_size=4 --per_device_eval_batch_size=4 --gradient_accumulation_steps=2 --learning_rate=1e-4 --num_train_epochs=9.0 --evaluation_strategy='steps' --save_steps=100 --eval_steps=100 --logging_steps=100 --save_total_limit=2 --do_eval --do_train --freeze_feature_extractor
11
+ ```
12
+
13
+ ### Prediction
14
+
15
+ ```python
16
+ import torch
17
+ import torch.nn as nn
18
+ import torch.nn.functional as F
19
+ import torchaudio
20
+ from transformers import AutoConfig, Wav2Vec2FeatureExtractor
21
+ from src.models import Wav2Vec2ForSpeechClassification, HubertForSpeechClassification
22
+
23
+ model_name_or_path = "path/to/your-pretrained-model"
24
+
25
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
26
+ config = AutoConfig.from_pretrained(model_name_or_path)
27
+ feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(model_name_or_path)
28
+ sampling_rate = feature_extractor.sampling_rate
29
+
30
+ model = HubertForSpeechClassification.from_pretrained(model_name_or_path).to(device)
31
+
32
+ def speech_file_to_array_fn(path, sampling_rate):
33
+ speech_array, _sampling_rate = torchaudio.load(path)
34
+ resampler = torchaudio.transforms.Resample(_sampling_rate, sampling_rate)
35
+ speech = resampler(speech_array).squeeze().numpy()
36
+ return speech
37
+
38
+
39
+ def predict(path, sampling_rate):
40
+ speech = speech_file_to_array_fn(path, sampling_rate)
41
+ inputs = feature_extractor(speech, sampling_rate=sampling_rate, return_tensors="pt", padding=True)
42
+ inputs = {key: inputs[key].to(device) for key in inputs}
43
+
44
+ with torch.no_grad():
45
+ logits = model(**inputs).logits
46
+
47
+ scores = F.softmax(logits, dim=1).detach().cpu().numpy()[0]
48
+ outputs = [{"Emotion": config.id2label[i], "Score": f"{round(score * 100, 3):.1f}%"} for i, score in
49
+ enumerate(scores)]
50
+ return outputs
51
+
52
+ path = "./dataset/disgust.wav"
53
+ outputs = predict(path, sampling_rate)
54
+ print(outputs)
55
+ ```
56
+
57
+ Output:
58
+
59
+ ```bash
60
+ [
61
+ {'Emotion': 'anger', 'Score': '0.0%'},
62
+ {'Emotion': 'disgust', 'Score': '99.2%'},
63
+ {'Emotion': 'fear', 'Score': '0.1%'},
64
+ {'Emotion': 'happiness', 'Score': '0.3%'},
65
+ {'Emotion': 'sadness', 'Score': '0.5%'}
66
+ ]
67
+ ```
HuBERT-SER/requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch
2
+ numpy
3
+ git+https://github.com/huggingface/datasets.git
4
+ git+https://github.com/huggingface/transformers.git
5
+ torchaudio
6
+ librosa
HuBERT-SER/run_wav2vec_clf.py ADDED
@@ -0,0 +1,491 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import sys
4
+ from dataclasses import dataclass, field
5
+ from typing import Any, Dict, List, Optional, Union
6
+
7
+ from datasets import load_dataset, load_metric
8
+ import numpy as np
9
+ import torch
10
+ import torchaudio
11
+
12
+ import transformers
13
+ from transformers import (
14
+ HfArgumentParser,
15
+ TrainingArguments,
16
+ EvalPrediction,
17
+ AutoConfig,
18
+ Wav2Vec2Processor,
19
+ Wav2Vec2FeatureExtractor,
20
+ is_apex_available,
21
+ set_seed,
22
+ )
23
+ from transformers.trainer_utils import get_last_checkpoint, is_main_process
24
+
25
+ from src.models import Wav2Vec2ForSpeechClassification, HubertForSpeechClassification
26
+ from src.collator import DataCollatorCTCWithPadding
27
+ from src.trainer import CTCTrainer
28
+
29
+ logger = logging.getLogger(__name__)
30
+ MODEL_MODES = ["wav2vec", "hubert"]
31
+ POOLING_MODES = ["mean", "sum", "max"]
32
+ DELIMITERS = {"tab": "\t", "comma": ",", "pipe": "|"}
33
+
34
+
35
+ @dataclass
36
+ class ModelArguments:
37
+ """
38
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
39
+ """
40
+
41
+ model_name_or_path: str = field(
42
+ metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
43
+ )
44
+ model_mode: str = field(
45
+ default="wav2vec",
46
+ metadata={
47
+ "help": "Specifies the base model and must be from the following: " + ", ".join(MODEL_MODES)
48
+ },
49
+ )
50
+ pooling_mode: str = field(
51
+ default="mean",
52
+ metadata={
53
+ "help": "Specifies the reduction to apply to the output of Wav2Vec2 model and must be from the following: " + ", ".join(
54
+ POOLING_MODES)
55
+ },
56
+ )
57
+ config_name: Optional[str] = field(
58
+ default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
59
+ )
60
+ feature_extractor_name: Optional[str] = field(
61
+ default=None, metadata={"help": "Pretrained feature_extractor name or path if not the same as model_name"}
62
+ )
63
+ cache_dir: Optional[str] = field(
64
+ default=None,
65
+ metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
66
+ )
67
+ freeze_feature_extractor: Optional[bool] = field(
68
+ default=True, metadata={"help": "Whether to freeze the feature extractor layers of the model."}
69
+ )
70
+ model_revision: str = field(
71
+ default="main",
72
+ metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
73
+ )
74
+ use_auth_token: bool = field(
75
+ default=False,
76
+ metadata={
77
+ "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
78
+ "with private models)."
79
+ },
80
+ )
81
+
82
+
83
+ @dataclass
84
+ class DataTrainingArguments:
85
+ """
86
+ Arguments pertaining to what data we are going to input our model for training and eval.
87
+
88
+ Using `HfArgumentParser` we can turn this class
89
+ into argparse arguments to be able to specify them on
90
+ the command line.
91
+ """
92
+ train_file: Optional[str] = field(
93
+ default=None, metadata={"help": "The input training data file (a csv or JSON file)."}
94
+ )
95
+ validation_file: Optional[str] = field(
96
+ default=None,
97
+ metadata={"help": "An optional input evaluation data file to evaluate on (a csv or JSON file)."},
98
+ )
99
+ test_file: Optional[str] = field(
100
+ default=None,
101
+ metadata={"help": "An optional input evaluation data file to test on (a csv or JSON file)."},
102
+ )
103
+ input_column: Optional[str] = field(
104
+ default="path",
105
+ metadata={"help": "The name of the column in the datasets containing the audio path."},
106
+ )
107
+ target_column: Optional[str] = field(
108
+ default="emotion",
109
+ metadata={"help": "The name of the column in the datasets containing the labels."},
110
+ )
111
+ delimiter: Optional[str] = field(
112
+ default="tab",
113
+ metadata={
114
+ "help": "Specifies the character delimiting individual cells in the CSV data and must be from the following: " + ", ".join(
115
+ DELIMITERS.keys())
116
+ },
117
+ )
118
+ overwrite_cache: bool = field(
119
+ default=False, metadata={"help": "Overwrite the cached preprocessed datasets or not."}
120
+ )
121
+ preprocessing_num_workers: Optional[int] = field(
122
+ default=None,
123
+ metadata={"help": "The number of processes to use for the preprocessing."},
124
+ )
125
+ max_train_samples: Optional[int] = field(
126
+ default=None,
127
+ metadata={
128
+ "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
129
+ "value if set."
130
+ },
131
+ )
132
+ max_eval_samples: Optional[int] = field(
133
+ default=None,
134
+ metadata={
135
+ "help": "For debugging purposes or quicker training, truncate the number of validation examples to this "
136
+ "value if set."
137
+ },
138
+ )
139
+ max_predict_samples: Optional[int] = field(
140
+ default=None,
141
+ metadata={
142
+ "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
143
+ "value if set."
144
+ },
145
+ )
146
+ min_duration_in_seconds: Optional[float] = field(
147
+ default=None,
148
+ metadata={"help": "Filters out examples less than specified. Defaults to no filtering."},
149
+ )
150
+ max_duration_in_seconds: Optional[float] = field(
151
+ default=None,
152
+ metadata={"help": "Filters out examples longer than specified. Defaults to no filtering."},
153
+ )
154
+
155
+ def __post_init__(self):
156
+ if self.train_file is None and self.validation_file is None:
157
+ raise ValueError("Need either a dataset name or a training/validation file.")
158
+ else:
159
+ extension = self.train_file.split(".")[-1]
160
+ assert extension in ["csv", "json"], "`train_file` should be a csv or a json file."
161
+ extension = self.validation_file.split(".")[-1]
162
+ assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
163
+
164
+
165
+ def main():
166
+ # See all possible arguments in src/transformers/training_args.py
167
+ # or by passing the --help flag to this script.
168
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
169
+
170
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
171
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
172
+ # If we pass only one argument to the script and it's the path to a json file,
173
+ # let's parse it to get our arguments.
174
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
175
+ else:
176
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
177
+
178
+ # Detecting last checkpoint.
179
+ last_checkpoint = None
180
+ if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
181
+ last_checkpoint = get_last_checkpoint(training_args.output_dir)
182
+ if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
183
+ raise ValueError(
184
+ f"Output directory ({training_args.output_dir}) already exists and is not empty. "
185
+ "Use --overwrite_output_dir to overcome."
186
+ )
187
+ elif last_checkpoint is not None:
188
+ logger.info(
189
+ f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
190
+ "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
191
+ )
192
+ logger.info(f"last_checkpoint: {last_checkpoint}")
193
+
194
+ # Setup logging
195
+ logging.basicConfig(
196
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
197
+ datefmt="%m/%d/%Y %H:%M:%S",
198
+ handlers=[logging.StreamHandler(sys.stdout)],
199
+ )
200
+ logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN)
201
+
202
+ # Log on each process the small summary:
203
+ logger.warning(
204
+ f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
205
+ + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
206
+ )
207
+ # Set the verbosity to info of the Transformers logger (on main process only):
208
+ if is_main_process(training_args.local_rank):
209
+ transformers.utils.logging.set_verbosity_info()
210
+ logger.info("Training/evaluation parameters %s", training_args)
211
+
212
+ # Set seed before initializing model.
213
+ set_seed(training_args.seed)
214
+
215
+ # Loading a dataset from your local files.
216
+ # CSV/JSON training and evaluation files are needed.
217
+ data_files = {"train": data_args.train_file, "validation": data_args.validation_file}
218
+
219
+ # Get the test dataset: you can provide your own CSV/JSON test file (see below)
220
+ # when you use `do_predict` without specifying a GLUE benchmark task.
221
+ if training_args.do_predict:
222
+ if data_args.test_file is not None:
223
+ train_extension = data_args.train_file.split(".")[-1]
224
+ test_extension = data_args.test_file.split(".")[-1]
225
+ assert (
226
+ test_extension == train_extension
227
+ ), "`test_file` should have the same extension (csv or json) as `train_file`."
228
+ data_files["test"] = data_args.test_file
229
+ else:
230
+ raise ValueError("Need a test file for `do_predict`.")
231
+
232
+ for key in data_files.keys():
233
+ logger.info(f"load a local file for {key}: {data_files[key]}")
234
+
235
+ if data_args.train_file.endswith(".csv"):
236
+ # Loading a dataset from local csv files
237
+ datasets = load_dataset(
238
+ "csv",
239
+ data_files=data_files,
240
+ delimiter=DELIMITERS.get(data_args.delimiter, "\t"),
241
+ cache_dir=model_args.cache_dir
242
+ )
243
+ else:
244
+ # Loading a dataset from local json files
245
+ datasets = load_dataset("json", data_files=data_files, cache_dir=model_args.cache_dir)
246
+
247
+ input_column_name = data_args.input_column
248
+ output_column_name = data_args.target_column
249
+
250
+ # Trying to have good defaults here, don't hesitate to tweak to your needs.
251
+ is_regression = datasets["train"].features[output_column_name].dtype in ["float32", "float64"]
252
+ if is_regression:
253
+ num_labels = 1
254
+ label_list = []
255
+ logger.info(f"*** A regression problem ***")
256
+ else:
257
+ # A useful fast method:
258
+ # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.unique
259
+ label_list = datasets["train"].unique(output_column_name)
260
+ label_list.sort() # Let's sort it for determinism
261
+ num_labels = len(label_list)
262
+
263
+ logger.info(f"*** A classification problem with {num_labels} classes ***")
264
+
265
+ # Load pretrained model and tokenizer
266
+ #
267
+ # Distributed training:
268
+ # The .from_pretrained methods guarantee that only one local process can concurrently
269
+ # download model & vocab.
270
+ config = AutoConfig.from_pretrained(
271
+ model_args.config_name if model_args.config_name else model_args.model_name_or_path,
272
+ num_labels=num_labels,
273
+ label2id={label: i for i, label in enumerate(label_list)},
274
+ id2label={i: label for i, label in enumerate(label_list)},
275
+ finetuning_task="wav2vec2_clf",
276
+ cache_dir=model_args.cache_dir,
277
+ revision=model_args.model_revision,
278
+ use_auth_token=True if model_args.use_auth_token else None,
279
+ )
280
+ setattr(config, 'pooling_mode', model_args.pooling_mode)
281
+
282
+ # tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(model_args.model_name_or_path)
283
+ # feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(model_args.model_name_or_path)
284
+ # processor = Wav2Vec2Processor.from_pretrained(
285
+ # model_args.processor_name if model_args.processor_name else model_args.model_name_or_path,
286
+ # cache_dir=model_args.cache_dir,
287
+ # revision=model_args.model_revision,
288
+ # use_auth_token=True if model_args.use_auth_token else None,
289
+ # )
290
+ feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
291
+ model_args.feature_extractor_name if model_args.feature_extractor_name else model_args.model_name_or_path,
292
+ cache_dir=model_args.cache_dir,
293
+ revision=model_args.model_revision,
294
+ use_auth_token=True if model_args.use_auth_token else None,
295
+ )
296
+ target_sampling_rate = feature_extractor.sampling_rate
297
+
298
+ if model_args.model_mode == "wav2vec":
299
+ model = Wav2Vec2ForSpeechClassification.from_pretrained(
300
+ model_args.model_name_or_path,
301
+ from_tf=bool(".ckpt" in model_args.model_name_or_path),
302
+ config=config,
303
+ cache_dir=model_args.cache_dir,
304
+ revision=model_args.model_revision,
305
+ use_auth_token=True if model_args.use_auth_token else None,
306
+ )
307
+ elif model_args.model_mode == "hubert":
308
+ model = HubertForSpeechClassification.from_pretrained(
309
+ model_args.model_name_or_path,
310
+ from_tf=bool(".ckpt" in model_args.model_name_or_path),
311
+ config=config,
312
+ cache_dir=model_args.cache_dir,
313
+ revision=model_args.model_revision,
314
+ use_auth_token=True if model_args.use_auth_token else None,
315
+ )
316
+ else:
317
+ raise ValueError("--model_mode does not exist in predefined modes: " + ",".join(MODEL_MODES))
318
+
319
+ if model_args.freeze_feature_extractor:
320
+ model.freeze_feature_extractor()
321
+
322
+ # NOTE: Duration controller for the future `min_duration_in_seconds` `max_duration_in_seconds`
323
+ # data_args.min_duration_in_seconds, data_args.max_duration_in_seconds
324
+
325
+ def speech_file_to_array_fn(path):
326
+ speech_array, sampling_rate = torchaudio.load(path)
327
+ resampler = torchaudio.transforms.Resample(sampling_rate, target_sampling_rate)
328
+ speech = resampler(speech_array).squeeze().numpy()
329
+ return speech
330
+
331
+ def label_to_id(label, label_list):
332
+
333
+ if len(label_list) > 0:
334
+ return label_list.index(label) if label in label_list else -1
335
+
336
+ return label
337
+
338
+ def preprocess_function(examples):
339
+ speech_list = [speech_file_to_array_fn(path) for path in examples[input_column_name]]
340
+ target_list = [label_to_id(label, label_list) for label in examples[output_column_name]]
341
+
342
+ result = feature_extractor(speech_list, sampling_rate=target_sampling_rate)
343
+ result["labels"] = list(target_list)
344
+
345
+ return result
346
+
347
+ if training_args.do_train:
348
+ if "train" not in datasets:
349
+ raise ValueError("--do_train requires a train dataset")
350
+
351
+ train_dataset = datasets["train"]
352
+
353
+ if data_args.max_train_samples is not None:
354
+ train_dataset = train_dataset.select(range(data_args.max_train_samples))
355
+
356
+ train_dataset = train_dataset.map(
357
+ preprocess_function,
358
+ batched=True,
359
+ load_from_cache_file=not data_args.overwrite_cache
360
+ )
361
+ logger.info(f"Split sizes: {len(train_dataset)} train")
362
+
363
+ if training_args.do_eval:
364
+ if "validation" not in datasets:
365
+ raise ValueError("--do_eval requires a validation dataset")
366
+
367
+ eval_dataset = datasets["validation"]
368
+
369
+ if data_args.max_eval_samples is not None:
370
+ eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
371
+
372
+ eval_dataset = eval_dataset.map(
373
+ preprocess_function,
374
+ batched=True,
375
+ load_from_cache_file=not data_args.overwrite_cache
376
+ )
377
+ logger.info(f"Split sizes: {len(eval_dataset)} validation")
378
+
379
+ if training_args.do_predict:
380
+ if "test" not in datasets:
381
+ raise ValueError("--do_predict requires a test dataset")
382
+
383
+ predict_dataset = datasets["test"]
384
+
385
+ if data_args.max_predict_samples is not None:
386
+ predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))
387
+
388
+ predict_dataset = predict_dataset.map(
389
+ preprocess_function,
390
+ batched=True,
391
+ load_from_cache_file=not data_args.overwrite_cache
392
+ )
393
+ logger.info(f"Split sizes: {len(predict_dataset)} test.")
394
+
395
+ # Metric
396
+ # You can define your custom compute_metrics function. It takes an `EvalPrediction` object (a namedtuple with a
397
+ # predictions and label_ids field) and has to return a dictionary string to float.
398
+ def compute_metrics(p: EvalPrediction):
399
+ preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions
400
+ preds = np.squeeze(preds) if is_regression else np.argmax(preds, axis=1)
401
+
402
+ if is_regression:
403
+ return {"mse": ((preds - p.label_ids) ** 2).mean().item()}
404
+ else:
405
+ return {"accuracy": (preds == p.label_ids).astype(np.float32).mean().item()}
406
+
407
+ # Data collator
408
+ data_collator = DataCollatorCTCWithPadding(feature_extractor=feature_extractor, padding=True)
409
+
410
+ # Initialize our Trainer
411
+ trainer = CTCTrainer(
412
+ model=model,
413
+ data_collator=data_collator,
414
+ args=training_args,
415
+ compute_metrics=compute_metrics,
416
+ train_dataset=train_dataset if training_args.do_train else None,
417
+ eval_dataset=eval_dataset if training_args.do_eval else None,
418
+ tokenizer=feature_extractor,
419
+ )
420
+
421
+ # Training
422
+ if training_args.do_train:
423
+ if last_checkpoint is not None:
424
+ checkpoint = last_checkpoint
425
+ elif os.path.isdir(model_args.model_name_or_path):
426
+ checkpoint = model_args.model_name_or_path
427
+ else:
428
+ checkpoint = None
429
+
430
+ logger.info(f"*** Training from: {checkpoint} ***")
431
+ train_result = trainer.train(resume_from_checkpoint=checkpoint)
432
+ trainer.save_model()
433
+
434
+ # save the feature_extractor and the tokenizer
435
+ if is_main_process(training_args.local_rank):
436
+ feature_extractor.save_pretrained(training_args.output_dir)
437
+
438
+ metrics = train_result.metrics
439
+ max_train_samples = (
440
+ data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
441
+ )
442
+ metrics["train_samples"] = min(max_train_samples, len(train_dataset))
443
+
444
+ trainer.log_metrics("train", metrics)
445
+ trainer.save_metrics("train", metrics)
446
+ trainer.save_state()
447
+
448
+ # Evaluation
449
+ results = {}
450
+ if training_args.do_eval:
451
+ logger.info("*** Evaluate ***")
452
+ metrics = trainer.evaluate()
453
+ max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
454
+ metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
455
+
456
+ trainer.log_metrics("eval", metrics)
457
+ trainer.save_metrics("eval", metrics)
458
+
459
+ # Final test metrics
460
+ if training_args.do_predict:
461
+ logger.info("*** Test ***")
462
+
463
+ predict_dataset.remove_columns_(output_column_name)
464
+ predictions = trainer.predict(predict_dataset, metric_key_prefix="predict").predictions
465
+ predictions = np.squeeze(predictions) if is_regression else np.argmax(predictions, axis=1)
466
+
467
+ output_predict_file = os.path.join(training_args.output_dir, f"predict_results.txt")
468
+ if trainer.is_world_process_zero():
469
+ with open(output_predict_file, "w", encoding="utf-8") as writer:
470
+ logger.info(f"***** Predict results *****")
471
+ writer.write("index\tprediction\n")
472
+ for index, item in enumerate(predictions):
473
+ if is_regression:
474
+ writer.write(f"{index}\t{item:3.3f}\n")
475
+ else:
476
+ item = label_list[item]
477
+ writer.write(f"{index}\t{item}\n")
478
+
479
+ # NOTE: Pushing to hub for future
480
+ # training_args.push_to_hub
481
+
482
+ return results
483
+
484
+
485
+ def _mp_fn(index):
486
+ # For xla_spawn (TPUs)
487
+ main()
488
+
489
+
490
+ if __name__ == "__main__":
491
+ main()
HuBERT-SER/src/__init__.py ADDED
File without changes
HuBERT-SER/src/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (169 Bytes). View file
 
HuBERT-SER/src/__pycache__/modeling_outputs.cpython-311.pyc ADDED
Binary file (1.05 kB). View file
 
HuBERT-SER/src/__pycache__/models.cpython-311.pyc ADDED
Binary file (10.4 kB). View file
 
HuBERT-SER/src/collator.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Dict, List, Optional, Union
3
+ import torch
4
+
5
+ import transformers
6
+ from transformers import Wav2Vec2Processor, Wav2Vec2FeatureExtractor
7
+
8
+
9
+ @dataclass
10
+ class DataCollatorCTCWithPadding:
11
+ """
12
+ Data collator that will dynamically pad the inputs received.
13
+ Args:
14
+ feature_extractor (:class:`~transformers.Wav2Vec2FeatureExtractor`)
15
+ The feature_extractor used for proccessing the data.
16
+ padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
17
+ Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
18
+ among:
19
+ * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
20
+ sequence if provided).
21
+ * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
22
+ maximum acceptable input length for the model if that argument is not provided.
23
+ * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
24
+ different lengths).
25
+ max_length (:obj:`int`, `optional`):
26
+ Maximum length of the ``input_values`` of the returned list and optionally padding length (see above).
27
+ max_length_labels (:obj:`int`, `optional`):
28
+ Maximum length of the ``labels`` returned list and optionally padding length (see above).
29
+ pad_to_multiple_of (:obj:`int`, `optional`):
30
+ If set will pad the sequence to a multiple of the provided value.
31
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
32
+ 7.5 (Volta).
33
+ """
34
+
35
+ feature_extractor: Wav2Vec2FeatureExtractor
36
+ padding: Union[bool, str] = True
37
+ max_length: Optional[int] = None
38
+ max_length_labels: Optional[int] = None
39
+ pad_to_multiple_of: Optional[int] = None
40
+ pad_to_multiple_of_labels: Optional[int] = None
41
+
42
+ def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
43
+ input_features = [{"input_values": feature["input_values"]} for feature in features]
44
+ label_features = [feature["labels"] for feature in features]
45
+
46
+ d_type = torch.long if isinstance(label_features[0], int) else torch.float
47
+
48
+ batch = self.feature_extractor.pad(
49
+ input_features,
50
+ padding=self.padding,
51
+ max_length=self.max_length,
52
+ pad_to_multiple_of=self.pad_to_multiple_of,
53
+ return_tensors="pt",
54
+ )
55
+
56
+ batch["labels"] = torch.tensor(label_features, dtype=d_type)
57
+
58
+ return batch
HuBERT-SER/src/modeling_outputs.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Optional, Tuple
3
+ import torch
4
+ from transformers.file_utils import ModelOutput
5
+
6
+
7
+ @dataclass
8
+ class SpeechClassifierOutput(ModelOutput):
9
+ loss: Optional[torch.FloatTensor] = None
10
+ logits: torch.FloatTensor = None
11
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
12
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
HuBERT-SER/src/models.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
4
+
5
+ from transformers.models.wav2vec2.modeling_wav2vec2 import (
6
+ Wav2Vec2PreTrainedModel,
7
+ Wav2Vec2Model
8
+ )
9
+ from transformers.models.hubert.modeling_hubert import (
10
+ HubertPreTrainedModel,
11
+ HubertModel
12
+ )
13
+
14
+ from src.modeling_outputs import SpeechClassifierOutput
15
+
16
+
17
+ class Wav2Vec2ClassificationHead(nn.Module):
18
+ """Head for wav2vec classification task."""
19
+
20
+ def __init__(self, config):
21
+ super().__init__()
22
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
23
+ self.dropout = nn.Dropout(config.final_dropout)
24
+ self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
25
+
26
+ def forward(self, features, **kwargs):
27
+ x = features
28
+ x = self.dropout(x)
29
+ x = self.dense(x)
30
+ x = torch.tanh(x)
31
+ x = self.dropout(x)
32
+ x = self.out_proj(x)
33
+ return x
34
+
35
+
36
+ class Wav2Vec2ForSpeechClassification(Wav2Vec2PreTrainedModel):
37
+ def __init__(self, config):
38
+ super().__init__(config)
39
+ self.num_labels = config.num_labels
40
+ self.pooling_mode = config.pooling_mode
41
+ self.config = config
42
+
43
+ self.wav2vec2 = Wav2Vec2Model(config)
44
+ self.classifier = Wav2Vec2ClassificationHead(config)
45
+
46
+ self.init_weights()
47
+
48
+ def freeze_feature_extractor(self):
49
+ self.wav2vec2.feature_extractor._freeze_parameters()
50
+
51
+ def merged_strategy(
52
+ self,
53
+ hidden_states,
54
+ mode="mean"
55
+ ):
56
+ if mode == "mean":
57
+ outputs = torch.mean(hidden_states, dim=1)
58
+ elif mode == "sum":
59
+ outputs = torch.sum(hidden_states, dim=1)
60
+ elif mode == "max":
61
+ outputs = torch.max(hidden_states, dim=1)[0]
62
+ else:
63
+ raise Exception(
64
+ "The pooling method hasn't been defined! Your pooling mode must be one of these ['mean', 'sum', 'max']")
65
+
66
+ return outputs
67
+
68
+ def forward(
69
+ self,
70
+ input_values,
71
+ attention_mask=None,
72
+ output_attentions=None,
73
+ output_hidden_states=None,
74
+ return_dict=None,
75
+ labels=None,
76
+ ):
77
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
78
+ outputs = self.wav2vec2(
79
+ input_values,
80
+ attention_mask=attention_mask,
81
+ output_attentions=output_attentions,
82
+ output_hidden_states=output_hidden_states,
83
+ return_dict=return_dict,
84
+ )
85
+ hidden_states = outputs[0]
86
+ hidden_states = self.merged_strategy(hidden_states, mode=self.pooling_mode)
87
+ logits = self.classifier(hidden_states)
88
+
89
+ loss = None
90
+ if labels is not None:
91
+ if self.config.problem_type is None:
92
+ if self.num_labels == 1:
93
+ self.config.problem_type = "regression"
94
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
95
+ self.config.problem_type = "single_label_classification"
96
+ else:
97
+ self.config.problem_type = "multi_label_classification"
98
+
99
+ if self.config.problem_type == "regression":
100
+ loss_fct = MSELoss()
101
+ loss = loss_fct(logits.view(-1, self.num_labels), labels)
102
+ elif self.config.problem_type == "single_label_classification":
103
+ loss_fct = CrossEntropyLoss()
104
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
105
+ elif self.config.problem_type == "multi_label_classification":
106
+ loss_fct = BCEWithLogitsLoss()
107
+ loss = loss_fct(logits, labels)
108
+
109
+ if not return_dict:
110
+ output = (logits,) + outputs[2:]
111
+ return ((loss,) + output) if loss is not None else output
112
+
113
+ return SpeechClassifierOutput(
114
+ loss=loss,
115
+ logits=logits,
116
+ hidden_states=outputs.hidden_states,
117
+ attentions=outputs.attentions,
118
+ )
119
+
120
+
121
+ class HubertClassificationHead(nn.Module):
122
+ """Head for hubert classification task."""
123
+
124
+ def __init__(self, config):
125
+ super().__init__()
126
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
127
+ self.dropout = nn.Dropout(config.final_dropout)
128
+ self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
129
+
130
+ def forward(self, features, **kwargs):
131
+ x = features
132
+ x = self.dropout(x)
133
+ x = self.dense(x)
134
+ x = torch.tanh(x)
135
+ x = self.dropout(x)
136
+ x = self.out_proj(x)
137
+ return x
138
+
139
+
140
+ class HubertForSpeechClassification(HubertPreTrainedModel):
141
+ def __init__(self, config):
142
+ super().__init__(config)
143
+ self.num_labels = config.num_labels
144
+ self.pooling_mode = config.pooling_mode
145
+ self.config = config
146
+
147
+ self.hubert = HubertModel(config)
148
+ self.classifier = HubertClassificationHead(config)
149
+
150
+ self.init_weights()
151
+
152
+ def freeze_feature_extractor(self):
153
+ self.hubert.feature_extractor._freeze_parameters()
154
+
155
+ def merged_strategy(
156
+ self,
157
+ hidden_states,
158
+ mode="mean"
159
+ ):
160
+ if mode == "mean":
161
+ outputs = torch.mean(hidden_states, dim=1)
162
+ elif mode == "sum":
163
+ outputs = torch.sum(hidden_states, dim=1)
164
+ elif mode == "max":
165
+ outputs = torch.max(hidden_states, dim=1)[0]
166
+ else:
167
+ raise Exception(
168
+ "The pooling method hasn't been defined! Your pooling mode must be one of these ['mean', 'sum', 'max']")
169
+
170
+ return outputs
171
+
172
+ def forward(
173
+ self,
174
+ input_values,
175
+ attention_mask=None,
176
+ output_attentions=None,
177
+ output_hidden_states=None,
178
+ return_dict=None,
179
+ labels=None,
180
+ ):
181
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
182
+ outputs = self.hubert(
183
+ input_values,
184
+ attention_mask=attention_mask,
185
+ output_attentions=output_attentions,
186
+ output_hidden_states=output_hidden_states,
187
+ return_dict=return_dict,
188
+ )
189
+ hidden_states = outputs[0]
190
+ hidden_states = self.merged_strategy(hidden_states, mode=self.pooling_mode)
191
+ logits = self.classifier(hidden_states)
192
+
193
+ loss = None
194
+ if labels is not None:
195
+ if self.config.problem_type is None:
196
+ if self.num_labels == 1:
197
+ self.config.problem_type = "regression"
198
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
199
+ self.config.problem_type = "single_label_classification"
200
+ else:
201
+ self.config.problem_type = "multi_label_classification"
202
+
203
+ if self.config.problem_type == "regression":
204
+ loss_fct = MSELoss()
205
+ loss = loss_fct(logits.view(-1, self.num_labels), labels)
206
+ elif self.config.problem_type == "single_label_classification":
207
+ loss_fct = CrossEntropyLoss()
208
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
209
+ elif self.config.problem_type == "multi_label_classification":
210
+ loss_fct = BCEWithLogitsLoss()
211
+ loss = loss_fct(logits, labels)
212
+
213
+ if not return_dict:
214
+ output = (logits,) + outputs[2:]
215
+ return ((loss,) + output) if loss is not None else output
216
+
217
+ return SpeechClassifierOutput(
218
+ loss=loss,
219
+ logits=logits,
220
+ hidden_states=outputs.hidden_states,
221
+ attentions=outputs.attentions,
222
+ )
HuBERT-SER/src/trainer.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, Union
2
+
3
+ import torch
4
+ from packaging import version
5
+ from torch import nn
6
+ from transformers import Trainer, is_apex_available, TrainingArguments
7
+
8
+ if is_apex_available():
9
+ from apex import amp
10
+
11
+ if version.parse(torch.__version__) >= version.parse("1.6"):
12
+ _is_native_amp_available = True
13
+ from torch.cuda.amp import autocast, GradScaler
14
+ else:
15
+ _is_native_amp_available = False
16
+
17
+
18
+ class CTCTrainer(Trainer):
19
+ def __init__(self, *args, **kwargs):
20
+ super().__init__(*args, **kwargs)
21
+ self.use_amp = _is_native_amp_available and self.args.fp16
22
+ self.scaler = GradScaler() if self.use_amp else None
23
+
24
+ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
25
+ """
26
+ Perform a training step on a batch of inputs.
27
+
28
+ Subclass and override to inject custom behavior.
29
+
30
+ Args:
31
+ model (:obj:`nn.Module`):
32
+ The model to train.
33
+ inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`):
34
+ The inputs and targets of the model.
35
+
36
+ The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
37
+ argument :obj:`labels`. Check your model's documentation for all accepted arguments.
38
+
39
+ Return:
40
+ :obj:`torch.Tensor`: The tensor with training loss on this batch.
41
+ """
42
+
43
+ model.train()
44
+ inputs = self._prepare_inputs(inputs)
45
+
46
+ if self.use_amp:
47
+ with autocast():
48
+ loss = self.compute_loss(model, inputs)
49
+ else:
50
+ loss = self.compute_loss(model, inputs)
51
+
52
+ if self.args.gradient_accumulation_steps > 1:
53
+ loss = loss / self.args.gradient_accumulation_steps
54
+
55
+ if self.use_amp:
56
+ self.scaler.scale(loss).backward()
57
+ elif self.use_apex:
58
+ with amp.scale_loss(loss, self.optimizer) as scaled_loss:
59
+ scaled_loss.backward()
60
+ elif self.deepspeed:
61
+ self.deepspeed.backward(loss)
62
+ else:
63
+ loss.backward()
64
+
65
+ return loss.detach()