import torch import cv2 import videotransforms import numpy as np import gradio as gr from einops import rearrange from torchvision import transforms from pytorch_i3d import InceptionI3d def preprocess(vidpath): cap = cv2.VideoCapture(vidpath) frames = [] cap.set(cv2.CAP_PROP_POS_FRAMES, 0) num = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) for _ in range(num): _, img = cap.read() w, h, c = img.shape if w < 226 or h < 226: d = 226. - min(w, h) sc = 1 + d / min(w, h) img = cv2.resize(img, dsize=(0, 0), fx=sc, fy=sc) img = (img / 255.) * 2 - 1 frames.append(img) frames = torch.Tensor(np.asarray(frames, dtype=np.float32)) transform = transforms.Compose([videotransforms.CenterCrop(224)]) frames = transform(frames) frames = rearrange(frames, 't h w c-> 1 c t h w') return frames def classify(video,dataset='WLASL100'): to_load = { 'WLASL100':{'logits':100,'path':'weights/asl100/FINAL_nslt_100_iters=896_top1=65.89_top5=84.11_top10=89.92.pt'}, 'WLASL2000':{'logits':2000,'path':'weights/asl2000/FINAL_nslt_2000_iters=5104_top1=32.48_top5=57.31_top10=66.31.pt'} } input = preprocess(video) model = InceptionI3d() model.cpu() model.load_state_dict(torch.load('weights/rgb_imagenet.pt')) model.replace_logits(to_load[dataset]['logits']) model.load_state_dict(torch.load(to_load[dataset]['path'])) model.eval() with torch.no_grad(): per_frame_logits = model(input) predictions = rearrange(per_frame_logits,'1 j k -> j k') predictions = torch.mean(predictions, dim = 1) top = torch.argmax(predictions).item() _, index = torch.topk(predictions,10) index = index.numpy() with open('wlasl_class_list.txt') as f: idx2label = dict() for line in f: idx2label[int(line.split()[0])]=line.split()[1] predictions = torch.nn.functional.softmax(predictions, dim=0).numpy() return {idx2label[i]:float(predictions[i]) for i in index} title = "I3D Sign Language Recognition" description = "Description here" examples = [['videos/no.mp4','WLASL100'],['videos/all.mp4','WLASL100'],['videos/blue.mp4','WLASL2000'],['videos/white.mp4','WLASL2000'],['videos/accident.mp4','WLASL2000']] gr.Interface( fn=classify, inputs=[gr.inputs.Video(label="VIDEO"),gr.inputs.Dropdown(choices=['WLASL100','WLASL2000'], default='WLASL100', label='DATASET USED')], outputs=[gr.outputs.Label(num_top_classes=5, label='Top 5 Predictions')], allow_flagging="never", title=title, description=description, examples=examples).launch()