File size: 1,671 Bytes
c20f071
 
 
 
44c8341
 
 
 
 
5e62770
44c8341
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import os
os.system("pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu")
os.system("pip install torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f https://data.pyg.org/whl/torch-1.12.0+cpu.html")

import gradio as gr
from glycowork.ml.processing import dataset_to_dataloader
import numpy as np
import torch


def fn(model, class_list):
  def f(glycan):
    glycan = [glycan]
    label = [0]
    data = next(iter(dataset_to_dataloader(glycan, label, batch_size=1)))
    device = "cpu"
    if torch.cuda.is_available():
      device = "cuda:0"
    x = data.labels
    edge_index = data.edge_index
    batch = data.batch
    x = x.to(device)
    edge_index = edge_index.to(device)
    batch = batch.to(device) 
    pred = model(x,edge_index, batch).cpu().detach().numpy()
    pred = np.argmax(pred)
    pred = class_list[pred]
    return pred
  return f

model = torch.load("model.pt")
model.eval()  
class_list=['Amoebozoa', 'Animalia', 'Bacteria', 'Bamfordvirae', 'Chromista', 'Euryarchaeota', 'Excavata', 'Fungi', 'Heunggongvirae', 
            'Orthornavirae', 'Pararnavirae', 'Plantae', 'Proteoarchaeota', 'Protista', 'Riboviria']

f = fn(model, class_list)  

demo = gr.Interface(
    fn=f,
    inputs=[gr.Textbox(label="Glycan sequence")],
    outputs=[gr.Textbox(label="Predicted Class")],
    allow_flagging=False,
    title="SweetNet demo",
    examples=["GlcOSN(a1-4)GlcA(b1-4)GlcOSN(a1-4)GlcAOS(b1-4)GlcOSN(a1-4)GlcOSN",
    "Man(a1-2)Man(a1-3)[Man(a1-3)Man(a1-6)]Man(b1-4)GlcNAc(b1-4)GlcNAc",
    "Neu5Ac(a2-3)Gal(b1-3)[Neu5Ac(a2-6)]GlcNAc(b1-3)Gal(b1-4)Glc-ol"]
)
demo.launch(debug=True)