File size: 2,776 Bytes
d3d8d59
 
 
 
2fe95f1
d3d8d59
 
 
2fe95f1
 
d3d8d59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bb83661
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import torch
import cv2
import videotransforms
import numpy as np
import gradio as gr
from einops import rearrange
from torchvision import transforms
from pytorch_i3d import InceptionI3d


def preprocess(vidpath):
    cap = cv2.VideoCapture(vidpath)

    frames = []
    cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
    num = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        
    for _ in range(num):
        _, img = cap.read()

        w, h, c = img.shape
        if w < 226 or h < 226:
            d = 226. - min(w, h)
            sc = 1 + d / min(w, h)
            img = cv2.resize(img, dsize=(0, 0), fx=sc, fy=sc)
        img = (img / 255.) * 2 - 1

        frames.append(img)
    
    frames = torch.Tensor(np.asarray(frames, dtype=np.float32))
    
    transform = transforms.Compose([videotransforms.CenterCrop(224)])
    frames = transform(frames)
    frames = rearrange(frames, 't h w c-> 1 c t h w')

    return frames

def classify(video,dataset='WLASL100'):
    to_load = {
        'WLASL100':{'logits':100,'path':'weights/asl100/FINAL_nslt_100_iters=896_top1=65.89_top5=84.11_top10=89.92.pt'},
        'WLASL2000':{'logits':2000,'path':'weights/asl2000/FINAL_nslt_2000_iters=5104_top1=32.48_top5=57.31_top10=66.31.pt'}
        }

    input = preprocess(video)

    model = InceptionI3d()
    model.load_state_dict(torch.load('weights/rgb_imagenet.pt'))
    model.replace_logits(to_load[dataset]['logits'])
    model.load_state_dict(torch.load(to_load[dataset]['path']))
    model.eval()

    with torch.no_grad():
        per_frame_logits = model(input)
    
    predictions = rearrange(per_frame_logits,'1 j k -> j k')
    predictions = torch.mean(predictions, dim = 1)

    top = torch.argmax(predictions).item()
    _, index = torch.topk(predictions,10)
    index = index.numpy()

    with open('wlasl_class_list.txt') as f:
        idx2label = dict()
        for line in f:
            idx2label[int(line.split()[0])]=line.split()[1]

    predictions = torch.nn.functional.softmax(predictions, dim=0).numpy()

    return {idx2label[i]:float(predictions[i]) for i in index}

title = "I3D Sign Language Recognition"
description = "Description here"
examples = [['videos/no.mp4','WLASL100'],['videos/all.mp4','WLASL100'],['videos/blue.mp4','WLASL2000'],['videos/white.mp4','WLASL2000'],['videos/accident.mp4','WLASL2000']]


gr.Interface(   fn=classify,
                inputs=[gr.inputs.Video(label="VIDEO"),gr.inputs.Dropdown(choices=['WLASL100','WLASL2000'], default='WLASL100', label='DATASET USED')], 
                outputs=[gr.outputs.Label(num_top_classes=5, label='Top 5 Predictions')],
                allow_flagging="never",
                title=title, 
                description=description, 
                examples=examples).launch(cache_examples=True)