Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -4,7 +4,7 @@ import random
|
|
4 |
import os
|
5 |
import torch
|
6 |
import torch.nn.functional as F
|
7 |
-
from mobilenet import MobileNetLarge3D
|
8 |
from torchvision.io import read_video
|
9 |
import time
|
10 |
|
@@ -38,8 +38,8 @@ def call_pitch(pitch):
|
|
38 |
pitch_tensor = (read_video(pitch,pts_unit='sec')[0].permute(-1,0,1,2)).unsqueeze(0)/255
|
39 |
pitch_tensor = (pitch_tensor-mean)/std #normalize the pitch tensor
|
40 |
video_length = pitch_tensor.shape[2]/15
|
41 |
-
model =
|
42 |
-
model.load_state_dict(torch.load('weights/
|
43 |
model.eval()
|
44 |
|
45 |
#run the model
|
|
|
4 |
import os
|
5 |
import torch
|
6 |
import torch.nn.functional as F
|
7 |
+
from mobilenet import MobileNetLarge3D, MobileNetSmall3D
|
8 |
from torchvision.io import read_video
|
9 |
import time
|
10 |
|
|
|
38 |
pitch_tensor = (read_video(pitch,pts_unit='sec')[0].permute(-1,0,1,2)).unsqueeze(0)/255
|
39 |
pitch_tensor = (pitch_tensor-mean)/std #normalize the pitch tensor
|
40 |
video_length = pitch_tensor.shape[2]/15
|
41 |
+
model = MobileNetSmall3D()
|
42 |
+
model.load_state_dict(torch.load('weights/MobileNetSmall.pth',map_location=torch.device('cpu')))
|
43 |
model.eval()
|
44 |
|
45 |
#run the model
|