import gradio as gr import pandas as pd import random import os import torch import torch.nn.functional as F from mobilenet import MobileNetLarge3D, MobileNetSmall3D from torchvision.io import read_video import time def classify_pitch(confidence_scores): if torch.argmax(confidence_scores) == 0: call = 'Ball' elif torch.argmax(confidence_scores) == 1: call = 'Strike' else: print("that's odd, something is wrong") pass return call def get_demo_call(video_name): video_name = os.path.basename(video_name).lower() if video_name.startswith("strike"): ump_out = "Strike" elif video_name.startswith("ball"): ump_out = "Ball" else: ump_out = "Error" return ump_out def call_pitch(pitch): std = (0.2104, 0.1986, 0.1829) mean = (0.3939, 0.3817, 0.3314) # Convert the mean and std to tensors mean = torch.tensor(mean).view(1, 3, 1, 1, 1) std = torch.tensor(std).view(1, 3, 1, 1, 1) ump_out = get_demo_call(pitch) pitch_tensor = (read_video(pitch,pts_unit='sec')[0].permute(-1,0,1,2)).unsqueeze(0)/255 pitch_tensor = (pitch_tensor-mean)/std #normalize the pitch tensor video_length = pitch_tensor.shape[2]/15 model = MobileNetSmall3D() model.load_state_dict(torch.load('weights/MobileNetSmall.pth',map_location=torch.device('cpu'))) model.eval() #run the model with torch.no_grad(): output = model(pitch_tensor) output = F.softmax(output,dim=1) final_call = classify_pitch(output) #time.sleep(video_length) #wait until the video is done to return the call return final_call,ump_out def generate_random_pitch(): random_number = random.randint(1,2645342) #random number in our range to select a random pitch df = pd.read_csv('picklebot_2m.csv',skiprows=random_number,nrows=1,header=None) #just load one row video_link = df.values[0][0] label = df.values[0][1] if label == 0: ump_out = "Ball" elif label == 1: ump_out = "Strike" else: #error if it's not a 0 or 1 ump_out = "Error" random_pitch = download_pitch(video_link,ump_out) return random_pitch def download_pitch(video_link,ump_out): #download and process the video link using yt-dlp and ffmpeg os.system(f"yt-dlp -q --no-warnings --force-overwrites -f mp4 {video_link} -o 'demo_files/downloaded_videos/{ump_out}_demo_video.mp4'") os.system(f'ffmpeg -y -nostats -loglevel 0 -i demo_files/downloaded_videos/{ump_out}_demo_video.mp4 -vf "crop=700:700:in_w/2-350:in_h/2-350,scale=224:224" -r 15 -c:v libx264 -crf 23 -c:a aac -strict experimental demo_files/downloaded_videos/{ump_out}_cropped_video.mp4') return f"demo_files/downloaded_videos/{ump_out}_cropped_video.mp4" demo_files = os.listdir("demo_files/") demo_files = [os.path.join("demo_files/", file) for file in demo_files if file.endswith(".mp4")] with gr.Blocks(title="Picklebot") as demo: with gr.Row(): with gr.Column(scale=3): inp = gr.Video(interactive=False, label="Pitch Video") with gr.Column(scale=2): ump_out = gr.Label(label="Umpire's Original Call") pb_out = gr.Label(label="Picklebot's Call") random_button = gr.Button("🔀 Load a Random Pitch") ball = 0 random_button.click(fn=generate_random_pitch,outputs=[inp]) inp.play(fn=call_pitch,inputs=inp,outputs=[pb_out,ump_out]) gr.ClearButton([inp,pb_out,ump_out]) gr.Examples(demo_files,inputs=inp,outputs=[inp]) if __name__ == "__main__": demo.launch()