File size: 5,618 Bytes
5263bd3
 
 
 
 
 
2243c0c
 
 
 
 
5263bd3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33ae1e8
5263bd3
2243c0c
 
 
 
 
 
 
 
 
 
 
 
 
5263bd3
2243c0c
 
 
 
 
 
 
 
 
 
5263bd3
2243c0c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5263bd3
2243c0c
5263bd3
2243c0c
 
 
 
 
 
 
 
 
5263bd3
2243c0c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5263bd3
2243c0c
5263bd3
2243c0c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5263bd3
 
 
 
 
 
2243c0c
 
 
 
 
 
 
 
 
 
 
 
5263bd3
 
 
 
 
2243c0c
5263bd3
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import gradio as gr
import torch
import joblib
import numpy as np
from itertools import product
import torch.nn as nn
import logging

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class VirusClassifier(nn.Module):
    def __init__(self, input_shape: int):
        super(VirusClassifier, self).__init__()
        self.network = nn.Sequential(
            nn.Linear(input_shape, 64),
            nn.GELU(),
            nn.BatchNorm1d(64),
            nn.Dropout(0.3),
            nn.Linear(64, 32),
            nn.GELU(),
            nn.BatchNorm1d(32),
            nn.Dropout(0.3),
            nn.Linear(32, 32),
            nn.GELU(),
            nn.Linear(32, 2)
        )

    def forward(self, x):
        return self.network(x)

def sequence_to_kmer_vector(sequence: str, k: int = 4) -> np.ndarray:
    """Convert sequence to k-mer frequency vector"""
    try:
        kmers = [''.join(p) for p in product("ACGT", repeat=k)]
        kmer_dict = {kmer: 0 for kmer in kmers}
        
        for i in range(len(sequence) - k + 1):
            kmer = sequence[i:i+k]
            if kmer in kmer_dict:  # only count valid kmers
                kmer_dict[kmer] += 1
        
        return np.array(list(kmer_dict.values()))
    except Exception as e:
        logger.error(f"Error in sequence_to_kmer_vector: {str(e)}")
        raise

def parse_fasta(file_obj) -> list:
    """Parse FASTA format from file object"""
    try:
        # Read the content from the file object
        content = file_obj.decode('utf-8')
        logger.info(f"Received file content length: {len(content)}")
        
        sequences = []
        current_header = None
        current_sequence = []
        
        for line in content.split('\n'):
            line = line.strip()
            if not line:
                continue
            if line.startswith('>'):
                if current_header is not None:
                    sequences.append((current_header, ''.join(current_sequence)))
                current_header = line[1:]
                current_sequence = []
            else:
                current_sequence.append(line.upper())
                
        if current_header is not None:
            sequences.append((current_header, ''.join(current_sequence)))
        
        logger.info(f"Parsed {len(sequences)} sequences from FASTA")
        return sequences
    except Exception as e:
        logger.error(f"Error parsing FASTA: {str(e)}")
        raise

def predict_sequence(file_obj) -> str:
    """Process FASTA input and return formatted predictions"""
    try:
        logger.info("Starting prediction process")
        
        if file_obj is None:
            return "Please upload a FASTA file"
            
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        logger.info(f"Using device: {device}")
        k = 4
        
        # Load model and scaler
        try:
            logger.info("Loading model and scaler")
            model = VirusClassifier(256).to(device)  # 256 = 4^4 for 4-mers
            model.load_state_dict(torch.load('model.pt', map_location=device))
            scaler = joblib.load('scaler.pkl')
            model.eval()
        except Exception as e:
            logger.error(f"Error loading model or scaler: {str(e)}")
            return f"Error loading model: {str(e)}"
        
        # Process sequences
        try:
            sequences = parse_fasta(file_obj)
        except Exception as e:
            logger.error(f"Error parsing FASTA file: {str(e)}")
            return f"Error parsing FASTA file: {str(e)}"
            
        results = []
        
        for header, seq in sequences:
            logger.info(f"Processing sequence: {header}")
            try:
                # Convert sequence to k-mer vector
                kmer_vector = sequence_to_kmer_vector(seq, k)
                kmer_vector = scaler.transform(kmer_vector.reshape(1, -1))
                
                # Get prediction
                with torch.no_grad():
                    output = model(torch.FloatTensor(kmer_vector).to(device))
                    probs = torch.softmax(output, dim=1)
                    
                # Format result
                pred_class = 1 if probs[0][1] > probs[0][0] else 0
                pred_label = 'human' if pred_class == 1 else 'non-human'
                
                result = f"""
Sequence: {header}
Prediction: {pred_label}
Confidence: {float(max(probs[0])):0.4f}
Human probability: {float(probs[0][1]):0.4f}
Non-human probability: {float(probs[0][0]):0.4f}
"""
                results.append(result)
                logger.info(f"Processed sequence {header} successfully")
                
            except Exception as e:
                logger.error(f"Error processing sequence {header}: {str(e)}")
                results.append(f"Error processing sequence {header}: {str(e)}")
        
        return "\n".join(results)
        
    except Exception as e:
        logger.error(f"Unexpected error in predict_sequence: {str(e)}")
        return f"An unexpected error occurred: {str(e)}"

# Create Gradio interface
iface = gr.Interface(
    fn=predict_sequence,
    inputs=gr.File(label="Upload FASTA file", file_types=[".fasta", ".fa", ".txt"]),
    outputs=gr.Textbox(label="Prediction Results", lines=10),
    title="Virus Host Classifier",
    description="Upload a FASTA file to predict whether a virus sequence is likely to infect human or non-human hosts.",
    examples=[["example.fasta"]],
    cache_examples=True
)

# Launch the interface
iface.launch()