RasmusToivanen commited on
Commit
3dd368d
1 Parent(s): 0a1a14e
Files changed (1) hide show
  1. app.py +15 -8
app.py CHANGED
@@ -17,7 +17,8 @@ from fastapi import FastAPI, HTTPException, File
17
  from transformers import pipeline
18
 
19
 
20
- pipe = pipeline(model="Finnish-NLP/wav2vec2-xlsr-1b-finnish-lm-v2",chunk_length_s=20, stride_length_s=(4, 2))
 
21
 
22
 
23
 
@@ -28,26 +29,32 @@ model = AutoModelForSeq2SeqLM.from_pretrained('Finnish-NLP/case_correction_model
28
 
29
 
30
  # define speech-to-text function
31
- def asr_transcript(audio):
32
 
33
  text = ""
34
 
35
  if audio:
36
- text = pipe(audio.name)
 
 
 
37
 
38
  input_ids = tokenizer(text['text'], return_tensors="pt").input_ids.to(device)
39
  outputs = model.generate(input_ids, max_length=128)
40
  case_corrected_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
41
- return {"text_asr": text['text'], "text_case_corrected": case_corrected_text}
42
  else:
43
  return "File not valid"
44
 
45
  gradio_ui = gr.Interface(
46
  fn=asr_transcript,
47
- title="Speech-to-Text with HuggingFace+Wav2Vec2",
48
  description="Upload an audio clip, and let AI do the hard work of transcribing",
49
- inputs=gr.inputs.Audio(label="Upload Audio File", type="file"),
50
- outputs=gr.outputs.Textbox(label="Auto-Transcript"),
51
  )
52
 
53
- gradio_ui.launch()
 
 
 
 
17
  from transformers import pipeline
18
 
19
 
20
+ pipe_300m = pipeline(model="Finnish-NLP/wav2vec2-xlsr-300m-finnish-lm",chunk_length_s=20, stride_length_s=(3, 3))
21
+ pipe_1b = pipeline(model="Finnish-NLP/wav2vec2-xlsr-1b-finnish-lm-v2",chunk_length_s=20, stride_length_s=(3, 3))
22
 
23
 
24
 
 
29
 
30
 
31
  # define speech-to-text function
32
+ def asr_transcript(audio, model_params):
33
 
34
  text = ""
35
 
36
  if audio:
37
+ if model_params == "300 million":
38
+ text = pipe_300m(audio.name)
39
+ elif model_params == "1 billion":
40
+ text = pipe_1b(audio.name)
41
 
42
  input_ids = tokenizer(text['text'], return_tensors="pt").input_ids.to(device)
43
  outputs = model.generate(input_ids, max_length=128)
44
  case_corrected_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
45
+ return text['text'], case_corrected_text
46
  else:
47
  return "File not valid"
48
 
49
  gradio_ui = gr.Interface(
50
  fn=asr_transcript,
51
+ title="Finnish Automatic Speech-Recognition",
52
  description="Upload an audio clip, and let AI do the hard work of transcribing",
53
+ inputs=[gr.inputs.Audio(label="Upload Audio File", type="file"), gr.inputs.Dropdown(choices=["300 million", "1 billion"], type="value", default="1 billion", label="Select speech recognition model parameter amount", optional=False)],
54
+ outputs=[gr.outputs.Textbox(label="Recognized speech"),gr.outputs.Textbox(label="Recognized speech with case correction and punctuation")]
55
  )
56
 
57
+ gradio_ui.launch()
58
+
59
+
60
+ os.environ.get('hf_token')