File size: 7,386 Bytes
6d5c70d
 
 
 
 
 
29ba5a8
6d5c70d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93d1391
6d5c70d
 
6f1d1cd
56fb0e8
6d5c70d
 
 
 
 
 
 
 
56fb0e8
29ba5a8
 
6d5c70d
 
 
 
 
6f1d1cd
 
5db56a7
6f1d1cd
56fb0e8
6d5c70d
93d1391
 
 
 
6d5c70d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6f1d1cd
93d1391
 
 
6d5c70d
6f1d1cd
93d1391
 
6f1d1cd
93d1391
 
 
 
0ccb4ae
6f1d1cd
 
 
93d1391
6d5c70d
 
 
 
93d1391
4df5f4c
2b37f62
1d264ea
 
377d6e7
2b37f62
93d1391
 
 
 
 
 
6f1d1cd
93d1391
6f1d1cd
 
93d1391
df4c6a1
93d1391
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6f1d1cd
93d1391
6f1d1cd
93d1391
6f1d1cd
93d1391
6d5c70d
93d1391
 
c9a9e98
 
 
c4a3414
c9a9e98
c4a3414
a737ae9
6d5c70d
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import gradio as gr
import pandas as pd
import random
import os
import torch
import torch.nn.functional as F
from mobilenet import MobileNetLarge3D, MobileNetSmall3D
from torchvision.io import read_video
import time

def classify_pitch(confidence_scores):
    if torch.argmax(confidence_scores) == 0:
        call = 'Ball'
    elif torch.argmax(confidence_scores) == 1:
        call = 'Strike'
    else:
        print("that's odd, something is wrong")
        pass
    return call

def get_demo_call(video_name):
    video_name = os.path.basename(video_name).lower()
    if video_name.startswith("strike"):
        ump_out = "Strike"
    elif video_name.startswith("ball"):
        ump_out = "Ball"
    else:
        ump_out = None
    return ump_out

def call_pitch(pitch,progress=gr.Progress()):
    start = time.time()
    std = (0.2104, 0.1986, 0.1829)
    mean = (0.3939, 0.3817, 0.3314)
    # Convert the mean and std to tensors
    mean = torch.tensor(mean).view(1, 3, 1, 1, 1)
    std = torch.tensor(std).view(1, 3, 1, 1, 1)
    ump_out = get_demo_call(pitch)
    pitch_tensor = (read_video(pitch,pts_unit='sec')[0].permute(-1,0,1,2)).unsqueeze(0)/255
    pitch_tensor = (pitch_tensor-mean)/std #normalize the pitch tensor
    video_length = pitch_tensor.shape[2] #get video length to wait for the video to finish
    model = MobileNetSmall3D()
    model.load_state_dict(torch.load('weights/MobileNetSmall.pth',map_location=torch.device('cpu')))
    model.eval()
    with torch.no_grad():
        output = model(pitch_tensor)
    output = F.softmax(output,dim=1)
    final_call = classify_pitch(output)

    differential = time.time() - start
    wait_time = (int(video_length-(10+differential))//15)*10
    for _ in progress.tqdm(range(wait_time)):  #wait until the video is done to return the call and go 5 frames early to seem speedy
        time.sleep(0.1)

    if ump_out is not None:
        return final_call,ump_out
    else:
        return final_call

def generate_random_pitch():
    random_number = random.randint(1,2645342) #random number in our range to select a random pitch

    df = pd.read_csv('picklebot_2m.csv',skiprows=random_number,nrows=1,header=None) #just load one row
    video_link = df.values[0][0]
    label = df.values[0][1]
    
    if label == 0:
        ump_out = "Ball"
    elif label == 1:
        ump_out = "Strike"
    else:
        #error if it's not a 0 or 1
        ump_out = "Error"
    
    random_pitch = download_pitch(video_link,ump_out)

    return random_pitch,None,None #none and none to write none to ump out and pb out


def download_pitch(video_link,ump_out=None):
    #download and process the video link using yt-dlp and ffmpeg
    start = time.time()
    if ump_out is not None:
        video_out = f"demo_files/downloaded_videos/{ump_out}_demo_video.mp4"
        os.system(f"yt-dlp -q --no-warnings --force-overwrites -f mp4 {video_link} -o '{video_out}'")
        os.system(f'ffmpeg -y -nostats -loglevel 0 -i {video_out} -vf "crop=700:700:in_w/2-350:in_h/2-350,scale=224:224" -r 15 -c:v libx264 -crf 23 -c:a aac -an -strict experimental {video_out}_cropped_video.mp4')
    else:
        video_out = f"demo_files/downloaded_videos/own_demo_video.mp4"
        os.system(f"yt-dlp -q --no-warnings --force-overwrites -f mp4 {video_link} -o '{video_out}'")
        os.system(f'ffmpeg -y -nostats -loglevel 0 -i {video_out} -vf "crop=700:700:in_w/2-350:in_h/2-350,scale=224:224" -r 15 -c:v libx264 -crf 23 -c:a aac -an -strict experimental {video_out}_cropped_video.mp4')
    #if the process is taking too long, throw an error
    if time.time() - start > 30:
        raise gr.Error("Video download is taking too long. Please try again.")
    return f"{video_out}_cropped_video.mp4"

demo_files = os.listdir("demo_files/")
demo_files = [os.path.join("demo_files/", file) for file in demo_files if file.endswith(".mp4")]
with gr.Blocks(title="Picklebot") as demo:
    with gr.Tab("Picklebot"):
        gr.Markdown(value="""
                    To load a video, click the random button or choose from the examples. 
                    Play the video by clicking anywhere on the video; Picklebot's call will appear after the video is done playing.

                    
                    To use your own video, click the "Use Your Own Video!" tab and follow the instructions.""")
        with gr.Row():
            with gr.Column(scale=3):    
                inp = gr.Video(interactive=False, label="Pitch Video")
            with gr.Column(scale=2):
                ump_out = gr.Label(label="Umpire's Original Call")
                pb_out = gr.Label(label="Picklebot's Call")
                #when random is clicked, we want to somehow have a loading something to show that it's working
                random_button = gr.Button("🔀 Load a Random Pitch")
                random_button.click(fn=generate_random_pitch,outputs=[inp,ump_out,pb_out])
                inp.play(fn=call_pitch,inputs=inp,outputs=[pb_out,ump_out],trigger_mode="once")
                gr.ClearButton([inp,pb_out,ump_out])
                gr.Examples(demo_files,inputs=inp,outputs=[inp])
    with gr.Tab("Use Your Own Video!"):
        with gr.Row():
            gr.Markdown(value=
                    """
                    # Here\'s how to use your own video: 
                    1. Navigate to [Baseball Savant's statcast search](https://baseballsavant.mlb.com/statcast_search)
                    2. Choose the situation and pitch you want (just keep in mind that the network was only trained on called balls and called strikes) 
                    3. Click on the pitcher or batter's name to get to list of pitches
                    4. Right click on the camera icon to the right, and copy the link
                    5. Paste the video url in below, and press go to download the pitch
                    6. Watch the video and see what Picklebot thinks!
                    """)
        with gr.Row():
            with gr.Column(scale=3):    
                vid_inp = gr.Video(interactive=False, label="Pitch Video")
            with gr.Column(scale=2):
                #make a textbox to take in the user's video link
                input_txt = gr.Textbox(placeholder="Paste your video link here",label="Video Link")
                input_txt.submit(fn=download_pitch,inputs=input_txt,outputs=vid_inp)
                submit_button = gr.Button("Go!")
                submit_button.click(fn=download_pitch,inputs=input_txt,outputs=vid_inp)
                pb_out = gr.Label(label="Picklebot's Call")
                vid_inp.play(fn=call_pitch,inputs=vid_inp,outputs=pb_out,trigger_mode="once")
                gr.ClearButton([vid_inp,pb_out,input_txt])

    with gr.Tab("About"):
        gr.Markdown(value=
                    """Picklebot is a 3D Convolutional Neural Network based on MobileNetV3 that classifies baseball pitches as balls or strikes.
                    The network was trained on the [Picklebot-50K Dataset](https://huggingface.co/datasets/hbfreed/Picklebot-50K),
                    comprised of over fifty thousand pitches to achieve ~80% accuracy.

                    Here's the [GitHub](https://github.com/hbfreed/Picklebot) for the project.
                    
                    Here's my [LinkedIn](https://www.linkedin.com/in/hbfreed/) if you want to connect.""")
if __name__ == "__main__":
  demo.launch()