Spaces:
Running
Running
alessandro trinca tornidor
commited on
Commit
·
70d4503
1
Parent(s):
823d44e
feat: support pytorch and torchaudio, update test, add requirements-dev.txt
Browse files- .gitignore +1 -0
- aip_trainer/models/models.py +5 -16
- requirements-dev.txt +2 -0
- requirements.txt +2 -3
- tests/events/GetAccuracyFromRecordedAudio.json +0 -0
- tests/test_GetAccuracyFromRecordedAudio.py +2 -0
.gitignore
CHANGED
|
@@ -199,6 +199,7 @@ tmp
|
|
| 199 |
nohup.out
|
| 200 |
/tests/events.tar
|
| 201 |
function_dump_*.json
|
|
|
|
| 202 |
|
| 203 |
# onnx models
|
| 204 |
*.onnx
|
|
|
|
| 199 |
nohup.out
|
| 200 |
/tests/events.tar
|
| 201 |
function_dump_*.json
|
| 202 |
+
*.yml
|
| 203 |
|
| 204 |
# onnx models
|
| 205 |
*.onnx
|
aip_trainer/models/models.py
CHANGED
|
@@ -1,25 +1,14 @@
|
|
| 1 |
-
from typing import Any
|
| 2 |
-
|
| 3 |
-
import torch
|
| 4 |
import torch.nn as nn
|
|
|
|
|
|
|
| 5 |
|
| 6 |
|
| 7 |
# second returned type here is the custom class src.silero.utils.Decoder from snakers4/silero-models
|
| 8 |
-
def getASRModel(language: str) -> tuple[nn.Module,
|
| 9 |
-
|
| 10 |
-
|
| 11 |
if language == 'de':
|
| 12 |
-
|
| 13 |
-
model, decoder, utils = torch.hub.load(repo_or_dir='snakers4/silero-models',
|
| 14 |
-
model='silero_stt',
|
| 15 |
-
language='de',
|
| 16 |
-
device=torch.device('cpu'))
|
| 17 |
-
|
| 18 |
elif language == 'en':
|
| 19 |
-
model, decoder,
|
| 20 |
-
model='silero_stt',
|
| 21 |
-
language='en',
|
| 22 |
-
device=torch.device('cpu'))
|
| 23 |
else:
|
| 24 |
raise NotImplementedError("currenty works only for 'de' and 'en' languages, not for '{}'.".format(language))
|
| 25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import torch.nn as nn
|
| 2 |
+
from silero import silero_stt
|
| 3 |
+
from silero.utils import Decoder
|
| 4 |
|
| 5 |
|
| 6 |
# second returned type here is the custom class src.silero.utils.Decoder from snakers4/silero-models
|
| 7 |
+
def getASRModel(language: str) -> tuple[nn.Module, Decoder]:
|
|
|
|
|
|
|
| 8 |
if language == 'de':
|
| 9 |
+
model, decoder, _ = silero_stt(language='de', version="v4", jit_model="jit_large")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
elif language == 'en':
|
| 11 |
+
model, decoder, _ = silero_stt(language='en')
|
|
|
|
|
|
|
|
|
|
| 12 |
else:
|
| 13 |
raise NotImplementedError("currenty works only for 'de' and 'en' languages, not for '{}'.".format(language))
|
| 14 |
|
requirements-dev.txt
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
pytest
|
| 2 |
+
pytest-cov
|
requirements.txt
CHANGED
|
@@ -7,7 +7,6 @@ flask_cors
|
|
| 7 |
omegaconf
|
| 8 |
ortools==9.11.4210
|
| 9 |
pandas
|
| 10 |
-
numpy<2.0.0
|
| 11 |
pickle-mixin
|
| 12 |
python-dotenv
|
| 13 |
requests
|
|
@@ -15,6 +14,6 @@ sentencepiece
|
|
| 15 |
soundfile==0.12.1
|
| 16 |
sqlalchemy
|
| 17 |
structlog
|
| 18 |
-
torch
|
| 19 |
-
torchaudio
|
| 20 |
transformers
|
|
|
|
| 7 |
omegaconf
|
| 8 |
ortools==9.11.4210
|
| 9 |
pandas
|
|
|
|
| 10 |
pickle-mixin
|
| 11 |
python-dotenv
|
| 12 |
requests
|
|
|
|
| 14 |
soundfile==0.12.1
|
| 15 |
sqlalchemy
|
| 16 |
structlog
|
| 17 |
+
torch
|
| 18 |
+
torchaudio
|
| 19 |
transformers
|
tests/events/GetAccuracyFromRecordedAudio.json
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|
tests/test_GetAccuracyFromRecordedAudio.py
CHANGED
|
@@ -40,7 +40,9 @@ class TestGetAccuracyFromRecordedAudio(unittest.TestCase):
|
|
| 40 |
output["matched_transcripts"] = expected_output["matched_transcripts"]
|
| 41 |
output["matched_transcripts_ipa"] = expected_output["matched_transcripts_ipa"]
|
| 42 |
output["pronunciation_accuracy"] = expected_output["pronunciation_accuracy"]
|
|
|
|
| 43 |
output["ipa_transcript"] = expected_output["ipa_transcript"]
|
|
|
|
| 44 |
output["real_transcripts_ipa"] = expected_output["real_transcripts_ipa"]
|
| 45 |
self.assertEqual(expected_output, output)
|
| 46 |
|
|
|
|
| 40 |
output["matched_transcripts"] = expected_output["matched_transcripts"]
|
| 41 |
output["matched_transcripts_ipa"] = expected_output["matched_transcripts_ipa"]
|
| 42 |
output["pronunciation_accuracy"] = expected_output["pronunciation_accuracy"]
|
| 43 |
+
output["pair_accuracy_category"] = expected_output["pair_accuracy_category"]
|
| 44 |
output["ipa_transcript"] = expected_output["ipa_transcript"]
|
| 45 |
+
output["real_transcript"] = expected_output["real_transcript"]
|
| 46 |
output["real_transcripts_ipa"] = expected_output["real_transcripts_ipa"]
|
| 47 |
self.assertEqual(expected_output, output)
|
| 48 |
|