Commit
·
769a90a
1
Parent(s):
8a6f09e
added gradio_demo.py
Browse files- gradio_demo.py +42 -0
gradio_demo.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
import requests
|
4 |
+
from torchvision import transforms
|
5 |
+
|
6 |
+
def predict(mfile,patch_file,positions_file,expName):
|
7 |
+
positions = torch.from_numpy(np.load(positions_file))
|
8 |
+
|
9 |
+
patch_h5 = h5py.File(patch_file,'r')
|
10 |
+
n_frames = len(patch_h5)
|
11 |
+
|
12 |
+
transform_norm = transforms.ToTensor()
|
13 |
+
patches = torch.zeros(size=(len(positions),15,15))
|
14 |
+
frame_nr = np.zeros(shape=(len(positions),))
|
15 |
+
patch_nr = np.zeros(shape=(len(positions),))
|
16 |
+
|
17 |
+
j=0
|
18 |
+
for i in range(1,n_frames+1):
|
19 |
+
k=0
|
20 |
+
for patch in patch_h5['frame_nr%s' %i]:
|
21 |
+
patches[j] = transform_norm(np.array(patch,dtype=np.uint8))
|
22 |
+
frame_nr[j] = i
|
23 |
+
patch_nr[j] = k
|
24 |
+
j+=1
|
25 |
+
k+=1
|
26 |
+
|
27 |
+
transformed_patches = patches
|
28 |
+
transformed_patches = torch.reshape(transformed_patches,(len(patches),1,15,15))
|
29 |
+
|
30 |
+
inp = transforms.ToTensor()(inp).unsqueeze(0)
|
31 |
+
model=torch.load(m_file)
|
32 |
+
with torch.no_grad():
|
33 |
+
for i, (inputs, labels) in enumerate(dl_valid):
|
34 |
+
inputs = inputs.to(device)
|
35 |
+
y_pred = model(inputs)
|
36 |
+
y_pred = y_pred.cpu().numpy()
|
37 |
+
labels = labels.cpu().numpy()
|
38 |
+
plot_error(y_pred,labels,args.expName)
|
39 |
+
mse += np.power(y_pred[:,0] - labels[:,0],2) + np.power(y_pred[:,1] - labels[:,1],2)
|
40 |
+
|
41 |
+
interface=gr.Interface(fn=predict, inputs={"m_file": "upload", "patch_file": "upload", "positions_file":"upload", "expName": "text"}, outputs={"output":"text"})
|
42 |
+
interface.launch()
|