mskov commited on
Commit
babca6f
1 Parent(s): f5e59d1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -4
app.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import evaluate
2
  from evaluate.utils import launch_gradio_widget
3
  import gradio as gr
@@ -25,11 +27,18 @@ emotion_dict = {
25
  def classify_toxicity(audio_file, text_input, classify_anxiety):
26
  # Transcribe the audio file using Whisper ASR
27
  if audio_file != None:
28
- whisper_module = evaluate.load("whisper")
29
- transcription_results = whisper_module.compute(uploaded=audio_file)
30
-
 
 
 
 
 
 
31
  # Extract the transcribed text
32
- transcribed_text = transcription_results["transcription"]
 
33
 
34
  #### Emotion classification ####
35
  emotion_classifier = foreign_class(source="speechbrain/emotion-recognition-wav2vec2-IEMOCAP", pymodule_file="custom_interface.py", classname="CustomEncoderWav2vec2Classifier")
 
1
+ import os
2
+ os.system("pip install git+https://github.com/openai/whisper.git")
3
  import evaluate
4
  from evaluate.utils import launch_gradio_widget
5
  import gradio as gr
 
27
  def classify_toxicity(audio_file, text_input, classify_anxiety):
28
  # Transcribe the audio file using Whisper ASR
29
  if audio_file != None:
30
+ '''whisper_model = WhisperModel.from_pretrained("openai/whisper-base")
31
+ feature_extractor = AutoFeatureExtractor.from_pretrained("openai/whisper-base")
32
+ transcription_results = whisper_model.compute(uploaded=audio_file)
33
+ '''
34
+ audio = whisper.load_audio(audio_file)
35
+ mel = whisper.log_mel_spectrogram(audio).to(model.device)
36
+ _, probs = model.detect_language(mel)
37
+ options = whisper.DecodingOptions(fp16 = False)
38
+ result = whisper.decode(model, mel, options)
39
  # Extract the transcribed text
40
+ # transcribed_text = transcription_results["transcription"]
41
+ transcribed_text = resut.text
42
 
43
  #### Emotion classification ####
44
  emotion_classifier = foreign_class(source="speechbrain/emotion-recognition-wav2vec2-IEMOCAP", pymodule_file="custom_interface.py", classname="CustomEncoderWav2vec2Classifier")