File size: 2,036 Bytes
b6f1234
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import torch
import numpy as np
import pandas as pd
import time
from torch_geometric.data import DataLoader
from model.model_concatenation import PLANet
from utils.args import ArgsInit
from utils.model import get_dataset_inference, test_gcn


def main(args):

    if args.use_gpu:
        device = torch.device("cuda:" + str(args.device)) if torch.cuda.is_available() else torch.device("cpu")
    else:
        device = torch.device('cpu')
    
    #Numpy and torch seeds
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    if device.type == 'cuda':
        torch.cuda.manual_seed(args.seed)
    print('%s' % args)
    
    
    data_inference = pd.read_csv(
        args.input_file_smiles, 
        names=["Smiles"],
        header=0
    )
 
    print("Data Inference: ", data_inference)

    data_target = pd.read_csv(
        args.target_list, names=["Fasta", "Target", "Label"]
    )
    data_target = data_target[data_target.Target == args.target]

    print("Data Target: ", data_target)

    test = get_dataset_inference(
        data_inference,
        use_prot=args.use_prot,
        target=data_target,
        args=args,
        advs=False,
        saliency=False,
    )
    
    test_loader = DataLoader(test, batch_size=args.batch_size, shuffle=False,
                            num_workers=args.num_workers)

    model = PLANet(args).to(device)


    print('Model inference in: {}'.format(args.inference_path))
    start_time = time.time()

    #Load pre-trained molecule model

    print('Evaluating...')
    test_gcn(model, device, test_loader, args)


    end_time = time.time()
    total_time = end_time - start_time
    print('Total time: {}'.format(time.strftime('%H:%M:%S', time.gmtime(total_time))))


if __name__ == "__main__":
    args = ArgsInit().args
    # Default args for inference
    
    args.nclasses = 2
    args.batch_size = 10
    args.use_prot = True
    args.freeze_molecule = True
    args.conv_encode_edge = True
    args.learn_t = True
    args.binary = True
    
    main(args)