Spaces:
				
			
			
	
			
			
					
		Running
		
	
	
	
			
			
	
	
	
	
		
		
					
		Running
		
	
		alessandro trinca tornidor
		
	commited on
		
		
					Commit 
							
							·
						
						70d4503
	
1
								Parent(s):
							
							823d44e
								
feat: support pytorch and torchaudio, update test, add requirements-dev.txt
Browse files- .gitignore +1 -0
- aip_trainer/models/models.py +5 -16
- requirements-dev.txt +2 -0
- requirements.txt +2 -3
- tests/events/GetAccuracyFromRecordedAudio.json +0 -0
- tests/test_GetAccuracyFromRecordedAudio.py +2 -0
    	
        .gitignore
    CHANGED
    
    | @@ -199,6 +199,7 @@ tmp | |
| 199 | 
             
            nohup.out
         | 
| 200 | 
             
            /tests/events.tar
         | 
| 201 | 
             
            function_dump_*.json
         | 
|  | |
| 202 |  | 
| 203 | 
             
            # onnx models
         | 
| 204 | 
             
            *.onnx
         | 
|  | |
| 199 | 
             
            nohup.out
         | 
| 200 | 
             
            /tests/events.tar
         | 
| 201 | 
             
            function_dump_*.json
         | 
| 202 | 
            +
            *.yml
         | 
| 203 |  | 
| 204 | 
             
            # onnx models
         | 
| 205 | 
             
            *.onnx
         | 
    	
        aip_trainer/models/models.py
    CHANGED
    
    | @@ -1,25 +1,14 @@ | |
| 1 | 
            -
            from typing import Any
         | 
| 2 | 
            -
             | 
| 3 | 
            -
            import torch
         | 
| 4 | 
             
            import torch.nn as nn
         | 
|  | |
|  | |
| 5 |  | 
| 6 |  | 
| 7 | 
             
            # second returned type here is the custom class src.silero.utils.Decoder from snakers4/silero-models
         | 
| 8 | 
            -
            def getASRModel(language: str) -> tuple[nn.Module,  | 
| 9 | 
            -
             | 
| 10 | 
            -
             | 
| 11 | 
             
                if language == 'de':
         | 
| 12 | 
            -
             | 
| 13 | 
            -
                    model, decoder, utils = torch.hub.load(repo_or_dir='snakers4/silero-models',
         | 
| 14 | 
            -
                                                           model='silero_stt',
         | 
| 15 | 
            -
                                                           language='de',
         | 
| 16 | 
            -
                                                           device=torch.device('cpu'))
         | 
| 17 | 
            -
             | 
| 18 | 
             
                elif language == 'en':
         | 
| 19 | 
            -
                    model, decoder,  | 
| 20 | 
            -
                                                           model='silero_stt',
         | 
| 21 | 
            -
                                                           language='en',
         | 
| 22 | 
            -
                                                           device=torch.device('cpu'))
         | 
| 23 | 
             
                else:
         | 
| 24 | 
             
                    raise NotImplementedError("currenty works only for 'de' and 'en' languages, not for '{}'.".format(language))
         | 
| 25 |  | 
|  | |
|  | |
|  | |
|  | |
| 1 | 
             
            import torch.nn as nn
         | 
| 2 | 
            +
            from silero import silero_stt
         | 
| 3 | 
            +
            from silero.utils import Decoder
         | 
| 4 |  | 
| 5 |  | 
| 6 | 
             
            # second returned type here is the custom class src.silero.utils.Decoder from snakers4/silero-models
         | 
| 7 | 
            +
            def getASRModel(language: str) -> tuple[nn.Module, Decoder]:
         | 
|  | |
|  | |
| 8 | 
             
                if language == 'de':
         | 
| 9 | 
            +
                    model, decoder, _ = silero_stt(language='de', version="v4", jit_model="jit_large")
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 10 | 
             
                elif language == 'en':
         | 
| 11 | 
            +
                    model, decoder, _ = silero_stt(language='en')
         | 
|  | |
|  | |
|  | |
| 12 | 
             
                else:
         | 
| 13 | 
             
                    raise NotImplementedError("currenty works only for 'de' and 'en' languages, not for '{}'.".format(language))
         | 
| 14 |  | 
    	
        requirements-dev.txt
    ADDED
    
    | @@ -0,0 +1,2 @@ | |
|  | |
|  | 
|  | |
| 1 | 
            +
            pytest
         | 
| 2 | 
            +
            pytest-cov
         | 
    	
        requirements.txt
    CHANGED
    
    | @@ -7,7 +7,6 @@ flask_cors | |
| 7 | 
             
            omegaconf
         | 
| 8 | 
             
            ortools==9.11.4210
         | 
| 9 | 
             
            pandas
         | 
| 10 | 
            -
            numpy<2.0.0
         | 
| 11 | 
             
            pickle-mixin
         | 
| 12 | 
             
            python-dotenv
         | 
| 13 | 
             
            requests
         | 
| @@ -15,6 +14,6 @@ sentencepiece | |
| 15 | 
             
            soundfile==0.12.1
         | 
| 16 | 
             
            sqlalchemy
         | 
| 17 | 
             
            structlog
         | 
| 18 | 
            -
            torch | 
| 19 | 
            -
            torchaudio | 
| 20 | 
             
            transformers
         | 
|  | |
| 7 | 
             
            omegaconf
         | 
| 8 | 
             
            ortools==9.11.4210
         | 
| 9 | 
             
            pandas
         | 
|  | |
| 10 | 
             
            pickle-mixin
         | 
| 11 | 
             
            python-dotenv
         | 
| 12 | 
             
            requests
         | 
|  | |
| 14 | 
             
            soundfile==0.12.1
         | 
| 15 | 
             
            sqlalchemy
         | 
| 16 | 
             
            structlog
         | 
| 17 | 
            +
            torch
         | 
| 18 | 
            +
            torchaudio
         | 
| 19 | 
             
            transformers
         | 
    	
        tests/events/GetAccuracyFromRecordedAudio.json
    CHANGED
    
    | The diff for this file is too large to render. 
		See raw diff | 
|  | 
    	
        tests/test_GetAccuracyFromRecordedAudio.py
    CHANGED
    
    | @@ -40,7 +40,9 @@ class TestGetAccuracyFromRecordedAudio(unittest.TestCase): | |
| 40 | 
             
                        output["matched_transcripts"] = expected_output["matched_transcripts"]
         | 
| 41 | 
             
                        output["matched_transcripts_ipa"] = expected_output["matched_transcripts_ipa"]
         | 
| 42 | 
             
                        output["pronunciation_accuracy"] = expected_output["pronunciation_accuracy"]
         | 
|  | |
| 43 | 
             
                        output["ipa_transcript"] = expected_output["ipa_transcript"]
         | 
|  | |
| 44 | 
             
                        output["real_transcripts_ipa"] = expected_output["real_transcripts_ipa"]
         | 
| 45 | 
             
                        self.assertEqual(expected_output, output)
         | 
| 46 |  | 
|  | |
| 40 | 
             
                        output["matched_transcripts"] = expected_output["matched_transcripts"]
         | 
| 41 | 
             
                        output["matched_transcripts_ipa"] = expected_output["matched_transcripts_ipa"]
         | 
| 42 | 
             
                        output["pronunciation_accuracy"] = expected_output["pronunciation_accuracy"]
         | 
| 43 | 
            +
                        output["pair_accuracy_category"] = expected_output["pair_accuracy_category"]
         | 
| 44 | 
             
                        output["ipa_transcript"] = expected_output["ipa_transcript"]
         | 
| 45 | 
            +
                        output["real_transcript"] = expected_output["real_transcript"]
         | 
| 46 | 
             
                        output["real_transcripts_ipa"] = expected_output["real_transcripts_ipa"]
         | 
| 47 | 
             
                        self.assertEqual(expected_output, output)
         | 
| 48 |  |