nouamanetazi's picture
nouamanetazi HF staff
linting
c731c61
raw
history blame
3.16 kB
import re
import glob
import pickle
import os
import torch
import numpy as np
from utils.audio import load_spectrograms
from utils.compute_args import compute_args
from utils.tokenize import (
tokenize,
create_dict,
sent_to_ix,
cmumosei_2,
cmumosei_7,
pad_feature,
)
from model_LA import Model_LA
import gradio as gr
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# load model
ckpts_path = "ckpt"
model_name = "Model_LA_e"
# Listing sorted checkpoints
ckpts = sorted(glob.glob(os.path.join(ckpts_path, model_name, "best*")), reverse=True)
# Load original args
args = torch.load(ckpts[0], map_location=torch.device(device))["args"]
args = compute_args(args)
pretrained_emb = np.load("train_glove.npy")
token_to_ix = pickle.load(open("token_to_ix.pkl", "rb"))
state_dict = torch.load(ckpts[0], map_location=torch.device(device))["state_dict"]
net = Model_LA(args, len(token_to_ix), pretrained_emb).to(device)
net.load_state_dict(state_dict)
def inference(source_video, transcription):
# data preprocessing
# text
def clean(w):
return (
re.sub(r"([.,'!?\"()*#:;])", "", w.lower())
.replace("-", " ")
.replace("/", " ")
)
s = [clean(w) for w in transcription.split() if clean(w) != ""]
# Sound
_, mel, mag = load_spectrograms(source_video)
l_max_len = args.lang_seq_len
a_max_len = args.audio_seq_len
v_max_len = args.video_seq_len
L = sent_to_ix(s, token_to_ix, max_token=l_max_len)
A = pad_feature(mel, a_max_len)
V = pad_feature(mel, v_max_len)
# print shapes
print(f"Processed text shape from {len(s)} to {L.shape}")
print(f"Processed audio shape from {mel.shape} to {A.shape}")
print(f"Processed video shape from {mel.shape} to {V.shape}")
net.train(False)
x = np.expand_dims(L, axis=0)
y = np.expand_dims(A, axis=0)
z = np.expand_dims(V, axis=0)
x, y, z = (
torch.from_numpy(x).to(device),
torch.from_numpy(y).to(device),
torch.from_numpy(z).float().to(device),
)
pred = net(x, y, z).cpu().data.numpy()[0]
# pred = np.exp(pred) / np.sum(np.exp(pred)) # softmax
label_to_ix = ["happy", "sad", "angry", "fear", "disgust", "surprise"]
# result_dict = {label_to_ix[i]: float(pred[i]) for i in range(len(label_to_ix))}
result_dict = {label_to_ix[i]: float(pred[i]) > 0 for i in range(len(label_to_ix))}
return result_dict
title = "Emotion Recognition"
description = ""
examples = [
[
"examples/0h-zjBukYpk_2.mp4",
"NOW IM NOT EVEN GONNA SUGAR COAT THIS THIS MOVIE FRUSTRATED ME TO SUCH AN EXTREME EXTENT THAT I WAS LOUDLY EXCLAIMING WHY AT THE END OF THE FILM",
],
["examples/0h-zjBukYpk_19.mp4", "NOW OTHER PERFORMANCES ARE BORDERLINE OKAY"],
["examples/03bSnISJMiM_1.mp4", "IT WAS REALLY GOOD "],
["examples/03bSnISJMiM_5.mp4", "AND THEY SHOULDVE I GUESS "],
]
gr.Interface(
inference,
inputs=[gr.inputs.Video(type="avi", source="upload"), "text"],
outputs=["label"],
title=title,
description=description,
examples=examples,
).launch(debug=True)