|
import transformers |
|
from transformers import pipeline |
|
|
|
import whisper |
|
|
|
import datetime |
|
|
|
import os |
|
import gradio as gr |
|
from pytube import YouTube |
|
|
|
transformers.utils.move_cache() |
|
|
|
|
|
|
|
|
|
speech_recognition_model = whisper.load_model("base") |
|
|
|
|
|
|
|
|
|
tokenizer_En = transformers.AutoTokenizer.from_pretrained("facebook/bart-large-cnn") |
|
text_summarization_model_En = transformers.AutoModelForSeq2SeqLM.from_pretrained("facebook/bart-large-cnn") |
|
|
|
|
|
|
|
tokenizer_Vi = transformers.AutoTokenizer.from_pretrained("VietAI/vit5-large-vietnews-summarization") |
|
text_summarization_model_Vi = transformers.AutoModelForSeq2SeqLM.from_pretrained("VietAI/vit5-large-vietnews-summarization") |
|
|
|
def asr_transcript(input_file): |
|
audio = whisper.load_audio(input_file) |
|
output = speech_recognition_model.transcribe(audio) |
|
text = output['text'] |
|
lang = "English" |
|
if output["language"] == 'en': |
|
lang = "English" |
|
elif output["language"] == 'vi': |
|
lang = "Vietnamese" |
|
|
|
detail = "" |
|
for segment in output['segments']: |
|
start = str(datetime.timedelta(seconds=round(segment['start']))) |
|
end = str(datetime.timedelta(seconds=round(segment['end']))) |
|
small_text = segment['text'] |
|
detail = detail + start + "-" + end + " " + small_text + "\n" |
|
return text, lang, detail |
|
|
|
def text_summarize_en(text_input): |
|
encoding = tokenizer_En(text_input, truncation=True, return_tensors="pt") |
|
input_ids, attention_masks = encoding["input_ids"], encoding["attention_mask"] |
|
outputs = text_summarization_model_En.generate( |
|
input_ids=input_ids, attention_mask=attention_masks, |
|
max_length=256, |
|
early_stopping=True |
|
) |
|
text = "" |
|
for output in outputs: |
|
line = tokenizer_En.decode(output, skip_special_tokens=True, clean_up_tokenization_spaces=True) |
|
text = text + line |
|
return text |
|
|
|
def text_summarize_vi(text_input): |
|
encoding = tokenizer_Vi(text_input, truncation=True, return_tensors="pt") |
|
input_ids, attention_masks = encoding["input_ids"], encoding["attention_mask"] |
|
outputs = text_summarization_model_Vi.generate( |
|
input_ids=input_ids, attention_mask=attention_masks, |
|
max_length=256, |
|
early_stopping=True |
|
) |
|
text = "" |
|
for output in outputs: |
|
line = tokenizer_Vi.decode(output, skip_special_tokens=True, clean_up_tokenization_spaces=True) |
|
text = text + line |
|
return text |
|
|
|
def text_summarize(text_input, lang): |
|
if lang == 'English': |
|
return text_summarize_en(text_input) |
|
elif lang == 'Vietnamese': |
|
return text_summarize_vi(text_input) |
|
else: |
|
return "" |
|
|
|
def load_video_url(url): |
|
current_dir = os.getcwd() |
|
|
|
try: |
|
yt = YouTube(url) |
|
except: |
|
print("Connection Error") |
|
raise gr.Error("Connection Error") |
|
try: |
|
highest_audio = yt.streams.filter(progressive=False).get_highest_resolution().itag |
|
file_url = os.path.join(current_dir, "audio", "temp.mp4") |
|
yt.streams.get_by_itag(highest_audio).download(output_path=os.path.join(current_dir, "audio"), filename = "temp.mp4") |
|
except : |
|
print("Download video error") |
|
raise gr.Error("Download video error") |
|
|
|
return file_url |