File size: 4,228 Bytes
d3d8d59 2fe95f1 d3d8d59 2fe95f1 d3d8d59 2a7c856 d3d8d59 2a7c856 d3d8d59 17d99b0 2a7c856 17d99b0 d3d8d59 2a7c856 d3d8d59 2a7c856 d3d8d59 2a7c856 d3d8d59 2a7c856 d3d8d59 2a7c856 d3d8d59 2a7c856 d3d8d59 7a783a2 d3d8d59 7a783a2 2a7c856 7a783a2 2a7c856 d3d8d59 2a7c856 d3d8d59 7a783a2 2a7c856 d3d8d59 2a7c856 d3d8d59 7a783a2 d3d8d59 2a7c856 d3d8d59 2a7c856 7a783a2 d3d8d59 2a7c856 d3d8d59 2a7c856 d3d8d59 2a7c856 17d99b0 d3d8d59 2a7c856 d3d8d59 2a7c856 d3d8d59 2a7c856 d3d8d59 2a7c856 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
import torch
import cv2
import videotransforms
import numpy as np
import gradio as gr
from einops import rearrange
from torchvision import transforms
from pytorch_i3d import InceptionI3d
def preprocess(vidpath):
# Fetch video
cap = cv2.VideoCapture(vidpath)
frames = []
cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
num = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
# Extract frames from video
for _ in range(num):
_, img = cap.read()
# Skip NoneType frames
if img is None:
continue
# Resize if (w,h) < (226,226)
w, h, c = img.shape
if w < 226 or h < 226:
d = 226. - min(w, h)
sc = 1 + d / min(w, h)
img = cv2.resize(img, dsize=(0, 0), fx=sc, fy=sc)
# Normalize
img = (img / 255.) * 2 - 1
frames.append(img)
frames = torch.Tensor(np.asarray(frames, dtype=np.float32))
# Transform tensor and reshape to (1, c, t ,w, h)
transform = transforms.Compose([videotransforms.CenterCrop(224)])
frames = transform(frames)
frames = rearrange(frames, 't w h c-> 1 c t w h')
return frames
def classify(video,dataset='WLASL100'):
to_load = {
'WLASL100':{'logits':100,'path':'weights/asl100/FINAL_nslt_100_iters=896_top1=65.89_top5=84.11_top10=89.92.pt'},
'WLASL2000':{'logits':2000,'path':'weights/asl2000/FINAL_nslt_2000_iters=5104_top1=32.48_top5=57.31_top10=66.31.pt'}
}
# Preprocess video
input = preprocess(video)
# Load model
model = InceptionI3d()
model.load_state_dict(torch.load('weights/rgb_imagenet.pt',map_location=torch.device('cpu')))
model.replace_logits(to_load[dataset]['logits'])
model.load_state_dict(torch.load(to_load[dataset]['path'],map_location=torch.device('cpu')))
# Run on cpu. Spaces environment is limited to CPU for free users.
model.cpu()
# Evaluation mode
model.eval()
with torch.no_grad(): # Disable gradient computation
per_frame_logits = model(input) # Inference
per_frame_logits.cpu()
model.cpu()
# Load predictions
predictions = rearrange(per_frame_logits,'1 j k -> j k')
predictions = torch.mean(predictions, dim = 1)
# Fetch top 10 predictions
_, index = torch.topk(predictions,10)
index = index.cpu().numpy()
# Load labels
with open('wlasl_class_list.txt') as f:
idx2label = dict()
for line in f:
idx2label[int(line.split()[0])]=line.split()[1]
# Get probabilities
predictions = torch.nn.functional.softmax(predictions, dim=0).cpu().numpy()
# Return dict {label:pred}
return {idx2label[i]:float(predictions[i]) for i in index}
# Gradio App config
title = "I3D Sign Language Recognition"
description = "Gradio demo of word-level sign language classification using I3D model pretrained on the WLASL video dataset. " \
"WLASL is a large-scale dataset containing more than 2000 words in American Sign Language. " \
"Examples used in the demo are videos from the the test subset. " \
"Note that WLASL100 contains 100 words while WLASL2000 contains 2000."
examples = [
['videos/no.mp4','WLASL100'],
['videos/all.mp4','WLASL100'],
['videos/before.mp4','WLASL100'],
['videos/blue.mp4','WLASL2000'],
['videos/white.mp4','WLASL2000'],
['videos/accident2.mp4','WLASL2000']
]
article = "NOTE: This is not the official demonstration of the I3D sign language classification on the WLASL dataset. "\
"More information about the WLASL dataset and pretrained I3D models can be found <a href=https://github.com/dxli94/WLASL>here</a>."
# Gradio App interface
gr.Interface( fn=classify,
inputs=[gr.inputs.Video(label="Video (*.mp4)"),gr.inputs.Radio(choices=['WLASL100','WLASL2000'], default='WLASL100', label='Trained on:')],
outputs=[gr.outputs.Label(num_top_classes=5, label='Top 5 Predictions')],
allow_flagging="never",
title=title,
description=description,
examples=examples,
article=article).launch()
|