Baghdad99 commited on
Commit
3369603
1 Parent(s): 17cfe18

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import gradio as gr
2
  from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForTextToWaveform
 
3
 
4
  # Load your pretrained models
5
  asr_model = Wav2Vec2ForCTC.from_pretrained("Baghdad99/saad-speech-recognition-hausa-audio-to-text")
@@ -24,7 +25,7 @@ def translate_speech(speech):
24
  # Transcribe the speech to text
25
  inputs = asr_processor(audio_signal, return_tensors="pt", padding=True)
26
  logits = asr_model(inputs.input_values).logits
27
- predicted_ids = torch.argmax(logits, dim=-1)
28
  transcription = asr_processor.decode(predicted_ids[0])
29
 
30
  # Translate the text
@@ -41,4 +42,3 @@ def translate_speech(speech):
41
  # Define the Gradio interface
42
  iface = gr.Interface(fn=translate_speech, inputs=gr.inputs.Audio(source="microphone"), outputs="audio")
43
  iface.launch()
44
-
 
1
  import gradio as gr
2
  from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForTextToWaveform
3
+ import torch # Add the import statement for torch
4
 
5
  # Load your pretrained models
6
  asr_model = Wav2Vec2ForCTC.from_pretrained("Baghdad99/saad-speech-recognition-hausa-audio-to-text")
 
25
  # Transcribe the speech to text
26
  inputs = asr_processor(audio_signal, return_tensors="pt", padding=True)
27
  logits = asr_model(inputs.input_values).logits
28
+ predicted_ids = torch.argmax(logits, dim=-1) # Add torch module to access argmax function
29
  transcription = asr_processor.decode(predicted_ids[0])
30
 
31
  # Translate the text
 
42
  # Define the Gradio interface
43
  iface = gr.Interface(fn=translate_speech, inputs=gr.inputs.Audio(source="microphone"), outputs="audio")
44
  iface.launch()