Spaces:
Runtime error
Runtime error
File size: 3,553 Bytes
c20f071 44c8341 50edbe9 6506504 1a7661f 633bc62 1a7661f 44c8341 50edbe9 85f8980 840fdaa 555d33d 840fdaa 0fbae15 fc829e4 cb13d0d fc829e4 0fbae15 fc829e4 44c8341 8b25912 47aa6b1 fc829e4 47aa6b1 8b25912 be08073 44c8341 85f8980 44c8341 2b584be 47aa6b1 be08073 44c8341 0fbae15 7a972f8 cb13d0d 44c8341 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 |
import os
os.system("pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu")
os.system("pip install torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f https://data.pyg.org/whl/torch-1.12.0+cpu.html")
import gradio as gr
from glycowork.ml.processing import dataset_to_dataloader
import numpy as np
import torch
import torch.nn as nn
from glycowork.motif.graph import glycan_to_nxGraph
import networkx as nx
import matplotlib.pyplot as plt
from networkx.drawing.nx_agraph import write_dot
# import pygraphviz as pgv
class EnsembleModel(nn.Module):
def __init__(self, models):
super().__init__()
self.models = models
def forward(self, data):
# Check if GPU available
device = "cpu"
if torch.cuda.is_available():
device = "cuda:0"
# Prepare data
x = data.labels.to(device)
edge_index = data.edge_index.to(device)
batch = data.batch.to(device)
y_pred = [model(x,edge_index, batch).cpu().detach().numpy() for model in self.models]
y_pred = np.mean(y_pred,axis=0)[0]
return y_pred
class_list=['Amoebozoa', 'Animalia', 'Bacteria', 'Bamfordvirae', 'Chromista', 'Euryarchaeota', 'Excavata', 'Fungi', 'Heunggongvirae',
'Orthornavirae', 'Pararnavirae', 'Plantae', 'Proteoarchaeota', 'Protista', 'Riboviria']
model1 = torch.load("model1.pt", map_location=torch.device('cpu'))
model2 = torch.load("model2.pt", map_location=torch.device('cpu'))
model3 = torch.load("model3.pt", map_location=torch.device('cpu'))
def fn(glycan, model):
# Draw graph
# graph = glycan_to_nxGraph(glycan)
# node_labels = nx.get_node_attributes(graph, 'string_labels')
# labels = {i:node_labels[i] for i in range(len(graph.nodes))}
# graph = nx.relabel_nodes(graph, labels)
# write_dot(graph, "graph.dot")
# graph=pgv.AGraph("graph.dot")
# graph.layout(prog='dot')
# graph.draw("graph.png")
# Perform inference
if model == "No data augmentation":
model_pred = model1
model_pred.eval()
elif model == "Ensemble":
model_pred = model3
model_pred.eval()
else:
model_pred = model2
model_pred.eval()
glycan = [glycan]
label = [0]
data = next(iter(dataset_to_dataloader(glycan, label, batch_size=1)))
if model == "Ensemble":
pred = model_pred(data)
else:
device = "cpu"
x = data.labels
edge_index = data.edge_index
batch = data.batch
x = x.to(device)
edge_index = edge_index.to(device)
batch = batch.to(device)
pred = model_pred(x,edge_index, batch).cpu().detach().numpy()[0]
pred = np.exp(pred)/sum(np.exp(pred)) # Softmax
pred = [float(x) for x in pred]
pred = {class_list[i]:pred[i] for i in range(15)}
return pred, "graph.png"
demo = gr.Interface(
fn=fn,
inputs=[gr.Textbox(label="Glycan sequence"), gr.Radio(label="Model",choices=["No data augmentation", "Random node deletion", "Ensemble"])],
outputs=[gr.Label(num_top_classes=15, label="Prediction"), gr.Image(label="Glycan graph")],
allow_flagging=False,
title="SweetNet demo",
examples=[["GlcOSN(a1-4)GlcA(b1-4)GlcOSN(a1-4)GlcAOS(b1-4)GlcOSN(a1-4)GlcOSN", "No data augmentation"],
["Man(a1-2)Man(a1-3)[Man(a1-3)Man(a1-6)]Man(b1-4)GlcNAc(b1-4)GlcNAc", "Random node deletion"],
["Man(a1-2)Man(a1-3)[Man(a1-6)]Man(a1-6)[Man(a1-2)Man(a1-2)Man(a1-3)]Man(b1-4)GlcNAc", "Ensemble"]]
)
demo.launch(debug=True) |