Spaces:
Runtime error
Runtime error
File size: 7,386 Bytes
6d5c70d 29ba5a8 6d5c70d 93d1391 6d5c70d 6f1d1cd 56fb0e8 6d5c70d 56fb0e8 29ba5a8 6d5c70d 6f1d1cd 5db56a7 6f1d1cd 56fb0e8 6d5c70d 93d1391 6d5c70d 6f1d1cd 93d1391 6d5c70d 6f1d1cd 93d1391 6f1d1cd 93d1391 0ccb4ae 6f1d1cd 93d1391 6d5c70d 93d1391 4df5f4c 2b37f62 1d264ea 377d6e7 2b37f62 93d1391 6f1d1cd 93d1391 6f1d1cd 93d1391 df4c6a1 93d1391 6f1d1cd 93d1391 6f1d1cd 93d1391 6f1d1cd 93d1391 6d5c70d 93d1391 c9a9e98 c4a3414 c9a9e98 c4a3414 a737ae9 6d5c70d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
import gradio as gr
import pandas as pd
import random
import os
import torch
import torch.nn.functional as F
from mobilenet import MobileNetLarge3D, MobileNetSmall3D
from torchvision.io import read_video
import time
def classify_pitch(confidence_scores):
if torch.argmax(confidence_scores) == 0:
call = 'Ball'
elif torch.argmax(confidence_scores) == 1:
call = 'Strike'
else:
print("that's odd, something is wrong")
pass
return call
def get_demo_call(video_name):
video_name = os.path.basename(video_name).lower()
if video_name.startswith("strike"):
ump_out = "Strike"
elif video_name.startswith("ball"):
ump_out = "Ball"
else:
ump_out = None
return ump_out
def call_pitch(pitch,progress=gr.Progress()):
start = time.time()
std = (0.2104, 0.1986, 0.1829)
mean = (0.3939, 0.3817, 0.3314)
# Convert the mean and std to tensors
mean = torch.tensor(mean).view(1, 3, 1, 1, 1)
std = torch.tensor(std).view(1, 3, 1, 1, 1)
ump_out = get_demo_call(pitch)
pitch_tensor = (read_video(pitch,pts_unit='sec')[0].permute(-1,0,1,2)).unsqueeze(0)/255
pitch_tensor = (pitch_tensor-mean)/std #normalize the pitch tensor
video_length = pitch_tensor.shape[2] #get video length to wait for the video to finish
model = MobileNetSmall3D()
model.load_state_dict(torch.load('weights/MobileNetSmall.pth',map_location=torch.device('cpu')))
model.eval()
with torch.no_grad():
output = model(pitch_tensor)
output = F.softmax(output,dim=1)
final_call = classify_pitch(output)
differential = time.time() - start
wait_time = (int(video_length-(10+differential))//15)*10
for _ in progress.tqdm(range(wait_time)): #wait until the video is done to return the call and go 5 frames early to seem speedy
time.sleep(0.1)
if ump_out is not None:
return final_call,ump_out
else:
return final_call
def generate_random_pitch():
random_number = random.randint(1,2645342) #random number in our range to select a random pitch
df = pd.read_csv('picklebot_2m.csv',skiprows=random_number,nrows=1,header=None) #just load one row
video_link = df.values[0][0]
label = df.values[0][1]
if label == 0:
ump_out = "Ball"
elif label == 1:
ump_out = "Strike"
else:
#error if it's not a 0 or 1
ump_out = "Error"
random_pitch = download_pitch(video_link,ump_out)
return random_pitch,None,None #none and none to write none to ump out and pb out
def download_pitch(video_link,ump_out=None):
#download and process the video link using yt-dlp and ffmpeg
start = time.time()
if ump_out is not None:
video_out = f"demo_files/downloaded_videos/{ump_out}_demo_video.mp4"
os.system(f"yt-dlp -q --no-warnings --force-overwrites -f mp4 {video_link} -o '{video_out}'")
os.system(f'ffmpeg -y -nostats -loglevel 0 -i {video_out} -vf "crop=700:700:in_w/2-350:in_h/2-350,scale=224:224" -r 15 -c:v libx264 -crf 23 -c:a aac -an -strict experimental {video_out}_cropped_video.mp4')
else:
video_out = f"demo_files/downloaded_videos/own_demo_video.mp4"
os.system(f"yt-dlp -q --no-warnings --force-overwrites -f mp4 {video_link} -o '{video_out}'")
os.system(f'ffmpeg -y -nostats -loglevel 0 -i {video_out} -vf "crop=700:700:in_w/2-350:in_h/2-350,scale=224:224" -r 15 -c:v libx264 -crf 23 -c:a aac -an -strict experimental {video_out}_cropped_video.mp4')
#if the process is taking too long, throw an error
if time.time() - start > 30:
raise gr.Error("Video download is taking too long. Please try again.")
return f"{video_out}_cropped_video.mp4"
demo_files = os.listdir("demo_files/")
demo_files = [os.path.join("demo_files/", file) for file in demo_files if file.endswith(".mp4")]
with gr.Blocks(title="Picklebot") as demo:
with gr.Tab("Picklebot"):
gr.Markdown(value="""
To load a video, click the random button or choose from the examples.
Play the video by clicking anywhere on the video; Picklebot's call will appear after the video is done playing.
To use your own video, click the "Use Your Own Video!" tab and follow the instructions.""")
with gr.Row():
with gr.Column(scale=3):
inp = gr.Video(interactive=False, label="Pitch Video")
with gr.Column(scale=2):
ump_out = gr.Label(label="Umpire's Original Call")
pb_out = gr.Label(label="Picklebot's Call")
#when random is clicked, we want to somehow have a loading something to show that it's working
random_button = gr.Button("🔀 Load a Random Pitch")
random_button.click(fn=generate_random_pitch,outputs=[inp,ump_out,pb_out])
inp.play(fn=call_pitch,inputs=inp,outputs=[pb_out,ump_out],trigger_mode="once")
gr.ClearButton([inp,pb_out,ump_out])
gr.Examples(demo_files,inputs=inp,outputs=[inp])
with gr.Tab("Use Your Own Video!"):
with gr.Row():
gr.Markdown(value=
"""
# Here\'s how to use your own video:
1. Navigate to [Baseball Savant's statcast search](https://baseballsavant.mlb.com/statcast_search)
2. Choose the situation and pitch you want (just keep in mind that the network was only trained on called balls and called strikes)
3. Click on the pitcher or batter's name to get to list of pitches
4. Right click on the camera icon to the right, and copy the link
5. Paste the video url in below, and press go to download the pitch
6. Watch the video and see what Picklebot thinks!
""")
with gr.Row():
with gr.Column(scale=3):
vid_inp = gr.Video(interactive=False, label="Pitch Video")
with gr.Column(scale=2):
#make a textbox to take in the user's video link
input_txt = gr.Textbox(placeholder="Paste your video link here",label="Video Link")
input_txt.submit(fn=download_pitch,inputs=input_txt,outputs=vid_inp)
submit_button = gr.Button("Go!")
submit_button.click(fn=download_pitch,inputs=input_txt,outputs=vid_inp)
pb_out = gr.Label(label="Picklebot's Call")
vid_inp.play(fn=call_pitch,inputs=vid_inp,outputs=pb_out,trigger_mode="once")
gr.ClearButton([vid_inp,pb_out,input_txt])
with gr.Tab("About"):
gr.Markdown(value=
"""Picklebot is a 3D Convolutional Neural Network based on MobileNetV3 that classifies baseball pitches as balls or strikes.
The network was trained on the [Picklebot-50K Dataset](https://huggingface.co/datasets/hbfreed/Picklebot-50K),
comprised of over fifty thousand pitches to achieve ~80% accuracy.
Here's the [GitHub](https://github.com/hbfreed/Picklebot) for the project.
Here's my [LinkedIn](https://www.linkedin.com/in/hbfreed/) if you want to connect.""")
if __name__ == "__main__":
demo.launch() |