archit11 commited on
Commit
52a058d
·
verified ·
1 Parent(s): 53b7033

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +144 -1
app.py CHANGED
@@ -1,3 +1,146 @@
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- gr.load("models/archit11/videomae-base-finetuned-fight-nofight-subset2").launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
  import gradio as gr
3
+ import imutils
4
+ import numpy as np
5
+ import torch
6
+ from pytorchvideo.transforms import (
7
+ ApplyTransformToKey,
8
+ Normalize,
9
+ RandomShortSideScale,
10
+ RemoveKey,
11
+ ShortSideScale,
12
+ UniformTemporalSubsample,
13
+ )
14
+ from torchvision.transforms import (
15
+ Compose,
16
+ Lambda,
17
+ RandomCrop,
18
+ RandomHorizontalFlip,
19
+ Resize,
20
+ )
21
+ from transformers import VideoMAEFeatureExtractor, VideoMAEForVideoClassification
22
 
23
+ MODEL_CKPT = "archit11/videomae-base-finetuned-fight-nofight-subset2"
24
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
+
26
+ MODEL = VideoMAEForVideoClassification.from_pretrained(MODEL_CKPT).to(DEVICE)
27
+ PROCESSOR = VideoMAEFeatureExtractor.from_pretrained(MODEL_CKPT)
28
+
29
+ RESIZE_TO = PROCESSOR.size["shortest_edge"]
30
+ NUM_FRAMES_TO_SAMPLE = MODEL.config.num_frames
31
+ IMAGE_STATS = {"image_mean": [0.485, 0.456, 0.406], "image_std": [0.229, 0.224, 0.225]}
32
+ VAL_TRANSFORMS = Compose(
33
+ [
34
+ UniformTemporalSubsample(NUM_FRAMES_TO_SAMPLE),
35
+ Lambda(lambda x: x / 255.0),
36
+ Normalize(IMAGE_STATS["image_mean"], IMAGE_STATS["image_std"]),
37
+ Resize((RESIZE_TO, RESIZE_TO)),
38
+ ]
39
+ )
40
+ LABELS = list(MODEL.config.label2id.keys())
41
+
42
+
43
+ def parse_video(video_file):
44
+ """A utility to parse the input videos.
45
+
46
+ Reference: https://pyimagesearch.com/2018/11/12/yolo-object-detection-with-opencv/
47
+ """
48
+ vs = cv2.VideoCapture(video_file)
49
+
50
+ # try to determine the total number of frames in the video file
51
+ try:
52
+ prop = (
53
+ cv2.cv.CV_CAP_PROP_FRAME_COUNT
54
+ if imutils.is_cv2()
55
+ else cv2.CAP_PROP_FRAME_COUNT
56
+ )
57
+ total = int(vs.get(prop))
58
+ print("[INFO] {} total frames in video".format(total))
59
+
60
+ # an error occurred while trying to determine the total
61
+ # number of frames in the video file
62
+ except:
63
+ print("[INFO] could not determine # of frames in video")
64
+ print("[INFO] no approx. completion time can be provided")
65
+ total = -1
66
+
67
+ frames = []
68
+
69
+ # loop over frames from the video file stream
70
+ while True:
71
+ # read the next frame from the file
72
+ (grabbed, frame) = vs.read()
73
+ if frame is not None:
74
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
75
+ frames.append(frame)
76
+ # if the frame was not grabbed, then we have reached the end
77
+ # of the stream
78
+ if not grabbed:
79
+ break
80
+
81
+ return frames
82
+
83
+
84
+ def preprocess_video(frames: list):
85
+ """Utility to apply preprocessing transformations to a video tensor."""
86
+ # Each frame in the `frames` list has the shape: (height, width, num_channels).
87
+ # Collated together the `frames` has the the shape: (num_frames, height, width, num_channels).
88
+ # So, after converting the `frames` list to a torch tensor, we permute the shape
89
+ # such that it becomes (num_channels, num_frames, height, width) to make
90
+ # the shape compatible with the preprocessing transformations. After applying the
91
+ # preprocessing chain, we permute the shape to (num_frames, num_channels, height, width)
92
+ # to make it compatible with the model. Finally, we add a batch dimension so that our video
93
+ # classification model can operate on it.
94
+ video_tensor = torch.tensor(np.array(frames).astype(frames[0].dtype))
95
+ video_tensor = video_tensor.permute(
96
+ 3, 0, 1, 2
97
+ ) # (num_channels, num_frames, height, width)
98
+ video_tensor_pp = VAL_TRANSFORMS(video_tensor)
99
+ video_tensor_pp = video_tensor_pp.permute(
100
+ 1, 0, 2, 3
101
+ ) # (num_frames, num_channels, height, width)
102
+ video_tensor_pp = video_tensor_pp.unsqueeze(0)
103
+ return video_tensor_pp.to(DEVICE)
104
+
105
+
106
+ def infer(video_file):
107
+ frames = parse_video(video_file)
108
+ video_tensor = preprocess_video(frames)
109
+ inputs = {"pixel_values": video_tensor}
110
+
111
+ # forward pass
112
+ with torch.no_grad():
113
+ outputs = MODEL(**inputs)
114
+ logits = outputs.logits
115
+ softmax_scores = torch.nn.functional.softmax(logits, dim=-1).squeeze(0)
116
+ confidences = {LABELS[i]: float(softmax_scores[i]) for i in range(len(LABELS))}
117
+ return confidences
118
+
119
+
120
+ gr.Interface(
121
+ fn=infer,
122
+ inputs=gr.Video(type="file"),
123
+ outputs=gr.Label(num_top_classes=3),
124
+ examples=[
125
+ ["examples/fight.mp4"],
126
+ ["examples/baseball.mp4"],
127
+ ["examples/balancebeam.mp4"],
128
+ ["./examples/no-fight1.mp4"],
129
+ ["./examples/no-fight2.mp4"],
130
+ ["./examples/no-fight3.mp4"],
131
+ ["./examples/no-fight4.mp4"],
132
+
133
+
134
+ ],
135
+ title="VideoMAE fin-tuned on a subset of Fight / No Fight dataset",
136
+ description=(
137
+ "Gradio demo for VideoMAE for video classification. To use it, simply upload your video or click one of the"
138
+ " examples to load them. Read more at the links below."
139
+ ),
140
+ article=(
141
+ "<div style='text-align: center;'><a href='https://huggingface.co/docs/transformers/model_doc/videomae' target='_blank'>VideoMAE</a>"
142
+ " <center><a href='https://huggingface.co/sayakpaul/videomae-base-finetuned-kinetics-finetuned-ucf101-subset' target='_blank'>Fine-tuned Model</a></center></div>"
143
+ ),
144
+ allow_flagging=False,
145
+ allow_screenshot=False,
146
+ ).launch()