Spaces:
Runtime error
Runtime error
import os | |
os.system("pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu") | |
os.system("pip install torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f https://data.pyg.org/whl/torch-1.12.0+cpu.html") | |
import gradio as gr | |
from glycowork.ml.processing import dataset_to_dataloader | |
import numpy as np | |
import torch | |
from glycowork.glycan_data.loader import lib | |
# Update lib | |
equivalence_classes = [ | |
["Glc", "Man", "Gal", "Gul", "Alt", "All", "Tal", "Ido" ], | |
["GlcNAc", "ManNAc", "GalNAc", "GulNAc", "AltNAc", "AllNAc", "TalNAc", "IdoNAc"], | |
["GlcN", "ManN", "GalN", "GulN", "AltN", "AllN", "TalN", "IdoN"], | |
["GlcA", "ManA", "GalA", "GulA", "AltA", "AllA", "TalA", "IdoA"], | |
["Qui", "Rha", "6dGul", "6dAlt", "6dTal", "Fuc"], | |
["QuiNAc", "RhaNAc", "6dAltNAc", "6dTalNAc", "FucNAc"], | |
["Oli", "Tyv", "Abe", "Par", "Dig", "Col"], | |
["Ara", "Lyx", "Xyl", "Rib"], | |
["Kdn", "Neu5Ac", "Neu5Gc", "Neu", "Sia"], | |
["Pse", "Leg", "Aci", "4eLeg"], | |
["Bac", "LDmanHep", "Kdo", "Dha", "DDmanHep", "MurNAc", "MurNGc", "Mur", "Api", "Fru", "Tag", "Sor", "Psi"] | |
] | |
linkage_classes = [ | |
["a1-2", "a1-z", "z1-2", "z1-z"], | |
["a1-3", "a1-z", "z1-3", "z1-z"], | |
["a1-4", "a1-z", "z1-4", "z1-z"], | |
["a1-6", "a1-z", "z1-6", "z1-z"], | |
["b1-2", "b1-z", "z1-2", "z1-z"], | |
["b1-3", "b1-z", "z1-3", "z1-z"], | |
["b1-4", "b1-z", "z1-4", "z1-z"], | |
["b1-6", "b1-z", "z1-6", "z1-z"], | |
["a2-3", "a2-z", "z2-3", "z2-z"], | |
["a2-6", "a2-z", "z2-6", "z2-z"], | |
["a2-8", "a2-z", "z2-8", "z2-z"] | |
] | |
# Update lib | |
print(len(lib)) | |
for equivalence_class in equivalence_classes: | |
for target in equivalence_class: | |
if target not in lib: | |
lib.append(target) | |
for linkage_class in linkage_classes: | |
for target in linkage_class: | |
if target not in lib: | |
lib.append(target) | |
print(len(lib)) | |
def fn(model, class_list): | |
def f(glycan): | |
glycan = [glycan] | |
label = [0] | |
data = next(iter(dataset_to_dataloader(glycan, label, batch_size=1))) | |
device = "cpu" | |
if torch.cuda.is_available(): | |
device = "cuda:0" | |
x = data.labels | |
edge_index = data.edge_index | |
batch = data.batch | |
x = x.to(device) | |
edge_index = edge_index.to(device) | |
batch = batch.to(device) | |
pred = model(x,edge_index, batch).cpu().detach().numpy() | |
pred = np.argmax(pred) | |
pred = class_list[pred] | |
return pred | |
return f | |
model = torch.load("model.pt") | |
model.eval() | |
class_list=['Amoebozoa', 'Animalia', 'Bacteria', 'Bamfordvirae', 'Chromista', 'Euryarchaeota', 'Excavata', 'Fungi', 'Heunggongvirae', | |
'Orthornavirae', 'Pararnavirae', 'Plantae', 'Proteoarchaeota', 'Protista', 'Riboviria'] | |
f = fn(model, class_list) | |
demo = gr.Interface( | |
fn=f, | |
inputs=[gr.Textbox(label="Glycan sequence")], | |
outputs=[gr.Textbox(label="Predicted Class")], | |
allow_flagging=False, | |
title="SweetNet demo", | |
examples=["GlcOSN(a1-4)GlcA(b1-4)GlcOSN(a1-4)GlcAOS(b1-4)GlcOSN(a1-4)GlcOSN", | |
"Man(a1-2)Man(a1-3)[Man(a1-3)Man(a1-6)]Man(b1-4)GlcNAc(b1-4)GlcNAc", | |
"Neu5Ac(a2-3)Gal(b1-3)[Neu5Ac(a2-6)]GlcNAc(b1-3)Gal(b1-4)Glc-ol"] | |
) | |
demo.launch(debug=True) |