omermazig's picture
Make sure we are using CUDA if available
614fa07
raw
history blame
4.01 kB
import gradio as gr
import torch
from pytorchvideo.data import make_clip_sampler
from pytorchvideo.data.clip_sampling import ClipInfoList
from pytorchvideo.data.encoded_video_pyav import EncodedVideoPyAV
from pytorchvideo.data.video import VideoPathHandler
from pytorchvideo.transforms import (
Normalize,
UniformTemporalSubsample, RandomShortSideScale,
)
from torchvision.transforms import (
Compose,
Lambda,
Resize, RandomCrop,
)
from transformers import VideoMAEForVideoClassification, VideoMAEFeatureExtractor
MODEL_CKPT = "omermazig/videomae-finetuned-nba-5-class-4-batch-8000-vid-multiclass"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
CLIPS_FROM_SINGLE_VIDEO = 5
trained_model = VideoMAEForVideoClassification.from_pretrained(MODEL_CKPT).to(DEVICE)
image_processor = VideoMAEFeatureExtractor.from_pretrained(MODEL_CKPT)
mean = image_processor.image_mean
std = image_processor.image_std
if "shortest_edge" in image_processor.size:
height = width = image_processor.size["shortest_edge"]
else:
height = image_processor.size["height"]
width = image_processor.size["width"]
resize_to = (height, width)
num_frames_to_sample = trained_model.config.num_frames
sample_rate = 4
fps = 30
clip_duration = num_frames_to_sample * sample_rate / fps
# Validation and Test datasets' transformations.
inference_transform = Compose(
[
UniformTemporalSubsample(num_frames_to_sample),
Lambda(lambda x: x / 255.0),
Normalize(mean, std),
RandomShortSideScale(min_size=256, max_size=320),
RandomCrop(resize_to),
]
)
labels = list(trained_model.config.label2id.keys())
def parse_video_to_clips(video_file):
"""A utility to parse the input videos """
video_path_handler = VideoPathHandler()
video: EncodedVideoPyAV = video_path_handler.video_from_path(video_file)
clip_sampler = make_clip_sampler("random_multi", clip_duration, CLIPS_FROM_SINGLE_VIDEO)
# noinspection PyTypeChecker
clip_info: ClipInfoList = clip_sampler(0, video.duration, {})
video_clips_list = []
for clip_start, clip_end in zip(clip_info.clip_start_sec, clip_info.clip_end_sec):
video_clip = video.get_clip(clip_start, clip_end)["video"]
video_clips_list.append(inference_transform(video_clip))
videos_tensor = torch.stack([single_clip.permute(1, 0, 2, 3) for single_clip in video_clips_list])
return videos_tensor.to(DEVICE)
def infer(video_file):
videos_tensor = parse_video_to_clips(video_file)
inputs = {"pixel_values": videos_tensor}
# forward pass
with torch.no_grad():
outputs = trained_model(**inputs)
multiple_logits = outputs.logits
logits = multiple_logits.sum(dim=0)
softmax_scores = torch.nn.functional.softmax(logits, dim=-1).squeeze(0)
confidences = {labels[i]: float(softmax_scores[i]) for i in range(len(labels))}
return confidences
gr.Interface(
fn=infer,
inputs=gr.Video(type="file"),
outputs=gr.Label(num_top_classes=3),
examples=[
["examples/DUNK.avi"],
["examples/FLOATING_JUMP_SHOT.avi"],
["examples/JUMP_SHOT.avi"],
["examples/REVERSE_LAYUP.avi"],
["examples/TURNAROUND_HOOK_SHOT.avi"],
],
title="VideoMAE fine-tuned on nba data",
description=(
"Gradio demo for VideoMAE for video classification. To use it, simply upload your video or click one of the"
" examples to load them. Read more at the links below."
),
article=(
"<div style='text-align: center;'><a href='https://huggingface.co/docs/transformers/model_doc/videomae' target='_blank'>VideoMAE</a>"
" <center><a href='https://huggingface.co/omermazig/videomae-finetuned-nba-5-class-8-batch-8000-vid-multiclass_1697155188' target='_blank'>Fine-tuned Model</a></center></div>"
),
allow_flagging=False,
allow_screenshot=False,
).launch()