import os os.system("pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu") os.system("pip install torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f https://data.pyg.org/whl/torch-1.12.0+cpu.html") import gradio as gr from glycowork.ml.processing import dataset_to_dataloader import numpy as np import torch from glycowork.glycan_data.loader import lib # Update lib equivalence_classes = [ ["Glc", "Man", "Gal", "Gul", "Alt", "All", "Tal", "Ido" ], ["GlcNAc", "ManNAc", "GalNAc", "GulNAc", "AltNAc", "AllNAc", "TalNAc", "IdoNAc"], ["GlcN", "ManN", "GalN", "GulN", "AltN", "AllN", "TalN", "IdoN"], ["GlcA", "ManA", "GalA", "GulA", "AltA", "AllA", "TalA", "IdoA"], ["Qui", "Rha", "6dGul", "6dAlt", "6dTal", "Fuc"], ["QuiNAc", "RhaNAc", "6dAltNAc", "6dTalNAc", "FucNAc"], ["Oli", "Tyv", "Abe", "Par", "Dig", "Col"], ["Ara", "Lyx", "Xyl", "Rib"], ["Kdn", "Neu5Ac", "Neu5Gc", "Neu", "Sia"], ["Pse", "Leg", "Aci", "4eLeg"], ["Bac", "LDmanHep", "Kdo", "Dha", "DDmanHep", "MurNAc", "MurNGc", "Mur", "Api", "Fru", "Tag", "Sor", "Psi"] ] linkage_classes = [ ["a1-2", "a1-z", "z1-2", "z1-z"], ["a1-3", "a1-z", "z1-3", "z1-z"], ["a1-4", "a1-z", "z1-4", "z1-z"], ["a1-6", "a1-z", "z1-6", "z1-z"], ["b1-2", "b1-z", "z1-2", "z1-z"], ["b1-3", "b1-z", "z1-3", "z1-z"], ["b1-4", "b1-z", "z1-4", "z1-z"], ["b1-6", "b1-z", "z1-6", "z1-z"], ["a2-3", "a2-z", "z2-3", "z2-z"], ["a2-6", "a2-z", "z2-6", "z2-z"], ["a2-8", "a2-z", "z2-8", "z2-z"] ] # Update lib print(len(lib)) for equivalence_class in equivalence_classes: for target in equivalence_class: if target not in lib: lib.append(target) for linkage_class in linkage_classes: for target in linkage_class: if target not in lib: lib.append(target) print(len(lib)) def fn(model, class_list): def f(glycan): glycan = [glycan] label = [0] data = next(iter(dataset_to_dataloader(glycan, label, batch_size=1))) device = "cpu" if torch.cuda.is_available(): device = "cuda:0" x = data.labels edge_index = data.edge_index batch = data.batch x = x.to(device) edge_index = edge_index.to(device) batch = batch.to(device) pred = model(x,edge_index, batch).cpu().detach().numpy() pred = np.argmax(pred) pred = class_list[pred] return pred return f model = torch.load("model.pt") model.eval() class_list=['Amoebozoa', 'Animalia', 'Bacteria', 'Bamfordvirae', 'Chromista', 'Euryarchaeota', 'Excavata', 'Fungi', 'Heunggongvirae', 'Orthornavirae', 'Pararnavirae', 'Plantae', 'Proteoarchaeota', 'Protista', 'Riboviria'] f = fn(model, class_list) demo = gr.Interface( fn=f, inputs=[gr.Textbox(label="Glycan sequence")], outputs=[gr.Textbox(label="Predicted Class")], allow_flagging=False, title="SweetNet demo", examples=["GlcOSN(a1-4)GlcA(b1-4)GlcOSN(a1-4)GlcAOS(b1-4)GlcOSN(a1-4)GlcOSN", "Man(a1-2)Man(a1-3)[Man(a1-3)Man(a1-6)]Man(b1-4)GlcNAc(b1-4)GlcNAc", "Neu5Ac(a2-3)Gal(b1-3)[Neu5Ac(a2-6)]GlcNAc(b1-3)Gal(b1-4)Glc-ol"] ) demo.launch(debug=True)