File size: 2,907 Bytes
c20f071
 
 
44c8341
 
 
 
50edbe9
44c8341
50edbe9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85f8980
 
 
 
 
 
 
 
0fbae15
fc829e4
 
cb13d0d
fc829e4
 
0fbae15
fc829e4
 
 
44c8341
 
8b25912
fc829e4
 
 
 
 
 
 
 
 
 
 
8b25912
 
 
44c8341
 
85f8980
44c8341
2b584be
0fbae15
 
44c8341
 
0fbae15
7a972f8
cb13d0d
44c8341
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import os
os.system("pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu")
os.system("pip install torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f https://data.pyg.org/whl/torch-1.12.0+cpu.html")
import gradio as gr
from glycowork.ml.processing import dataset_to_dataloader
import numpy as np
import torch
import torch.nn as nn

class EnsembleModel(nn.Module):
    def __init__(self, models):
        super().__init__()
        self.models = models
        
    def forward(self, data):
      # Check if GPU available
      device = "cpu"
      if torch.cuda.is_available():
        device = "cuda:0"
      # Prepare data
      x = data.labels.to(device)
      edge_index = data.edge_index.to(device)
      batch = data.batch.to(device)
      y_pred = [model(x,edge_index, batch).cpu().detach().numpy() for model in self.models]
      y_pred = np.mean(y_pred,axis=0)[0]
      return y_pred
  
class_list=['Amoebozoa', 'Animalia', 'Bacteria', 'Bamfordvirae', 'Chromista', 'Euryarchaeota', 'Excavata', 'Fungi', 'Heunggongvirae', 
            'Orthornavirae', 'Pararnavirae', 'Plantae', 'Proteoarchaeota', 'Protista', 'Riboviria']

model1 = torch.load("model1.pt", map_location=torch.device('cpu'))
model2 = torch.load("model2.pt", map_location=torch.device('cpu'))
model3 = torch.load("model3.pt", map_location=torch.device('cpu'))

def fn(glycan, model):
    if model == "No data augmentation":
      model_pred = model1
      model_pred.eval()
    elif model == "Ensemble":
      model_pred = model3
      model_pred.eval()
    else:
      model_pred = model2
      model_pred.eval()
    
    glycan = [glycan]
    label = [0]
    data = next(iter(dataset_to_dataloader(glycan, label, batch_size=1)))
    if model == "Ensemble":
        pred = model_pred(data)
    else:
        device = "cpu"
        x = data.labels
        edge_index = data.edge_index
        batch = data.batch
        x = x.to(device)
        edge_index = edge_index.to(device)
        batch = batch.to(device)
        pred = model_pred(x,edge_index, batch).cpu().detach().numpy()[0]
    pred = np.exp(pred)/sum(np.exp(pred)) # Softmax 
    pred = [float(x) for x in pred]
    pred = {class_list[i]:pred[i] for i in range(15)}
    return pred


demo = gr.Interface(
    fn=fn,
    inputs=[gr.Textbox(label="Glycan sequence"), gr.Radio(label="Model",choices=["No data augmentation", "Random node deletion"])],
    outputs=[gr.Label(num_top_classes=15, label="Prediction")],
    allow_flagging=False,
    title="SweetNet demo",
    examples=[["GlcOSN(a1-4)GlcA(b1-4)GlcOSN(a1-4)GlcAOS(b1-4)GlcOSN(a1-4)GlcOSN", "No data augmentation"],
    ["Man(a1-2)Man(a1-3)[Man(a1-3)Man(a1-6)]Man(b1-4)GlcNAc(b1-4)GlcNAc", "Random node deletion"],
    ["Man(a1-2)Man(a1-3)[Man(a1-6)]Man(a1-6)[Man(a1-2)Man(a1-2)Man(a1-3)]Man(b1-4)GlcNAc", "Ensemble"]]
)
demo.launch(debug=True)