VMORnD commited on
Commit
1c86fc4
1 Parent(s): 1263980

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +74 -69
app.py CHANGED
@@ -1,84 +1,89 @@
1
  #Importing all the necessary packages
2
- import nltk
3
- import librosa
4
- import IPython.display
5
- import torch
6
  import gradio as gr
7
- from transformers import Wav2Vec2Tokenizer, Wav2Vec2ForCTC
8
- nltk.download("punkt")
9
  #Loading the model and the tokenizer
10
- model_name = "facebook/wav2vec2-base-960h"
11
- tokenizer = Wav2Vec2Tokenizer.from_pretrained(model_name)#omdel_name
12
- model = Wav2Vec2ForCTC.from_pretrained(model_name)
13
 
14
- def load_data(input_file):
15
- """ Function for resampling to ensure that the speech input is sampled at 16KHz.
16
- """
17
- #read the file
18
- speech, sample_rate = librosa.load(input_file)
19
- #make it 1-D
20
- if len(speech.shape) > 1:
21
- speech = speech[:,0] + speech[:,1]
22
- #Resampling at 16KHz since wav2vec2-base-960h is pretrained and fine-tuned on speech audio sampled at 16 KHz.
23
- if sample_rate !=16000:
24
- speech = librosa.resample(speech, sample_rate,16000)
25
- #speeches = librosa.effects.split(speech)
26
- return speech
27
- def correct_casing(input_sentence):
28
- """ This function is for correcting the casing of the generated transcribed text
29
- """
30
- sentences = nltk.sent_tokenize(input_sentence)
31
- return (' '.join([s.replace(s[0],s[0].capitalize(),1) for s in sentences]))
32
 
33
- def asr_transcript(input_file):
34
- """This function generates transcripts for the provided audio input
35
- """
36
- speech = load_data(input_file)
37
- #Tokenize
38
- input_values = tokenizer(speech, return_tensors="pt").input_values
39
- #Take logits
40
- logits = model(input_values).logits
41
- #Take argmax
42
- predicted_ids = torch.argmax(logits, dim=-1)
43
- #Get the words from predicted word ids
44
- transcription = tokenizer.decode(predicted_ids[0])
45
- #Output is all upper case
46
- transcription = correct_casing(transcription.lower())
47
- return transcription
48
- def asr_transcript_long(input_file,tokenizer=tokenizer, model=model ):
49
- transcript = ""
50
- # Ensure that the sample rate is 16k
51
- sample_rate = librosa.get_samplerate(input_file)
52
 
53
- # Stream over 10 seconds chunks rather than load the full file
54
- stream = librosa.stream(
55
- input_file,
56
- block_length=20, #number of seconds to split the batch
57
- frame_length=sample_rate, #16000,
58
- hop_length=sample_rate, #16000
59
- )
60
 
61
- for speech in stream:
62
- if len(speech.shape) > 1:
63
- speech = speech[:, 0] + speech[:, 1]
64
- if sample_rate !=16000:
65
- speech = librosa.resample(speech, sample_rate,16000)
66
- input_values = tokenizer(speech, return_tensors="pt").input_values
67
- logits = model(input_values).logits
68
 
69
- predicted_ids = torch.argmax(logits, dim=-1)
70
- transcription = tokenizer.decode(predicted_ids[0])
71
- #transcript += transcription.lower()
72
- transcript += correct_casing(transcription.lower())
73
- #transcript += " "
 
 
 
 
 
 
 
74
 
75
- return transcript[:3800]
76
  gr.Interface(asr_transcript_long,
77
  #inputs = gr.inputs.Audio(source="microphone", type="filepath", optional=True, label="Please record your voice"),
78
  inputs = gr.inputs.Audio(source="upload", type="filepath", optional=True, label="Upload your audio file here"),
79
  outputs = gr.outputs.Textbox(type="str",label="Output Text"),
80
  title="English Automated Speech Summarization",
81
  description = "This tool transcribes your audio to the text",
82
- # examples = [["Batman1_dialogue.wav"], ["Batman2_dialogue.wav"], ["Batman3_dialogue.wav"],["catwoman_dialogue.wav"]],
83
- theme="grass").launch()
84
-
 
1
  #Importing all the necessary packages
2
+ # import nltk
3
+ # import librosa
4
+ # import IPython.display
5
+ # import torch
6
  import gradio as gr
7
+ # from transformers import Wav2Vec2Tokenizer, Wav2Vec2ForCTC
8
+ # nltk.download("punkt")
9
  #Loading the model and the tokenizer
10
+ # model_name = "facebook/wav2vec2-base-960h"
11
+ # tokenizer = Wav2Vec2Tokenizer.from_pretrained(model_name)#omdel_name
12
+ # model = Wav2Vec2ForCTC.from_pretrained(model_name)
13
 
14
+ # def load_data(input_file):
15
+ # """ Function for resampling to ensure that the speech input is sampled at 16KHz.
16
+ # """
17
+ # #read the file
18
+ # speech, sample_rate = librosa.load(input_file)
19
+ # #make it 1-D
20
+ # if len(speech.shape) > 1:
21
+ # speech = speech[:,0] + speech[:,1]
22
+ # #Resampling at 16KHz since wav2vec2-base-960h is pretrained and fine-tuned on speech audio sampled at 16 KHz.
23
+ # if sample_rate !=16000:
24
+ # speech = librosa.resample(speech, sample_rate,16000)
25
+ # #speeches = librosa.effects.split(speech)
26
+ # return speech
27
+ # def correct_casing(input_sentence):
28
+ # """ This function is for correcting the casing of the generated transcribed text
29
+ # """
30
+ # sentences = nltk.sent_tokenize(input_sentence)
31
+ # return (' '.join([s.replace(s[0],s[0].capitalize(),1) for s in sentences]))
32
 
33
+ # def asr_transcript(input_file):
34
+ # """This function generates transcripts for the provided audio input
35
+ # """
36
+ # speech = load_data(input_file)
37
+ # #Tokenize
38
+ # input_values = tokenizer(speech, return_tensors="pt").input_values
39
+ # #Take logits
40
+ # logits = model(input_values).logits
41
+ # #Take argmax
42
+ # predicted_ids = torch.argmax(logits, dim=-1)
43
+ # #Get the words from predicted word ids
44
+ # transcription = tokenizer.decode(predicted_ids[0])
45
+ # #Output is all upper case
46
+ # transcription = correct_casing(transcription.lower())
47
+ # return transcription
48
+ # def asr_transcript_long(input_file,tokenizer=tokenizer, model=model ):
49
+ # transcript = ""
50
+ # # Ensure that the sample rate is 16k
51
+ # sample_rate = librosa.get_samplerate(input_file)
52
 
53
+ # # Stream over 10 seconds chunks rather than load the full file
54
+ # stream = librosa.stream(
55
+ # input_file,
56
+ # block_length=20, #number of seconds to split the batch
57
+ # frame_length=sample_rate, #16000,
58
+ # hop_length=sample_rate, #16000
59
+ # )
60
 
61
+ # for speech in stream:
62
+ # if len(speech.shape) > 1:
63
+ # speech = speech[:, 0] + speech[:, 1]
64
+ # if sample_rate !=16000:
65
+ # speech = librosa.resample(speech, sample_rate,16000)
66
+ # input_values = tokenizer(speech, return_tensors="pt").input_values
67
+ # logits = model(input_values).logits
68
 
69
+ # predicted_ids = torch.argmax(logits, dim=-1)
70
+ # transcription = tokenizer.decode(predicted_ids[0])
71
+ # #transcript += transcription.lower()
72
+ # transcript += correct_casing(transcription.lower())
73
+ # #transcript += " "
74
+
75
+ # return transcript[:3800]
76
+ from transformers import pipeline
77
+ p=pipeline("automatic-speech-recognition", model="facebook/wav2vec2-base-960h")
78
+
79
+ def asr_transcript_long(input_file):
80
+ return p(input_file, chunk_length_s=10, stride_length_s=(2, 2))['text']
81
 
 
82
  gr.Interface(asr_transcript_long,
83
  #inputs = gr.inputs.Audio(source="microphone", type="filepath", optional=True, label="Please record your voice"),
84
  inputs = gr.inputs.Audio(source="upload", type="filepath", optional=True, label="Upload your audio file here"),
85
  outputs = gr.outputs.Textbox(type="str",label="Output Text"),
86
  title="English Automated Speech Summarization",
87
  description = "This tool transcribes your audio to the text",
88
+ examples = [["sample 1.flac"], ["sample 2.flac"], ["sample 3.flac"],["TheDiverAnUncannyTale.mp3"]],
89
+ theme="grass").launch()