Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -24,6 +24,20 @@ class VirusClassifier(nn.Module):
|
|
24 |
|
25 |
def forward(self, x):
|
26 |
return self.network(x)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
|
28 |
def sequence_to_kmer_vector(sequence: str, k: int = 4) -> np.ndarray:
|
29 |
"""Convert sequence to k-mer frequency vector"""
|
@@ -73,7 +87,6 @@ def predict(file_obj):
|
|
73 |
|
74 |
# Read the file content
|
75 |
try:
|
76 |
-
# Handle both string and file object cases
|
77 |
if isinstance(file_obj, str):
|
78 |
text = file_obj
|
79 |
else:
|
@@ -81,6 +94,11 @@ def predict(file_obj):
|
|
81 |
except Exception as e:
|
82 |
return f"Error reading file: {str(e)}"
|
83 |
|
|
|
|
|
|
|
|
|
|
|
84 |
# Load model and scaler
|
85 |
try:
|
86 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
@@ -106,20 +124,47 @@ def predict(file_obj):
|
|
106 |
# Get k-mer vector
|
107 |
kmer_vector = sequence_to_kmer_vector(seq)
|
108 |
kmer_vector = scaler.transform(kmer_vector.reshape(1, -1))
|
|
|
109 |
|
110 |
-
#
|
111 |
with torch.no_grad():
|
112 |
-
output = model(
|
113 |
probs = torch.softmax(output, dim=1)
|
114 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
# Format results
|
116 |
pred_class = 1 if probs[0][1] > probs[0][0] else 0
|
117 |
pred_label = 'human' if pred_class == 1 else 'non-human'
|
|
|
118 |
result = f"""Sequence: {header}
|
119 |
Prediction: {pred_label}
|
120 |
Confidence: {float(max(probs[0])):0.4f}
|
121 |
Human probability: {float(probs[0][1]):0.4f}
|
122 |
-
Non-human probability: {float(probs[0][0]):0.4f}
|
|
|
|
|
|
|
|
|
|
|
|
|
123 |
results.append(result)
|
124 |
except Exception as e:
|
125 |
return f"Error processing sequences: {str(e)}"
|
@@ -136,4 +181,4 @@ iface = gr.Interface(
|
|
136 |
|
137 |
# Launch the interface
|
138 |
if __name__ == "__main__":
|
139 |
-
iface.launch()
|
|
|
24 |
|
25 |
def forward(self, x):
|
26 |
return self.network(x)
|
27 |
+
|
28 |
+
def get_feature_importance(self, x):
|
29 |
+
"""Calculate feature importance using gradient-based method"""
|
30 |
+
x.requires_grad_(True)
|
31 |
+
output = self.network(x)
|
32 |
+
importance = torch.zeros_like(x)
|
33 |
+
|
34 |
+
for i in range(output.shape[1]):
|
35 |
+
if x.grad is not None:
|
36 |
+
x.grad.zero_()
|
37 |
+
output[..., i].sum().backward(retain_graph=True)
|
38 |
+
importance += torch.abs(x.grad)
|
39 |
+
|
40 |
+
return importance
|
41 |
|
42 |
def sequence_to_kmer_vector(sequence: str, k: int = 4) -> np.ndarray:
|
43 |
"""Convert sequence to k-mer frequency vector"""
|
|
|
87 |
|
88 |
# Read the file content
|
89 |
try:
|
|
|
90 |
if isinstance(file_obj, str):
|
91 |
text = file_obj
|
92 |
else:
|
|
|
94 |
except Exception as e:
|
95 |
return f"Error reading file: {str(e)}"
|
96 |
|
97 |
+
# Generate k-mer dictionary
|
98 |
+
k = 4 # k-mer size
|
99 |
+
kmers = [''.join(p) for p in product("ACGT", repeat=k)]
|
100 |
+
kmer_dict = {km: i for i, km in enumerate(kmers)}
|
101 |
+
|
102 |
# Load model and scaler
|
103 |
try:
|
104 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
|
124 |
# Get k-mer vector
|
125 |
kmer_vector = sequence_to_kmer_vector(seq)
|
126 |
kmer_vector = scaler.transform(kmer_vector.reshape(1, -1))
|
127 |
+
X_tensor = torch.FloatTensor(kmer_vector).to(device)
|
128 |
|
129 |
+
# Get predictions and feature importance
|
130 |
with torch.no_grad():
|
131 |
+
output = model(X_tensor)
|
132 |
probs = torch.softmax(output, dim=1)
|
133 |
|
134 |
+
# Calculate feature importance
|
135 |
+
importance = model.get_feature_importance(X_tensor)
|
136 |
+
kmer_importance = importance[0].cpu().numpy()
|
137 |
+
|
138 |
+
# Weight importance by actual k-mer frequency
|
139 |
+
kmer_importance *= kmer_vector[0]
|
140 |
+
|
141 |
+
# Get top 10 k-mers
|
142 |
+
top_k = 10
|
143 |
+
top_indices = np.argsort(np.abs(kmer_importance))[-top_k:][::-1]
|
144 |
+
important_kmers = [
|
145 |
+
{
|
146 |
+
'kmer': list(kmer_dict.keys())[list(kmer_dict.values()).index(i)],
|
147 |
+
'importance': float(kmer_importance[i]),
|
148 |
+
'frequency': float(kmer_vector[0][i])
|
149 |
+
}
|
150 |
+
for i in top_indices
|
151 |
+
]
|
152 |
+
|
153 |
# Format results
|
154 |
pred_class = 1 if probs[0][1] > probs[0][0] else 0
|
155 |
pred_label = 'human' if pred_class == 1 else 'non-human'
|
156 |
+
|
157 |
result = f"""Sequence: {header}
|
158 |
Prediction: {pred_label}
|
159 |
Confidence: {float(max(probs[0])):0.4f}
|
160 |
Human probability: {float(probs[0][1]):0.4f}
|
161 |
+
Non-human probability: {float(probs[0][0]):0.4f}
|
162 |
+
|
163 |
+
Most influential k-mers:"""
|
164 |
+
|
165 |
+
for kmer in important_kmers:
|
166 |
+
result += f"\n {kmer['kmer']}: importance={kmer['importance']:.4f}, frequency={kmer['frequency']:.4f}"
|
167 |
+
|
168 |
results.append(result)
|
169 |
except Exception as e:
|
170 |
return f"Error processing sequences: {str(e)}"
|
|
|
181 |
|
182 |
# Launch the interface
|
183 |
if __name__ == "__main__":
|
184 |
+
iface.launch()
|