SweetNet / app.py
dalexanderch's picture
Upload app.py
bdac097
raw
history blame
3.13 kB
import os
os.system("pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu")
os.system("pip install torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f https://data.pyg.org/whl/torch-1.12.0+cpu.html")
import gradio as gr
from glycowork.ml.processing import dataset_to_dataloader
import numpy as np
import torch
from glycowork.glycan_data.loader import lib
# Update lib
equivalence_classes = [
["Glc", "Man", "Gal", "Gul", "Alt", "All", "Tal", "Ido" ],
["GlcNAc", "ManNAc", "GalNAc", "GulNAc", "AltNAc", "AllNAc", "TalNAc", "IdoNAc"],
["GlcN", "ManN", "GalN", "GulN", "AltN", "AllN", "TalN", "IdoN"],
["GlcA", "ManA", "GalA", "GulA", "AltA", "AllA", "TalA", "IdoA"],
["Qui", "Rha", "6dGul", "6dAlt", "6dTal", "Fuc"],
["QuiNAc", "RhaNAc", "6dAltNAc", "6dTalNAc", "FucNAc"],
["Oli", "Tyv", "Abe", "Par", "Dig", "Col"],
["Ara", "Lyx", "Xyl", "Rib"],
["Kdn", "Neu5Ac", "Neu5Gc", "Neu", "Sia"],
["Pse", "Leg", "Aci", "4eLeg"],
["Bac", "LDmanHep", "Kdo", "Dha", "DDmanHep", "MurNAc", "MurNGc", "Mur", "Api", "Fru", "Tag", "Sor", "Psi"]
]
linkage_classes = [
["a1-2", "a1-z", "z1-2", "z1-z"],
["a1-3", "a1-z", "z1-3", "z1-z"],
["a1-4", "a1-z", "z1-4", "z1-z"],
["a1-6", "a1-z", "z1-6", "z1-z"],
["b1-2", "b1-z", "z1-2", "z1-z"],
["b1-3", "b1-z", "z1-3", "z1-z"],
["b1-4", "b1-z", "z1-4", "z1-z"],
["b1-6", "b1-z", "z1-6", "z1-z"],
["a2-3", "a2-z", "z2-3", "z2-z"],
["a2-6", "a2-z", "z2-6", "z2-z"],
["a2-8", "a2-z", "z2-8", "z2-z"]
]
# Update lib
print(len(lib))
for equivalence_class in equivalence_classes:
for target in equivalence_class:
if target not in lib:
lib.append(target)
for linkage_class in linkage_classes:
for target in linkage_class:
if target not in lib:
lib.append(target)
print(len(lib))
def fn(model, class_list):
def f(glycan):
glycan = [glycan]
label = [0]
data = next(iter(dataset_to_dataloader(glycan, label, batch_size=1)))
device = "cpu"
if torch.cuda.is_available():
device = "cuda:0"
x = data.labels
edge_index = data.edge_index
batch = data.batch
x = x.to(device)
edge_index = edge_index.to(device)
batch = batch.to(device)
pred = model(x,edge_index, batch).cpu().detach().numpy()
pred = np.argmax(pred)
pred = class_list[pred]
return pred
return f
model = torch.load("model.pt")
model.eval()
class_list=['Amoebozoa', 'Animalia', 'Bacteria', 'Bamfordvirae', 'Chromista', 'Euryarchaeota', 'Excavata', 'Fungi', 'Heunggongvirae',
'Orthornavirae', 'Pararnavirae', 'Plantae', 'Proteoarchaeota', 'Protista', 'Riboviria']
f = fn(model, class_list)
demo = gr.Interface(
fn=f,
inputs=[gr.Textbox(label="Glycan sequence")],
outputs=[gr.Textbox(label="Predicted Class")],
allow_flagging=False,
title="SweetNet demo",
examples=["GlcOSN(a1-4)GlcA(b1-4)GlcOSN(a1-4)GlcAOS(b1-4)GlcOSN(a1-4)GlcOSN",
"Man(a1-2)Man(a1-3)[Man(a1-3)Man(a1-6)]Man(b1-4)GlcNAc(b1-4)GlcNAc",
"Neu5Ac(a2-3)Gal(b1-3)[Neu5Ac(a2-6)]GlcNAc(b1-3)Gal(b1-4)Glc-ol"]
)
demo.launch(debug=True)