File size: 4,228 Bytes
d3d8d59
 
 
 
2fe95f1
d3d8d59
 
 
2fe95f1
 
d3d8d59
2a7c856
d3d8d59
 
 
 
 
2a7c856
 
d3d8d59
 
17d99b0
2a7c856
17d99b0
 
d3d8d59
2a7c856
d3d8d59
 
 
 
 
2a7c856
 
d3d8d59
 
 
 
 
 
2a7c856
d3d8d59
 
2a7c856
d3d8d59
 
 
 
 
 
 
 
 
2a7c856
d3d8d59
 
2a7c856
d3d8d59
7a783a2
d3d8d59
7a783a2
 
2a7c856
7a783a2
2a7c856
 
d3d8d59
 
2a7c856
 
d3d8d59
7a783a2
 
 
2a7c856
d3d8d59
 
 
2a7c856
d3d8d59
7a783a2
d3d8d59
2a7c856
d3d8d59
 
 
 
2a7c856
 
7a783a2
d3d8d59
2a7c856
d3d8d59
 
2a7c856
d3d8d59
2a7c856
 
 
 
17d99b0
 
 
 
 
 
 
 
d3d8d59
2a7c856
 
d3d8d59
2a7c856
d3d8d59
2a7c856
d3d8d59
 
 
 
2a7c856
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import torch
import cv2
import videotransforms
import numpy as np
import gradio as gr
from einops import rearrange
from torchvision import transforms
from pytorch_i3d import InceptionI3d


def preprocess(vidpath):
    # Fetch video
    cap = cv2.VideoCapture(vidpath)

    frames = []
    cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
    num = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    
    # Extract frames from video
    for _ in range(num):
        _, img = cap.read()
        
        # Skip NoneType frames
        if img is None:
            continue

        # Resize if (w,h) < (226,226)
        w, h, c = img.shape
        if w < 226 or h < 226:
            d = 226. - min(w, h)
            sc = 1 + d / min(w, h)
            img = cv2.resize(img, dsize=(0, 0), fx=sc, fy=sc)

        # Normalize
        img = (img / 255.) * 2 - 1

        frames.append(img)
    
    frames = torch.Tensor(np.asarray(frames, dtype=np.float32))
    
    # Transform tensor and reshape to (1, c, t ,w, h)
    transform = transforms.Compose([videotransforms.CenterCrop(224)])
    frames = transform(frames)
    frames = rearrange(frames, 't w h c-> 1 c t w h')

    return frames

def classify(video,dataset='WLASL100'):
    to_load = {
        'WLASL100':{'logits':100,'path':'weights/asl100/FINAL_nslt_100_iters=896_top1=65.89_top5=84.11_top10=89.92.pt'},
        'WLASL2000':{'logits':2000,'path':'weights/asl2000/FINAL_nslt_2000_iters=5104_top1=32.48_top5=57.31_top10=66.31.pt'}
        }

    # Preprocess video
    input = preprocess(video)

    # Load model
    model = InceptionI3d()
    model.load_state_dict(torch.load('weights/rgb_imagenet.pt',map_location=torch.device('cpu')))
    model.replace_logits(to_load[dataset]['logits'])
    model.load_state_dict(torch.load(to_load[dataset]['path'],map_location=torch.device('cpu')))

    # Run on cpu. Spaces environment is limited to CPU for free users. 
    model.cpu()

    # Evaluation mode
    model.eval()

    with torch.no_grad(): # Disable gradient computation
        per_frame_logits = model(input) # Inference
    
    per_frame_logits.cpu()
    model.cpu()

    # Load predictions
    predictions = rearrange(per_frame_logits,'1 j k -> j k')
    predictions = torch.mean(predictions, dim = 1)

    # Fetch top 10 predictions
    _, index = torch.topk(predictions,10)
    index = index.cpu().numpy()

    # Load labels 
    with open('wlasl_class_list.txt') as f:
        idx2label = dict()
        for line in f:
            idx2label[int(line.split()[0])]=line.split()[1]
    
    # Get probabilities
    predictions = torch.nn.functional.softmax(predictions, dim=0).cpu().numpy()

    # Return dict {label:pred}
    return {idx2label[i]:float(predictions[i]) for i in index}

# Gradio App config
title = "I3D Sign Language Recognition"
description =   "Gradio demo of word-level sign language classification using I3D model pretrained on the WLASL video dataset. " \
                "WLASL is a large-scale dataset containing more than 2000 words in American Sign Language. " \
                "Examples used in the demo are videos from the the test subset. "  \
                "Note that WLASL100 contains 100 words while WLASL2000 contains 2000."
examples = [
        ['videos/no.mp4','WLASL100'],
        ['videos/all.mp4','WLASL100'],
        ['videos/before.mp4','WLASL100'],
        ['videos/blue.mp4','WLASL2000'],
        ['videos/white.mp4','WLASL2000'],
        ['videos/accident2.mp4','WLASL2000']
    ]

article =   "NOTE: This is not the official demonstration of the I3D sign language classification on the WLASL dataset. "\
            "More information about the WLASL dataset and pretrained I3D models can be found <a href=https://github.com/dxli94/WLASL>here</a>."

# Gradio App interface
gr.Interface(   fn=classify,
                inputs=[gr.inputs.Video(label="Video (*.mp4)"),gr.inputs.Radio(choices=['WLASL100','WLASL2000'], default='WLASL100', label='Trained on:')], 
                outputs=[gr.outputs.Label(num_top_classes=5, label='Top 5 Predictions')],
                allow_flagging="never",
                title=title, 
                description=description, 
                examples=examples,
                article=article).launch()