hiyata commited on
Commit
cdd8a58
·
verified ·
1 Parent(s): 18779ec

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -5
app.py CHANGED
@@ -24,6 +24,20 @@ class VirusClassifier(nn.Module):
24
 
25
  def forward(self, x):
26
  return self.network(x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
  def sequence_to_kmer_vector(sequence: str, k: int = 4) -> np.ndarray:
29
  """Convert sequence to k-mer frequency vector"""
@@ -73,7 +87,6 @@ def predict(file_obj):
73
 
74
  # Read the file content
75
  try:
76
- # Handle both string and file object cases
77
  if isinstance(file_obj, str):
78
  text = file_obj
79
  else:
@@ -81,6 +94,11 @@ def predict(file_obj):
81
  except Exception as e:
82
  return f"Error reading file: {str(e)}"
83
 
 
 
 
 
 
84
  # Load model and scaler
85
  try:
86
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
@@ -106,20 +124,47 @@ def predict(file_obj):
106
  # Get k-mer vector
107
  kmer_vector = sequence_to_kmer_vector(seq)
108
  kmer_vector = scaler.transform(kmer_vector.reshape(1, -1))
 
109
 
110
- # Predict
111
  with torch.no_grad():
112
- output = model(torch.FloatTensor(kmer_vector).to(device))
113
  probs = torch.softmax(output, dim=1)
114
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  # Format results
116
  pred_class = 1 if probs[0][1] > probs[0][0] else 0
117
  pred_label = 'human' if pred_class == 1 else 'non-human'
 
118
  result = f"""Sequence: {header}
119
  Prediction: {pred_label}
120
  Confidence: {float(max(probs[0])):0.4f}
121
  Human probability: {float(probs[0][1]):0.4f}
122
- Non-human probability: {float(probs[0][0]):0.4f}"""
 
 
 
 
 
 
123
  results.append(result)
124
  except Exception as e:
125
  return f"Error processing sequences: {str(e)}"
@@ -136,4 +181,4 @@ iface = gr.Interface(
136
 
137
  # Launch the interface
138
  if __name__ == "__main__":
139
- iface.launch() # Remove share=True for Hugging Face Spaces
 
24
 
25
  def forward(self, x):
26
  return self.network(x)
27
+
28
+ def get_feature_importance(self, x):
29
+ """Calculate feature importance using gradient-based method"""
30
+ x.requires_grad_(True)
31
+ output = self.network(x)
32
+ importance = torch.zeros_like(x)
33
+
34
+ for i in range(output.shape[1]):
35
+ if x.grad is not None:
36
+ x.grad.zero_()
37
+ output[..., i].sum().backward(retain_graph=True)
38
+ importance += torch.abs(x.grad)
39
+
40
+ return importance
41
 
42
  def sequence_to_kmer_vector(sequence: str, k: int = 4) -> np.ndarray:
43
  """Convert sequence to k-mer frequency vector"""
 
87
 
88
  # Read the file content
89
  try:
 
90
  if isinstance(file_obj, str):
91
  text = file_obj
92
  else:
 
94
  except Exception as e:
95
  return f"Error reading file: {str(e)}"
96
 
97
+ # Generate k-mer dictionary
98
+ k = 4 # k-mer size
99
+ kmers = [''.join(p) for p in product("ACGT", repeat=k)]
100
+ kmer_dict = {km: i for i, km in enumerate(kmers)}
101
+
102
  # Load model and scaler
103
  try:
104
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
 
124
  # Get k-mer vector
125
  kmer_vector = sequence_to_kmer_vector(seq)
126
  kmer_vector = scaler.transform(kmer_vector.reshape(1, -1))
127
+ X_tensor = torch.FloatTensor(kmer_vector).to(device)
128
 
129
+ # Get predictions and feature importance
130
  with torch.no_grad():
131
+ output = model(X_tensor)
132
  probs = torch.softmax(output, dim=1)
133
 
134
+ # Calculate feature importance
135
+ importance = model.get_feature_importance(X_tensor)
136
+ kmer_importance = importance[0].cpu().numpy()
137
+
138
+ # Weight importance by actual k-mer frequency
139
+ kmer_importance *= kmer_vector[0]
140
+
141
+ # Get top 10 k-mers
142
+ top_k = 10
143
+ top_indices = np.argsort(np.abs(kmer_importance))[-top_k:][::-1]
144
+ important_kmers = [
145
+ {
146
+ 'kmer': list(kmer_dict.keys())[list(kmer_dict.values()).index(i)],
147
+ 'importance': float(kmer_importance[i]),
148
+ 'frequency': float(kmer_vector[0][i])
149
+ }
150
+ for i in top_indices
151
+ ]
152
+
153
  # Format results
154
  pred_class = 1 if probs[0][1] > probs[0][0] else 0
155
  pred_label = 'human' if pred_class == 1 else 'non-human'
156
+
157
  result = f"""Sequence: {header}
158
  Prediction: {pred_label}
159
  Confidence: {float(max(probs[0])):0.4f}
160
  Human probability: {float(probs[0][1]):0.4f}
161
+ Non-human probability: {float(probs[0][0]):0.4f}
162
+
163
+ Most influential k-mers:"""
164
+
165
+ for kmer in important_kmers:
166
+ result += f"\n {kmer['kmer']}: importance={kmer['importance']:.4f}, frequency={kmer['frequency']:.4f}"
167
+
168
  results.append(result)
169
  except Exception as e:
170
  return f"Error processing sequences: {str(e)}"
 
181
 
182
  # Launch the interface
183
  if __name__ == "__main__":
184
+ iface.launch()