File size: 3,644 Bytes
5263bd3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import gradio as gr
import torch
import joblib
import numpy as np
from itertools import product
from typing import Dict
import torch.nn as nn

class VirusClassifier(nn.Module):
    def __init__(self, input_shape: int):
        super(VirusClassifier, self).__init__()
        self.network = nn.Sequential(
            nn.Linear(input_shape, 64),
            nn.GELU(),
            nn.BatchNorm1d(64),
            nn.Dropout(0.3),
            nn.Linear(64, 32),
            nn.GELU(),
            nn.BatchNorm1d(32),
            nn.Dropout(0.3),
            nn.Linear(32, 32),
            nn.GELU(),
            nn.Linear(32, 2)
        )

    def forward(self, x):
        return self.network(x)

def sequence_to_kmer_vector(sequence: str, k: int = 6) -> np.ndarray:
    """Convert sequence to k-mer frequency vector"""
    kmers = [''.join(p) for p in product("ACGT", repeat=k)]
    kmer_dict = {kmer: 0 for kmer in kmers}
    
    for i in range(len(sequence) - k + 1):
        kmer = sequence[i:i+k]
        if kmer in kmer_dict:  # only count valid kmers
            kmer_dict[kmer] += 1
            
    return np.array(list(kmer_dict.values()))

def parse_fasta(fasta_content: str):
    """Parse FASTA format string"""
    sequences = []
    current_header = None
    current_sequence = []
    
    for line in fasta_content.split('\n'):
        line = line.strip()
        if not line:
            continue
        if line.startswith('>'):
            if current_header is not None:
                sequences.append((current_header, ''.join(current_sequence)))
            current_header = line[1:]
            current_sequence = []
        else:
            current_sequence.append(line.upper())
            
    if current_header is not None:
        sequences.append((current_header, ''.join(current_sequence)))
        
    return sequences

def predict_sequence(fasta_content: str) -> str:
    """Process FASTA input and return formatted predictions"""
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    k = 6
    
    # Load model and scaler
    model = VirusClassifier(4096).to(device)  # 4096 = 4^6 for 6-mers
    model.load_state_dict(torch.load('model.pt', map_location=device))
    scaler = joblib.load('scaler.pkl')
    model.eval()
    
    # Process sequences
    sequences = parse_fasta(fasta_content)
    results = []
    
    for header, seq in sequences:
        # Convert sequence to k-mer vector
        kmer_vector = sequence_to_kmer_vector(seq, k)
        kmer_vector = scaler.transform(kmer_vector.reshape(1, -1))
        
        # Get prediction
        with torch.no_grad():
            output = model(torch.FloatTensor(kmer_vector).to(device))
            probs = torch.softmax(output, dim=1)
            
        # Format result
        pred_class = 1 if probs[0][1] > probs[0][0] else 0
        pred_label = 'human' if pred_class == 1 else 'non-human'
        
        result = f"""
Sequence: {header}
Prediction: {pred_label}
Confidence: {float(max(probs[0])):0.4f}
Human probability: {float(probs[0][1]):0.4f}
Non-human probability: {float(probs[0][0]):0.4f}
"""
        results.append(result)
    
    return "\n".join(results)

# Create Gradio interface
iface = gr.Interface(
    fn=predict_sequence,
    inputs=gr.File(label="Upload FASTA file", file_types=[".fasta", ".fa", ".txt"]),
    outputs=gr.Textbox(label="Prediction Results"),
    title="Virus Host Classifier",
    description="Upload a FASTA file to predict whether a virus sequence is likely to infect human or non-human hosts.",
    examples=[["example.fasta"]],
    cache_examples=True
)

# Launch the interface
iface.launch()