RasmusToivanen
commited on
Commit
•
0f4c2e3
1
Parent(s):
08774d8
add 94m model
Browse files
app.py
CHANGED
@@ -17,20 +17,24 @@ from fastapi import FastAPI, HTTPException, File
|
|
17 |
from transformers import pipeline
|
18 |
|
19 |
|
|
|
|
|
20 |
pipe_300m = pipeline(model="Finnish-NLP/wav2vec2-xlsr-300m-finnish-lm",chunk_length_s=20, stride_length_s=(3, 3))
|
|
|
21 |
pipe_1b = pipeline(model="Finnish-NLP/wav2vec2-xlsr-1b-finnish-lm-v2",chunk_length_s=20, stride_length_s=(3, 3))
|
22 |
|
23 |
|
24 |
|
25 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
26 |
-
model_checkpoint = 'Finnish-NLP/t5x-small-nl24-
|
27 |
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_auth_token=os.environ.get('hf_token'))
|
28 |
-
model = AutoModelForSeq2SeqLM.from_pretrained(
|
29 |
|
30 |
|
31 |
# define speech-to-text function
|
32 |
def asr_transcript(audio, audio_microphone, model_params):
|
33 |
|
|
|
34 |
audio = audio_microphone if audio_microphone else audio
|
35 |
|
36 |
if audio == None and audio_microphone == None:
|
@@ -38,10 +42,12 @@ def asr_transcript(audio, audio_microphone, model_params):
|
|
38 |
text = ""
|
39 |
|
40 |
if audio:
|
41 |
-
if model_params == "
|
42 |
-
text = pipe_300m(audio.name)
|
43 |
-
elif model_params == "1 billion":
|
44 |
text = pipe_1b(audio.name)
|
|
|
|
|
|
|
|
|
45 |
|
46 |
input_ids = tokenizer(text['text'], return_tensors="pt").input_ids.to(device)
|
47 |
outputs = model.generate(input_ids, max_length=128)
|
@@ -52,9 +58,9 @@ def asr_transcript(audio, audio_microphone, model_params):
|
|
52 |
|
53 |
gradio_ui = gr.Interface(
|
54 |
fn=asr_transcript,
|
55 |
-
title="Finnish
|
56 |
description="Upload an audio clip, and let AI do the hard work of transcribing",
|
57 |
-
inputs=[gr.inputs.Audio(label="Upload Audio File", type="file", optional=True), gr.inputs.Audio(source="microphone", type="file", optional=True, label="Record"), gr.inputs.Dropdown(choices=["300 million", "1 billion"], type="value", default="1 billion", label="Select speech recognition model parameter amount", optional=False)],
|
58 |
outputs=[gr.outputs.Textbox(label="Recognized speech"),gr.outputs.Textbox(label="Recognized speech with case correction and punctuation")]
|
59 |
)
|
60 |
|
|
|
17 |
from transformers import pipeline
|
18 |
|
19 |
|
20 |
+
|
21 |
+
|
22 |
pipe_300m = pipeline(model="Finnish-NLP/wav2vec2-xlsr-300m-finnish-lm",chunk_length_s=20, stride_length_s=(3, 3))
|
23 |
+
pipe_94m = pipeline(model="Finnish-NLP/wav2vec2-base-fi-voxpopuli-v2-finetuned",chunk_length_s=20, stride_length_s=(3, 3))
|
24 |
pipe_1b = pipeline(model="Finnish-NLP/wav2vec2-xlsr-1b-finnish-lm-v2",chunk_length_s=20, stride_length_s=(3, 3))
|
25 |
|
26 |
|
27 |
|
28 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
29 |
+
model_checkpoint = 'Finnish-NLP/t5x-small-nl24-casing-punctuation-correction'
|
30 |
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_auth_token=os.environ.get('hf_token'))
|
31 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint, from_flax=False, torch_dtype=torch.float32, use_auth_token=os.environ.get('hf_token')).to(device)
|
32 |
|
33 |
|
34 |
# define speech-to-text function
|
35 |
def asr_transcript(audio, audio_microphone, model_params):
|
36 |
|
37 |
+
|
38 |
audio = audio_microphone if audio_microphone else audio
|
39 |
|
40 |
if audio == None and audio_microphone == None:
|
|
|
42 |
text = ""
|
43 |
|
44 |
if audio:
|
45 |
+
if model_params == "1 billion multi":
|
|
|
|
|
46 |
text = pipe_1b(audio.name)
|
47 |
+
elif model_params == "94 million fi":
|
48 |
+
text = pipe_94m(audio.name)
|
49 |
+
elif model_params == "300 million multi":
|
50 |
+
text = pipe_300m(audio.name)
|
51 |
|
52 |
input_ids = tokenizer(text['text'], return_tensors="pt").input_ids.to(device)
|
53 |
outputs = model.generate(input_ids, max_length=128)
|
|
|
58 |
|
59 |
gradio_ui = gr.Interface(
|
60 |
fn=asr_transcript,
|
61 |
+
title="Finnish automatic speech recognition",
|
62 |
description="Upload an audio clip, and let AI do the hard work of transcribing",
|
63 |
+
inputs=[gr.inputs.Audio(label="Upload Audio File", type="file", optional=True), gr.inputs.Audio(source="microphone", type="file", optional=True, label="Record from microphone"), gr.inputs.Dropdown(choices=["94 million fi", "300 million multi", "1 billion multi"], type="value", default="1 billion multi", label="Select speech recognition model parameter amount", optional=False)],
|
64 |
outputs=[gr.outputs.Textbox(label="Recognized speech"),gr.outputs.Textbox(label="Recognized speech with case correction and punctuation")]
|
65 |
)
|
66 |
|