dennistrujillo commited on
Commit
769a90a
·
1 Parent(s): 8a6f09e

added gradio_demo.py

Browse files
Files changed (1) hide show
  1. gradio_demo.py +42 -0
gradio_demo.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import requests
4
+ from torchvision import transforms
5
+
6
+ def predict(mfile,patch_file,positions_file,expName):
7
+ positions = torch.from_numpy(np.load(positions_file))
8
+
9
+ patch_h5 = h5py.File(patch_file,'r')
10
+ n_frames = len(patch_h5)
11
+
12
+ transform_norm = transforms.ToTensor()
13
+ patches = torch.zeros(size=(len(positions),15,15))
14
+ frame_nr = np.zeros(shape=(len(positions),))
15
+ patch_nr = np.zeros(shape=(len(positions),))
16
+
17
+ j=0
18
+ for i in range(1,n_frames+1):
19
+ k=0
20
+ for patch in patch_h5['frame_nr%s' %i]:
21
+ patches[j] = transform_norm(np.array(patch,dtype=np.uint8))
22
+ frame_nr[j] = i
23
+ patch_nr[j] = k
24
+ j+=1
25
+ k+=1
26
+
27
+ transformed_patches = patches
28
+ transformed_patches = torch.reshape(transformed_patches,(len(patches),1,15,15))
29
+
30
+ inp = transforms.ToTensor()(inp).unsqueeze(0)
31
+ model=torch.load(m_file)
32
+ with torch.no_grad():
33
+ for i, (inputs, labels) in enumerate(dl_valid):
34
+ inputs = inputs.to(device)
35
+ y_pred = model(inputs)
36
+ y_pred = y_pred.cpu().numpy()
37
+ labels = labels.cpu().numpy()
38
+ plot_error(y_pred,labels,args.expName)
39
+ mse += np.power(y_pred[:,0] - labels[:,0],2) + np.power(y_pred[:,1] - labels[:,1],2)
40
+
41
+ interface=gr.Interface(fn=predict, inputs={"m_file": "upload", "patch_file": "upload", "positions_file":"upload", "expName": "text"}, outputs={"output":"text"})
42
+ interface.launch()