hiyata commited on
Commit
723da6d
·
verified ·
1 Parent(s): 63d967d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -39
app.py CHANGED
@@ -25,15 +25,13 @@ class VirusClassifier(nn.Module):
25
  def forward(self, x):
26
  return self.network(x)
27
 
28
- def sequence_to_kmer_vector(sequence: str, k: int = 6) -> np.ndarray:
29
  kmers = [''.join(p) for p in product("ACGT", repeat=k)]
30
  kmer_dict = {kmer: 0 for kmer in kmers}
31
-
32
  for i in range(len(sequence) - k + 1):
33
  kmer = sequence[i:i+k]
34
  if kmer in kmer_dict:
35
  kmer_dict[kmer] += 1
36
-
37
  return np.array(list(kmer_dict.values()))
38
 
39
  def parse_fasta(text):
@@ -52,10 +50,8 @@ def parse_fasta(text):
52
  current_sequence = []
53
  else:
54
  current_sequence.append(line.upper())
55
-
56
  if current_header:
57
  sequences.append((current_header, ''.join(current_sequence)))
58
-
59
  return sequences
60
 
61
  def predict(file_obj):
@@ -63,51 +59,61 @@ def predict(file_obj):
63
  return "Please upload a FASTA file"
64
 
65
  # Read the file content
66
- text = file_obj.read().decode()
67
-
 
 
 
 
 
 
 
68
  # Load model and scaler
69
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
70
- model = VirusClassifier(4096).to(device)
71
- model.load_state_dict(torch.load('model.pt', map_location=device))
72
- scaler = joblib.load('scaler.pkl')
73
- model.eval()
74
-
 
 
 
75
  # Get predictions
76
  results = []
77
- sequences = parse_fasta(text)
78
-
79
- for header, seq in sequences:
80
- # Get k-mer vector
81
- kmer_vector = sequence_to_kmer_vector(seq)
82
- kmer_vector = scaler.transform(kmer_vector.reshape(1, -1))
83
-
84
- # Predict
85
- with torch.no_grad():
86
- output = model(torch.FloatTensor(kmer_vector).to(device))
87
- probs = torch.softmax(output, dim=1)
88
-
89
- # Format results
90
- pred_class = 1 if probs[0][1] > probs[0][0] else 0
91
- pred_label = 'human' if pred_class == 1 else 'non-human'
92
-
93
- result = f"""
94
- Sequence: {header}
95
  Prediction: {pred_label}
96
  Confidence: {float(max(probs[0])):0.4f}
97
  Human probability: {float(probs[0][1]):0.4f}
98
- Non-human probability: {float(probs[0][0]):0.4f}
99
- """
100
- results.append(result)
101
-
102
- return "\n".join(results)
 
103
 
104
  # Create the interface
105
  iface = gr.Interface(
106
  fn=predict,
107
- inputs=gr.File(label="Upload FASTA file"),
108
  outputs=gr.Textbox(label="Results"),
109
  title="Virus Host Classifier"
110
  )
111
 
112
- # Launch with public link
113
- iface.launch(share=True)
 
 
25
  def forward(self, x):
26
  return self.network(x)
27
 
28
+ def sequence_to_kmer_vector(sequence: str, k: int = 4) -> np.ndarray:
29
  kmers = [''.join(p) for p in product("ACGT", repeat=k)]
30
  kmer_dict = {kmer: 0 for kmer in kmers}
 
31
  for i in range(len(sequence) - k + 1):
32
  kmer = sequence[i:i+k]
33
  if kmer in kmer_dict:
34
  kmer_dict[kmer] += 1
 
35
  return np.array(list(kmer_dict.values()))
36
 
37
  def parse_fasta(text):
 
50
  current_sequence = []
51
  else:
52
  current_sequence.append(line.upper())
 
53
  if current_header:
54
  sequences.append((current_header, ''.join(current_sequence)))
 
55
  return sequences
56
 
57
  def predict(file_obj):
 
59
  return "Please upload a FASTA file"
60
 
61
  # Read the file content
62
+ try:
63
+ # Handle both string and file object cases
64
+ if isinstance(file_obj, str):
65
+ text = file_obj
66
+ else:
67
+ text = file_obj.decode('utf-8')
68
+ except Exception as e:
69
+ return f"Error reading file: {str(e)}"
70
+
71
  # Load model and scaler
72
+ try:
73
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
74
+ model = VirusClassifier(4096).to(device)
75
+ model.load_state_dict(torch.load('model.pt', map_location=device))
76
+ scaler = joblib.load('scaler.pkl')
77
+ model.eval()
78
+ except Exception as e:
79
+ return f"Error loading model: {str(e)}"
80
+
81
  # Get predictions
82
  results = []
83
+ try:
84
+ sequences = parse_fasta(text)
85
+ for header, seq in sequences:
86
+ # Get k-mer vector
87
+ kmer_vector = sequence_to_kmer_vector(seq)
88
+ kmer_vector = scaler.transform(kmer_vector.reshape(1, -1))
89
+
90
+ # Predict
91
+ with torch.no_grad():
92
+ output = model(torch.FloatTensor(kmer_vector).to(device))
93
+ probs = torch.softmax(output, dim=1)
94
+
95
+ # Format results
96
+ pred_class = 1 if probs[0][1] > probs[0][0] else 0
97
+ pred_label = 'human' if pred_class == 1 else 'non-human'
98
+ result = f"""Sequence: {header}
 
 
99
  Prediction: {pred_label}
100
  Confidence: {float(max(probs[0])):0.4f}
101
  Human probability: {float(probs[0][1]):0.4f}
102
+ Non-human probability: {float(probs[0][0]):0.4f}"""
103
+ results.append(result)
104
+ except Exception as e:
105
+ return f"Error processing sequences: {str(e)}"
106
+
107
+ return "\n\n".join(results)
108
 
109
  # Create the interface
110
  iface = gr.Interface(
111
  fn=predict,
112
+ inputs=gr.File(label="Upload FASTA file", type="binary"),
113
  outputs=gr.Textbox(label="Results"),
114
  title="Virus Host Classifier"
115
  )
116
 
117
+ # Launch the interface
118
+ if __name__ == "__main__":
119
+ iface.launch() # Remove share=True for Hugging Face Spaces