File size: 3,134 Bytes
c20f071
 
 
 
44c8341
 
 
 
bdac097
44c8341
bdac097
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5e62770
44c8341
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import os
os.system("pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu")
os.system("pip install torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f https://data.pyg.org/whl/torch-1.12.0+cpu.html")

import gradio as gr
from glycowork.ml.processing import dataset_to_dataloader
import numpy as np
import torch
from glycowork.glycan_data.loader import lib

# Update lib
equivalence_classes = [                     
["Glc", "Man", "Gal", "Gul", "Alt", "All", "Tal", "Ido" ],
["GlcNAc", "ManNAc", "GalNAc", "GulNAc", "AltNAc", "AllNAc", "TalNAc", "IdoNAc"],
["GlcN", "ManN", "GalN", "GulN", "AltN", "AllN", "TalN", "IdoN"],
["GlcA", "ManA", "GalA", "GulA", "AltA", "AllA", "TalA", "IdoA"],
["Qui", "Rha", "6dGul", "6dAlt", "6dTal", "Fuc"],
["QuiNAc", "RhaNAc", "6dAltNAc", "6dTalNAc", "FucNAc"],
["Oli", "Tyv", "Abe", "Par", "Dig", "Col"],
["Ara", "Lyx", "Xyl", "Rib"],
["Kdn", "Neu5Ac", "Neu5Gc", "Neu", "Sia"],
["Pse", "Leg", "Aci", "4eLeg"],
["Bac", "LDmanHep", "Kdo", "Dha", "DDmanHep", "MurNAc", "MurNGc", "Mur", "Api", "Fru", "Tag", "Sor", "Psi"]
]

linkage_classes = [    
  ["a1-2", "a1-z", "z1-2", "z1-z"],
  ["a1-3", "a1-z", "z1-3", "z1-z"],
  ["a1-4", "a1-z", "z1-4", "z1-z"],
  ["a1-6", "a1-z", "z1-6", "z1-z"],
  ["b1-2", "b1-z", "z1-2", "z1-z"],
  ["b1-3", "b1-z", "z1-3", "z1-z"],
  ["b1-4", "b1-z", "z1-4", "z1-z"],
  ["b1-6", "b1-z", "z1-6", "z1-z"],
  ["a2-3", "a2-z", "z2-3", "z2-z"],
  ["a2-6", "a2-z", "z2-6", "z2-z"],
  ["a2-8", "a2-z", "z2-8", "z2-z"]
  ]

# Update lib
print(len(lib))
for equivalence_class in equivalence_classes:
  for target in equivalence_class:
    if target not in lib:
      lib.append(target)
for linkage_class in linkage_classes:
  for target in linkage_class:
    if target not in lib:
      lib.append(target)
print(len(lib))

def fn(model, class_list):
  def f(glycan):
    glycan = [glycan]
    label = [0]
    data = next(iter(dataset_to_dataloader(glycan, label, batch_size=1)))
    device = "cpu"
    if torch.cuda.is_available():
      device = "cuda:0"
    x = data.labels
    edge_index = data.edge_index
    batch = data.batch
    x = x.to(device)
    edge_index = edge_index.to(device)
    batch = batch.to(device) 
    pred = model(x,edge_index, batch).cpu().detach().numpy()
    pred = np.argmax(pred)
    pred = class_list[pred]
    return pred
  return f

model = torch.load("model.pt")
model.eval()  
class_list=['Amoebozoa', 'Animalia', 'Bacteria', 'Bamfordvirae', 'Chromista', 'Euryarchaeota', 'Excavata', 'Fungi', 'Heunggongvirae', 
            'Orthornavirae', 'Pararnavirae', 'Plantae', 'Proteoarchaeota', 'Protista', 'Riboviria']

f = fn(model, class_list)  

demo = gr.Interface(
    fn=f,
    inputs=[gr.Textbox(label="Glycan sequence")],
    outputs=[gr.Textbox(label="Predicted Class")],
    allow_flagging=False,
    title="SweetNet demo",
    examples=["GlcOSN(a1-4)GlcA(b1-4)GlcOSN(a1-4)GlcAOS(b1-4)GlcOSN(a1-4)GlcOSN",
    "Man(a1-2)Man(a1-3)[Man(a1-3)Man(a1-6)]Man(b1-4)GlcNAc(b1-4)GlcNAc",
    "Neu5Ac(a2-3)Gal(b1-3)[Neu5Ac(a2-6)]GlcNAc(b1-3)Gal(b1-4)Glc-ol"]
)
demo.launch(debug=True)