picklebot_demo / app.py
hbfreed's picture
Upload app.py
56fb0e8 verified
raw
history blame
3.87 kB
import gradio as gr
import pandas as pd
import random
import os
import torch
import torch.nn.functional as F
from mobilenet import MobileNetLarge3D, MobileNetSmall3D
from torchvision.io import read_video
import time
def classify_pitch(confidence_scores):
if torch.argmax(confidence_scores) == 0:
call = 'Ball'
elif torch.argmax(confidence_scores) == 1:
call = 'Strike'
else:
print("that's odd, something is wrong")
pass
return call
def get_demo_call(video_name):
video_name = os.path.basename(video_name).lower()
if video_name.startswith("strike"):
ump_out = "Strike"
elif video_name.startswith("ball"):
ump_out = "Ball"
else:
ump_out = "Error"
return ump_out
def call_pitch(pitch):
start = time.time()
std = (0.2104, 0.1986, 0.1829)
mean = (0.3939, 0.3817, 0.3314)
# Convert the mean and std to tensors
mean = torch.tensor(mean).view(1, 3, 1, 1, 1)
std = torch.tensor(std).view(1, 3, 1, 1, 1)
ump_out = get_demo_call(pitch)
pitch_tensor = (read_video(pitch,pts_unit='sec')[0].permute(-1,0,1,2)).unsqueeze(0)/255
pitch_tensor = (pitch_tensor-mean)/std #normalize the pitch tensor
video_length = pitch_tensor.shape[2] #get video length to wait for the video to finish
model = MobileNetSmall3D()
model.load_state_dict(torch.load('weights/MobileNetSmall.pth',map_location=torch.device('cpu')))
model.eval()
#run the model
with torch.no_grad():
output = model(pitch_tensor)
output = F.softmax(output,dim=1)
final_call = classify_pitch(output)
# time.sleep(video_length-1.5) #wait until the video is done to return the call
while time.time() - start < (video_length-5)/15: #wait until the video is done to return the call and go 2 frames early to seem speedy
time.sleep(0.1)
return final_call,ump_out
def generate_random_pitch():
random_number = random.randint(1,2645342) #random number in our range to select a random pitch
df = pd.read_csv('picklebot_2m.csv',skiprows=random_number,nrows=1,header=None) #just load one row
video_link = df.values[0][0]
label = df.values[0][1]
if label == 0:
ump_out = "Ball"
elif label == 1:
ump_out = "Strike"
else:
#error if it's not a 0 or 1
ump_out = "Error"
random_pitch = download_pitch(video_link,ump_out)
return random_pitch
def download_pitch(video_link,ump_out):
#download and process the video link using yt-dlp and ffmpeg
os.system(f"yt-dlp -q --no-warnings --force-overwrites -f mp4 {video_link} -o 'demo_files/downloaded_videos/{ump_out}_demo_video.mp4'")
os.system(f'ffmpeg -y -nostats -loglevel 0 -i demo_files/downloaded_videos/{ump_out}_demo_video.mp4 -vf "crop=700:700:in_w/2-350:in_h/2-350,scale=224:224" -r 15 -c:v libx264 -crf 23 -c:a aac -strict experimental demo_files/downloaded_videos/{ump_out}_cropped_video.mp4')
return f"demo_files/downloaded_videos/{ump_out}_cropped_video.mp4"
demo_files = os.listdir("demo_files/")
demo_files = [os.path.join("demo_files/", file) for file in demo_files if file.endswith(".mp4")]
with gr.Blocks(title="Picklebot") as demo:
with gr.Row():
with gr.Column(scale=3):
inp = gr.Video(interactive=False, label="Pitch Video")
with gr.Column(scale=2):
ump_out = gr.Label(label="Umpire's Original Call")
pb_out = gr.Label(label="Picklebot's Call")
random_button = gr.Button("🔀 Load a Random Pitch")
ball = 0
random_button.click(fn=generate_random_pitch,outputs=[inp])
inp.play(fn=call_pitch,inputs=inp,outputs=[pb_out,ump_out])
gr.ClearButton([inp,pb_out,ump_out])
gr.Examples(demo_files,inputs=inp,outputs=[inp])
if __name__ == "__main__":
demo.launch()