hiyata commited on
Commit
870813f
·
verified ·
1 Parent(s): 4e29ba7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +68 -128
app.py CHANGED
@@ -4,11 +4,6 @@ import joblib
4
  import numpy as np
5
  from itertools import product
6
  import torch.nn as nn
7
- import logging
8
-
9
- # Set up logging
10
- logging.basicConfig(level=logging.INFO)
11
- logger = logging.getLogger(__name__)
12
 
13
  class VirusClassifier(nn.Module):
14
  def __init__(self, input_shape: int):
@@ -31,143 +26,88 @@ class VirusClassifier(nn.Module):
31
  return self.network(x)
32
 
33
  def sequence_to_kmer_vector(sequence: str, k: int = 6) -> np.ndarray:
34
- """Convert sequence to k-mer frequency vector"""
35
- try:
36
- kmers = [''.join(p) for p in product("ACGT", repeat=k)]
37
- kmer_dict = {kmer: 0 for kmer in kmers}
38
-
39
- for i in range(len(sequence) - k + 1):
40
- kmer = sequence[i:i+k]
41
- if kmer in kmer_dict: # only count valid kmers
42
- kmer_dict[kmer] += 1
43
-
44
- return np.array(list(kmer_dict.values()))
45
- except Exception as e:
46
- logger.error(f"Error in sequence_to_kmer_vector: {str(e)}")
47
- raise
48
-
49
- def parse_fasta(content: str) -> list:
50
- """Parse FASTA format from string content"""
51
- try:
52
- logger.info(f"Received file content length: {len(content)}")
53
-
54
- sequences = []
55
- current_header = None
56
- current_sequence = []
57
-
58
- for line in content.split('\n'):
59
- line = line.strip()
60
- if not line:
61
- continue
62
- if line.startswith('>'):
63
- if current_header is not None:
64
- sequences.append((current_header, ''.join(current_sequence)))
65
- current_header = line[1:]
66
- current_sequence = []
67
- else:
68
- current_sequence.append(line.upper())
69
-
70
- if current_header is not None:
71
- sequences.append((current_header, ''.join(current_sequence)))
72
-
73
- logger.info(f"Parsed {len(sequences)} sequences from FASTA")
74
- return sequences
75
- except Exception as e:
76
- logger.error(f"Error parsing FASTA: {str(e)}")
77
- raise
78
 
79
- def predict_sequence(fasta_file) -> str:
80
- """Process FASTA input and return formatted predictions"""
81
- try:
82
- logger.info("Starting prediction process")
83
-
84
- if fasta_file is None:
85
- return "Please upload a FASTA file"
 
 
 
 
 
 
 
 
 
86
 
87
- # Get file content - handle both string and file inputs
88
- try:
89
- if isinstance(fasta_file, str):
90
- content = fasta_file
91
- else:
92
- content = fasta_file.name # For Gradio file upload
93
- except Exception as e:
94
- logger.error(f"Error reading file: {str(e)}")
95
- return f"Error reading file: {str(e)}"
96
 
97
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
98
- logger.info(f"Using device: {device}")
99
- k = 4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
- # Load model and scaler
102
- try:
103
- logger.info("Loading model and scaler")
104
- model = VirusClassifier(256).to(device) # 256 = 4^4 for 4-mers
105
- model.load_state_dict(torch.load('model.pt', map_location=device))
106
- scaler = joblib.load('scaler.pkl')
107
- model.eval()
108
- except Exception as e:
109
- logger.error(f"Error loading model or scaler: {str(e)}")
110
- return f"Error loading model: {str(e)}"
111
 
112
- # Process sequences
113
- try:
114
- sequences = parse_fasta(content)
115
- except Exception as e:
116
- logger.error(f"Error parsing FASTA file: {str(e)}")
117
- return f"Error parsing FASTA file: {str(e)}"
118
-
119
- results = []
120
 
121
- for header, seq in sequences:
122
- logger.info(f"Processing sequence: {header}")
123
- try:
124
- # Convert sequence to k-mer vector
125
- kmer_vector = sequence_to_kmer_vector(seq, k)
126
- kmer_vector = scaler.transform(kmer_vector.reshape(1, -1))
127
-
128
- # Get prediction
129
- with torch.no_grad():
130
- output = model(torch.FloatTensor(kmer_vector).to(device))
131
- probs = torch.softmax(output, dim=1)
132
-
133
- # Format result
134
- pred_class = 1 if probs[0][1] > probs[0][0] else 0
135
- pred_label = 'human' if pred_class == 1 else 'non-human'
136
-
137
- result = f"""
138
  Sequence: {header}
139
  Prediction: {pred_label}
140
  Confidence: {float(max(probs[0])):0.4f}
141
  Human probability: {float(probs[0][1]):0.4f}
142
  Non-human probability: {float(probs[0][0]):0.4f}
143
  """
144
- results.append(result)
145
- logger.info(f"Processed sequence {header} successfully")
146
-
147
- except Exception as e:
148
- logger.error(f"Error processing sequence {header}: {str(e)}")
149
- results.append(f"Error processing sequence {header}: {str(e)}")
150
-
151
- return "\n".join(results)
152
-
153
- except Exception as e:
154
- logger.error(f"Unexpected error in predict_sequence: {str(e)}")
155
- return f"An unexpected error occurred: {str(e)}"
156
 
157
- # Create Gradio interface with both file upload and text input
158
  iface = gr.Interface(
159
- fn=predict_sequence,
160
- inputs=gr.File(label="Upload FASTA file", file_types=[".fasta", ".fa", ".txt"])
161
- outputs=gr.Textbox(label="Prediction Results", lines=10),
162
- title="Virus Host Classifier",
163
- description="""Upload a FASTA file or paste your sequence to predict whether a virus sequence is likely to infect human or non-human hosts.
164
-
165
- Example format:
166
- >sequence_name
167
- ATCGATCGATCG...""",
168
- examples=[["example.fasta", None]],
169
- cache_examples=True
170
  )
171
 
172
- # Launch the interface
173
  iface.launch()
 
4
  import numpy as np
5
  from itertools import product
6
  import torch.nn as nn
 
 
 
 
 
7
 
8
  class VirusClassifier(nn.Module):
9
  def __init__(self, input_shape: int):
 
26
  return self.network(x)
27
 
28
  def sequence_to_kmer_vector(sequence: str, k: int = 6) -> np.ndarray:
29
+ kmers = [''.join(p) for p in product("ACGT", repeat=k)]
30
+ kmer_dict = {kmer: 0 for kmer in kmers}
31
+
32
+ for i in range(len(sequence) - k + 1):
33
+ kmer = sequence[i:i+k]
34
+ if kmer in kmer_dict:
35
+ kmer_dict[kmer] += 1
36
+
37
+ return np.array(list(kmer_dict.values()))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
+ def parse_fasta(text):
40
+ sequences = []
41
+ current_header = None
42
+ current_sequence = []
43
+
44
+ for line in text.split('\n'):
45
+ line = line.strip()
46
+ if not line:
47
+ continue
48
+ if line.startswith('>'):
49
+ if current_header:
50
+ sequences.append((current_header, ''.join(current_sequence)))
51
+ current_header = line[1:]
52
+ current_sequence = []
53
+ else:
54
+ current_sequence.append(line.upper())
55
 
56
+ if current_header:
57
+ sequences.append((current_header, ''.join(current_sequence)))
 
 
 
 
 
 
 
58
 
59
+ return sequences
60
+
61
+ def predict(file):
62
+ if not file:
63
+ return "Please upload a FASTA file"
64
+
65
+ # Read the file
66
+ text = file.decode('utf-8')
67
+
68
+ # Load model and scaler
69
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
70
+ model = VirusClassifier(4096).to(device)
71
+ model.load_state_dict(torch.load('model.pt', map_location=device))
72
+ scaler = joblib.load('scaler.pkl')
73
+ model.eval()
74
+
75
+ # Get predictions
76
+ results = []
77
+ sequences = parse_fasta(text)
78
+
79
+ for header, seq in sequences:
80
+ # Get k-mer vector
81
+ kmer_vector = sequence_to_kmer_vector(seq)
82
+ kmer_vector = scaler.transform(kmer_vector.reshape(1, -1))
83
 
84
+ # Predict
85
+ with torch.no_grad():
86
+ output = model(torch.FloatTensor(kmer_vector).to(device))
87
+ probs = torch.softmax(output, dim=1)
 
 
 
 
 
 
88
 
89
+ # Format results
90
+ pred_class = 1 if probs[0][1] > probs[0][0] else 0
91
+ pred_label = 'human' if pred_class == 1 else 'non-human'
 
 
 
 
 
92
 
93
+ result = f"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  Sequence: {header}
95
  Prediction: {pred_label}
96
  Confidence: {float(max(probs[0])):0.4f}
97
  Human probability: {float(probs[0][1]):0.4f}
98
  Non-human probability: {float(probs[0][0]):0.4f}
99
  """
100
+ results.append(result)
101
+
102
+ return "\n".join(results)
 
 
 
 
 
 
 
 
 
103
 
104
+ # Create the interface
105
  iface = gr.Interface(
106
+ fn=predict,
107
+ inputs=gr.File(label="Upload FASTA file"),
108
+ outputs=gr.Textbox(label="Results"),
109
+ title="Virus Host Classifier"
 
 
 
 
 
 
 
110
  )
111
 
112
+ # Launch
113
  iface.launch()