Spaces:
Runtime error
Runtime error
import gradio as gr | |
import pandas as pd | |
import random | |
import os | |
import torch | |
import torch.nn.functional as F | |
from mobilenet import MobileNetLarge3D, MobileNetSmall3D | |
from torchvision.io import read_video | |
import time | |
def classify_pitch(confidence_scores): | |
if torch.argmax(confidence_scores) == 0: | |
call = 'Ball' | |
elif torch.argmax(confidence_scores) == 1: | |
call = 'Strike' | |
else: | |
print("that's odd, something is wrong") | |
pass | |
return call | |
def get_demo_call(video_name): | |
video_name = os.path.basename(video_name).lower() | |
if video_name.startswith("strike"): | |
ump_out = "Strike" | |
elif video_name.startswith("ball"): | |
ump_out = "Ball" | |
else: | |
ump_out = "Error" | |
return ump_out | |
def call_pitch(pitch): | |
start = time.time() | |
std = (0.2104, 0.1986, 0.1829) | |
mean = (0.3939, 0.3817, 0.3314) | |
# Convert the mean and std to tensors | |
mean = torch.tensor(mean).view(1, 3, 1, 1, 1) | |
std = torch.tensor(std).view(1, 3, 1, 1, 1) | |
ump_out = get_demo_call(pitch) | |
pitch_tensor = (read_video(pitch,pts_unit='sec')[0].permute(-1,0,1,2)).unsqueeze(0)/255 | |
pitch_tensor = (pitch_tensor-mean)/std #normalize the pitch tensor | |
video_length = pitch_tensor.shape[2] #get video length to wait for the video to finish | |
model = MobileNetSmall3D() | |
model.load_state_dict(torch.load('weights/MobileNetSmall.pth',map_location=torch.device('cpu'))) | |
model.eval() | |
#run the model | |
with torch.no_grad(): | |
output = model(pitch_tensor) | |
output = F.softmax(output,dim=1) | |
final_call = classify_pitch(output) | |
# time.sleep(video_length-1.5) #wait until the video is done to return the call | |
while time.time() - start < (video_length-5)/15: #wait until the video is done to return the call and go 2 frames early to seem speedy | |
time.sleep(0.1) | |
return final_call,ump_out | |
def generate_random_pitch(): | |
random_number = random.randint(1,2645342) #random number in our range to select a random pitch | |
df = pd.read_csv('picklebot_2m.csv',skiprows=random_number,nrows=1,header=None) #just load one row | |
video_link = df.values[0][0] | |
label = df.values[0][1] | |
if label == 0: | |
ump_out = "Ball" | |
elif label == 1: | |
ump_out = "Strike" | |
else: | |
#error if it's not a 0 or 1 | |
ump_out = "Error" | |
random_pitch = download_pitch(video_link,ump_out) | |
return random_pitch | |
def download_pitch(video_link,ump_out): | |
#download and process the video link using yt-dlp and ffmpeg | |
os.system(f"yt-dlp -q --no-warnings --force-overwrites -f mp4 {video_link} -o 'demo_files/downloaded_videos/{ump_out}_demo_video.mp4'") | |
os.system(f'ffmpeg -y -nostats -loglevel 0 -i demo_files/downloaded_videos/{ump_out}_demo_video.mp4 -vf "crop=700:700:in_w/2-350:in_h/2-350,scale=224:224" -r 15 -c:v libx264 -crf 23 -c:a aac -strict experimental demo_files/downloaded_videos/{ump_out}_cropped_video.mp4') | |
return f"demo_files/downloaded_videos/{ump_out}_cropped_video.mp4" | |
demo_files = os.listdir("demo_files/") | |
demo_files = [os.path.join("demo_files/", file) for file in demo_files if file.endswith(".mp4")] | |
with gr.Blocks(title="Picklebot") as demo: | |
with gr.Row(): | |
with gr.Column(scale=3): | |
inp = gr.Video(interactive=False, label="Pitch Video") | |
with gr.Column(scale=2): | |
ump_out = gr.Label(label="Umpire's Original Call") | |
pb_out = gr.Label(label="Picklebot's Call") | |
random_button = gr.Button("🔀 Load a Random Pitch") | |
ball = 0 | |
random_button.click(fn=generate_random_pitch,outputs=[inp]) | |
inp.play(fn=call_pitch,inputs=inp,outputs=[pb_out,ump_out]) | |
gr.ClearButton([inp,pb_out,ump_out]) | |
gr.Examples(demo_files,inputs=inp,outputs=[inp]) | |
if __name__ == "__main__": | |
demo.launch() |