Spaces:
Runtime error
Runtime error
File size: 3,485 Bytes
c20f071 cc4e3d8 da72438 840fdaa 44c8341 50edbe9 6506504 44c8341 50edbe9 85f8980 840fdaa 6506504 840fdaa 0fbae15 fc829e4 cb13d0d fc829e4 0fbae15 fc829e4 44c8341 8b25912 47aa6b1 fc829e4 47aa6b1 8b25912 6506504 44c8341 85f8980 44c8341 2b584be 47aa6b1 9e05631 44c8341 0fbae15 7a972f8 cb13d0d 44c8341 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
import os
os.system("pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu")
os.system("pip install torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f https://data.pyg.org/whl/torch-1.12.0+cpu.html")
# os.system("apt-get install -y graphviz-dev")
# os.system("pip install pygraphviz")
import gradio as gr
from glycowork.ml.processing import dataset_to_dataloader
import numpy as np
import torch
import torch.nn as nn
from glycowork.motif.graph import glycan_to_nxGraph
import networkx as nx
import matplotlib.pyplot as plt
class EnsembleModel(nn.Module):
def __init__(self, models):
super().__init__()
self.models = models
def forward(self, data):
# Check if GPU available
device = "cpu"
if torch.cuda.is_available():
device = "cuda:0"
# Prepare data
x = data.labels.to(device)
edge_index = data.edge_index.to(device)
batch = data.batch.to(device)
y_pred = [model(x,edge_index, batch).cpu().detach().numpy() for model in self.models]
y_pred = np.mean(y_pred,axis=0)[0]
return y_pred
class_list=['Amoebozoa', 'Animalia', 'Bacteria', 'Bamfordvirae', 'Chromista', 'Euryarchaeota', 'Excavata', 'Fungi', 'Heunggongvirae',
'Orthornavirae', 'Pararnavirae', 'Plantae', 'Proteoarchaeota', 'Protista', 'Riboviria']
model1 = torch.load("model1.pt", map_location=torch.device('cpu'))
model2 = torch.load("model2.pt", map_location=torch.device('cpu'))
model3 = torch.load("model3.pt", map_location=torch.device('cpu'))
def fn(glycan, model):
# Draw graph
graph = glycan_to_nxGraph(glycan)
node_labels = nx.get_node_attributes(graph, 'string_labels')
labels = {i:node_labels[i] for i in range(len(graph.nodes))}
graph = nx.relabel_nodes(graph, labels)
nx.draw(graph, with_labels=True)
plt.savefig("graph.png")
# Perform inference
if model == "No data augmentation":
model_pred = model1
model_pred.eval()
elif model == "Ensemble":
model_pred = model3
model_pred.eval()
else:
model_pred = model2
model_pred.eval()
glycan = [glycan]
label = [0]
data = next(iter(dataset_to_dataloader(glycan, label, batch_size=1)))
if model == "Ensemble":
pred = model_pred(data)
else:
device = "cpu"
x = data.labels
edge_index = data.edge_index
batch = data.batch
x = x.to(device)
edge_index = edge_index.to(device)
batch = batch.to(device)
pred = model_pred(x,edge_index, batch).cpu().detach().numpy()[0]
pred = np.exp(pred)/sum(np.exp(pred)) # Softmax
pred = [float(x) for x in pred]
pred = {class_list[i]:pred[i] for i in range(15)}
return "graph.png", pred
demo = gr.Interface(
fn=fn,
inputs=[gr.Textbox(label="Glycan sequence"), gr.Radio(label="Model",choices=["No data augmentation", "Random node deletion", "Ensemble"])],
outputs=[gr.Image(label="Glycan graph"), gr.Label(num_top_classes=15, label="Prediction")],
allow_flagging=False,
title="SweetNet demo",
examples=[["GlcOSN(a1-4)GlcA(b1-4)GlcOSN(a1-4)GlcAOS(b1-4)GlcOSN(a1-4)GlcOSN", "No data augmentation"],
["Man(a1-2)Man(a1-3)[Man(a1-3)Man(a1-6)]Man(b1-4)GlcNAc(b1-4)GlcNAc", "Random node deletion"],
["Man(a1-2)Man(a1-3)[Man(a1-6)]Man(a1-6)[Man(a1-2)Man(a1-2)Man(a1-3)]Man(b1-4)GlcNAc", "Ensemble"]]
)
demo.launch(debug=True) |