Relax-Teacher / app.py
Rimi98's picture
Update app.py
f5b3004
import gradio as gr
import onnxruntime
from transformers import AutoTokenizer
import torch
import os
from transformers import pipeline
import subprocess
import moviepy.editor as mp
import base64
token = AutoTokenizer.from_pretrained('distilroberta-base')
inf_session = onnxruntime.InferenceSession('classifier-quantized2.onnx')
input_name = inf_session.get_inputs()[0].name
output_name = inf_session.get_outputs()[0].name
classes = ['Art', 'Astrology', 'Biology', 'Chemistry', 'Economics', 'History', 'Literature', 'Philosophy', 'Physics', 'Politics', 'Psychology', 'Sociology']
### --- Audio/Video to txt ---###
device = "cuda:0" if torch.cuda.is_available() else "cpu"
pipe = pipeline("automatic-speech-recognition",
model="openai/whisper-tiny.en",
chunk_length_s=30, device=device)
### --- Text Summary --- ###
summarizer = pipeline("summarization", model="sshleifer/distilbart-cnn-12-6", device=device)
def video_identity(video):
transcription = pipe(video)["text"]
return transcription
def summary(text):
text = text.split('.')
max_chunk = 500
current_chunk = 0
chunks = []
for t in text:
if len(chunks) == current_chunk + 1:
if len(chunks[current_chunk]) + len(t.split(' ')) <= max_chunk:
chunks[current_chunk].extend(t.split(' '))
else:
current_chunk += 1
chunks.append(t.split(' '))
else:
chunks.append(t.split(' '))
for chunk in range(len(chunks)):
chunks[chunk] =' '.join(chunks[chunk])
summ = summarizer(chunks,max_length = 100)
return summ
def classify(video_file,encoded_video):
if encoded_video != "":
decoded_file_data = base64.b64decode(encoded_video)
with open("temp_video.mp4", "wb") as f:
f.write(decoded_file_data)
video_file = "temp_video.mp4"
clip = mp.VideoFileClip(video_file)
clip.audio.write_audiofile(r"audio.wav")
full_text = video_identity(r"audio.wav")
sum = summary(full_text)[0]['summary_text']
input_ids = token(sum)['input_ids'][:512]
logits = inf_session.run([output_name],{input_name : [input_ids]})[0]
logits = torch.FloatTensor(logits)
probs = torch.sigmoid(logits)[0]
probs = list(probs)
label = classes[probs.index(max(probs))]
final = {
'text':full_text,
'summary':sum,
'label':label,
}
return final
text1 = gr.Textbox(label="Text")
text2 = gr.Textbox(label="Summary")
iface = gr.Interface(fn=classify,
inputs=['video','text'],
outputs = ['json'])
iface.launch(inline=False)