ASS / core.py
VMORnD's picture
Update core.py
ab76106
raw
history blame
3.66 kB
import transformers
from transformers import pipeline
import whisper
import datetime
import os
import gradio as gr
from pytube import YouTube
transformers.utils.move_cache()
# ====================================
# Load speech recognition model
# speech_recognition_pipeline = pipeline("automatic-speech-recognition", model="facebook/wav2vec2-base-960h")
speech_recognition_model = whisper.load_model("base")
# ====================================
# Load text summarization model English
# text_summarization_pipeline_En = pipeline("summarization", model="facebook/bart-large-cnn")
tokenizer_En = transformers.AutoTokenizer.from_pretrained("facebook/bart-large-cnn")
text_summarization_model_En = transformers.AutoModelForSeq2SeqLM.from_pretrained("facebook/bart-large-cnn")
# ====================================
# Load text summarization model Vietnamese
tokenizer_Vi = transformers.AutoTokenizer.from_pretrained("VietAI/vit5-large-vietnews-summarization")
text_summarization_model_Vi = transformers.AutoModelForSeq2SeqLM.from_pretrained("VietAI/vit5-large-vietnews-summarization")
def asr_transcript(input_file):
audio = whisper.load_audio(input_file)
output = speech_recognition_model.transcribe(audio)
text = output['text']
lang = "English"
if output["language"] == 'en':
lang = "English"
elif output["language"] == 'vi':
lang = "Vietnamese"
detail = ""
for segment in output['segments']:
start = str(datetime.timedelta(seconds=round(segment['start'])))
end = str(datetime.timedelta(seconds=round(segment['end'])))
small_text = segment['text']
detail = detail + start + "-" + end + " " + small_text + "\n"
return text, lang, detail
def text_summarize_en(text_input):
encoding = tokenizer_En(text_input, truncation=True, return_tensors="pt")
input_ids, attention_masks = encoding["input_ids"], encoding["attention_mask"]
outputs = text_summarization_model_En.generate(
input_ids=input_ids, attention_mask=attention_masks,
max_length=256,
early_stopping=True
)
text = ""
for output in outputs:
line = tokenizer_En.decode(output, skip_special_tokens=True, clean_up_tokenization_spaces=True)
text = text + line
return text
def text_summarize_vi(text_input):
encoding = tokenizer_Vi(text_input, truncation=True, return_tensors="pt")
input_ids, attention_masks = encoding["input_ids"], encoding["attention_mask"]
outputs = text_summarization_model_Vi.generate(
input_ids=input_ids, attention_mask=attention_masks,
max_length=256,
early_stopping=True
)
text = ""
for output in outputs:
line = tokenizer_Vi.decode(output, skip_special_tokens=True, clean_up_tokenization_spaces=True)
text = text + line
return text
def text_summarize(text_input, lang):
if lang == 'English':
return text_summarize_en(text_input)
elif lang == 'Vietnamese':
return text_summarize_vi(text_input)
else:
return ""
def load_video_url(url):
current_dir = os.getcwd()
try:
yt = YouTube(url)
except:
print("Connection Error")
raise gr.Error("Connection Error")
try:
highest_audio = yt.streams.filter(progressive=False).get_highest_resolution().itag
file_url = os.path.join(current_dir, "audio", "temp.mp4")
yt.streams.get_by_itag(highest_audio).download(output_path=os.path.join(current_dir, "audio"), filename = "temp.mp4")
except :
print("Download video error")
raise gr.Error("Download video error")
return file_url