Spaces:
Runtime error
Runtime error
dalexanderch
commited on
Commit
·
44c8341
1
Parent(s):
69ea74f
Upload app.py
Browse files
app.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from glycowork.ml.processing import dataset_to_dataloader
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
|
6 |
+
def fn(model, class_list):
|
7 |
+
def f(glycan):
|
8 |
+
glycan = [glycan]
|
9 |
+
label = [0]
|
10 |
+
data = next(iter(dataset_to_dataloader(glycan, label, batch_size=1)))
|
11 |
+
device = "cpu"
|
12 |
+
if torch.cuda.is_available():
|
13 |
+
device = "cuda:0"
|
14 |
+
x = data.labels
|
15 |
+
edge_index = data.edge_index
|
16 |
+
batch = data.batch
|
17 |
+
x = x.to(device)
|
18 |
+
edge_index = edge_index.to(device)
|
19 |
+
batch = batch.to(device)
|
20 |
+
pred = model(x,edge_index, batch).cpu().detach().numpy()
|
21 |
+
pred = np.argmax(pred)
|
22 |
+
pred = class_list[pred]
|
23 |
+
return pred
|
24 |
+
return f
|
25 |
+
|
26 |
+
model = torch.load("model.pt")
|
27 |
+
model.eval()
|
28 |
+
class_list=['Amoebozoa', 'Animalia', 'Bacteria', 'Bamfordvirae', 'Chromista', 'Euryarchaeota', 'Excavata', 'Fungi', 'Heunggongvirae',
|
29 |
+
'Orthornavirae', 'Pararnavirae', 'Plantae', 'Proteoarchaeota', 'Protista', 'Riboviria']
|
30 |
+
|
31 |
+
f = fn(model, class_list)
|
32 |
+
|
33 |
+
demo = gr.Interface(
|
34 |
+
fn=f,
|
35 |
+
inputs=[gr.Textbox(label="Glycan sequence")],
|
36 |
+
outputs=[gr.Textbox(label="Predicted Class")],
|
37 |
+
allow_flagging=False,
|
38 |
+
title="SweetNet demo",
|
39 |
+
examples=["GlcOSN(a1-4)GlcA(b1-4)GlcOSN(a1-4)GlcAOS(b1-4)GlcOSN(a1-4)GlcOSN",
|
40 |
+
"Man(a1-2)Man(a1-3)[Man(a1-3)Man(a1-6)]Man(b1-4)GlcNAc(b1-4)GlcNAc",
|
41 |
+
"Neu5Ac(a2-3)Gal(b1-3)[Neu5Ac(a2-6)]GlcNAc(b1-3)Gal(b1-4)Glc-ol"]
|
42 |
+
)
|
43 |
+
demo.launch(debug=True)
|