anzorq commited on
Commit
550d732
1 Parent(s): 6fd478d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -9
app.py CHANGED
@@ -17,22 +17,23 @@ def transcribe_speech(audio):
17
  waveform, sr = torchaudio.load(audio)
18
 
19
  # Resample the audio if needed
20
- resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=16000)
21
- waveform = resampler(waveform)
 
22
 
23
  # Convert to mono if needed
24
  if waveform.dim() > 1:
25
- waveform = torch.mean(waveform, dim=0)
26
 
27
  # Normalize the audio
28
  waveform = waveform / torch.max(torch.abs(waveform))
29
 
30
  # Extract input features
31
- input_features = processor(waveform.unsqueeze(0), sampling_rate=16000).input_features
32
- input_features = torch.from_numpy(input_features).to(device)
33
-
34
- # Generate logits using the model
35
  with torch.no_grad():
 
 
 
 
36
  logits = model(input_features).logits
37
 
38
  # Decode the predicted ids to text
@@ -44,7 +45,7 @@ def transcribe_speech(audio):
44
  @spaces.GPU
45
  def transcribe_from_youtube(url):
46
  # Download audio from YouTube using yt-dlp
47
- audio_path = "downloaded_audio.wav"
48
  ydl_opts = {
49
  'format': 'bestaudio/best',
50
  'outtmpl': audio_path,
@@ -60,8 +61,17 @@ def transcribe_from_youtube(url):
60
  with yt_dlp.YoutubeDL(ydl_opts) as ydl:
61
  ydl.download([url])
62
 
 
 
 
 
63
  # Transcribe the downloaded audio
64
- return transcribe_speech(audio_path)
 
 
 
 
 
65
 
66
  with gr.Blocks() as demo:
67
  with gr.Tab("Microphone Input"):
 
17
  waveform, sr = torchaudio.load(audio)
18
 
19
  # Resample the audio if needed
20
+ if sr != 16000:
21
+ resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=16000)
22
+ waveform = resampler(waveform)
23
 
24
  # Convert to mono if needed
25
  if waveform.dim() > 1:
26
+ waveform = torchaudio.transforms.DownmixMono()(waveform)
27
 
28
  # Normalize the audio
29
  waveform = waveform / torch.max(torch.abs(waveform))
30
 
31
  # Extract input features
 
 
 
 
32
  with torch.no_grad():
33
+ input_features = processor(waveform.unsqueeze(0), sampling_rate=16000).input_features
34
+ input_features = torch.from_numpy(input_features).to(device)
35
+
36
+ # Generate logits using the model
37
  logits = model(input_features).logits
38
 
39
  # Decode the predicted ids to text
 
45
  @spaces.GPU
46
  def transcribe_from_youtube(url):
47
  # Download audio from YouTube using yt-dlp
48
+ audio_path = f"downloaded_audio_{url.split('=')[-1]}.wav"
49
  ydl_opts = {
50
  'format': 'bestaudio/best',
51
  'outtmpl': audio_path,
 
61
  with yt_dlp.YoutubeDL(ydl_opts) as ydl:
62
  ydl.download([url])
63
 
64
+ # # Check if the file exists
65
+ # if not os.path.exists(audio_path):
66
+ # raise FileNotFoundError(f"Failed to find the audio file {audio_path}")
67
+
68
  # Transcribe the downloaded audio
69
+ transcription = transcribe_speech(audio_path)
70
+
71
+ # Optionally, clean up the downloaded file
72
+ os.remove(audio_path)
73
+
74
+ return transcription
75
 
76
  with gr.Blocks() as demo:
77
  with gr.Tab("Microphone Input"):