File size: 2,776 Bytes
d3d8d59 2fe95f1 d3d8d59 2fe95f1 d3d8d59 bb83661 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 |
import torch
import cv2
import videotransforms
import numpy as np
import gradio as gr
from einops import rearrange
from torchvision import transforms
from pytorch_i3d import InceptionI3d
def preprocess(vidpath):
cap = cv2.VideoCapture(vidpath)
frames = []
cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
num = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
for _ in range(num):
_, img = cap.read()
w, h, c = img.shape
if w < 226 or h < 226:
d = 226. - min(w, h)
sc = 1 + d / min(w, h)
img = cv2.resize(img, dsize=(0, 0), fx=sc, fy=sc)
img = (img / 255.) * 2 - 1
frames.append(img)
frames = torch.Tensor(np.asarray(frames, dtype=np.float32))
transform = transforms.Compose([videotransforms.CenterCrop(224)])
frames = transform(frames)
frames = rearrange(frames, 't h w c-> 1 c t h w')
return frames
def classify(video,dataset='WLASL100'):
to_load = {
'WLASL100':{'logits':100,'path':'weights/asl100/FINAL_nslt_100_iters=896_top1=65.89_top5=84.11_top10=89.92.pt'},
'WLASL2000':{'logits':2000,'path':'weights/asl2000/FINAL_nslt_2000_iters=5104_top1=32.48_top5=57.31_top10=66.31.pt'}
}
input = preprocess(video)
model = InceptionI3d()
model.load_state_dict(torch.load('weights/rgb_imagenet.pt'))
model.replace_logits(to_load[dataset]['logits'])
model.load_state_dict(torch.load(to_load[dataset]['path']))
model.eval()
with torch.no_grad():
per_frame_logits = model(input)
predictions = rearrange(per_frame_logits,'1 j k -> j k')
predictions = torch.mean(predictions, dim = 1)
top = torch.argmax(predictions).item()
_, index = torch.topk(predictions,10)
index = index.numpy()
with open('wlasl_class_list.txt') as f:
idx2label = dict()
for line in f:
idx2label[int(line.split()[0])]=line.split()[1]
predictions = torch.nn.functional.softmax(predictions, dim=0).numpy()
return {idx2label[i]:float(predictions[i]) for i in index}
title = "I3D Sign Language Recognition"
description = "Description here"
examples = [['videos/no.mp4','WLASL100'],['videos/all.mp4','WLASL100'],['videos/blue.mp4','WLASL2000'],['videos/white.mp4','WLASL2000'],['videos/accident.mp4','WLASL2000']]
gr.Interface( fn=classify,
inputs=[gr.inputs.Video(label="VIDEO"),gr.inputs.Dropdown(choices=['WLASL100','WLASL2000'], default='WLASL100', label='DATASET USED')],
outputs=[gr.outputs.Label(num_top_classes=5, label='Top 5 Predictions')],
allow_flagging="never",
title=title,
description=description,
examples=examples).launch(cache_examples=True)
|