hbfreed commited on
Commit
29ba5a8
1 Parent(s): c33bd6f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -3
app.py CHANGED
@@ -4,7 +4,7 @@ import random
4
  import os
5
  import torch
6
  import torch.nn.functional as F
7
- from mobilenet import MobileNetLarge3D
8
  from torchvision.io import read_video
9
  import time
10
 
@@ -38,8 +38,8 @@ def call_pitch(pitch):
38
  pitch_tensor = (read_video(pitch,pts_unit='sec')[0].permute(-1,0,1,2)).unsqueeze(0)/255
39
  pitch_tensor = (pitch_tensor-mean)/std #normalize the pitch tensor
40
  video_length = pitch_tensor.shape[2]/15
41
- model = MobileNetLarge3D()
42
- model.load_state_dict(torch.load('weights/MobileNetLarge.pth',map_location=torch.device('cpu')))
43
  model.eval()
44
 
45
  #run the model
 
4
  import os
5
  import torch
6
  import torch.nn.functional as F
7
+ from mobilenet import MobileNetLarge3D, MobileNetSmall3D
8
  from torchvision.io import read_video
9
  import time
10
 
 
38
  pitch_tensor = (read_video(pitch,pts_unit='sec')[0].permute(-1,0,1,2)).unsqueeze(0)/255
39
  pitch_tensor = (pitch_tensor-mean)/std #normalize the pitch tensor
40
  video_length = pitch_tensor.shape[2]/15
41
+ model = MobileNetSmall3D()
42
+ model.load_state_dict(torch.load('weights/MobileNetSmall.pth',map_location=torch.device('cpu')))
43
  model.eval()
44
 
45
  #run the model