hiyata commited on
Commit
2243c0c
·
verified ·
1 Parent(s): 33ae1e8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +106 -61
app.py CHANGED
@@ -3,8 +3,12 @@ import torch
3
  import joblib
4
  import numpy as np
5
  from itertools import product
6
- from typing import Dict
7
  import torch.nn as nn
 
 
 
 
 
8
 
9
  class VirusClassifier(nn.Module):
10
  def __init__(self, input_shape: int):
@@ -28,84 +32,125 @@ class VirusClassifier(nn.Module):
28
 
29
  def sequence_to_kmer_vector(sequence: str, k: int = 4) -> np.ndarray:
30
  """Convert sequence to k-mer frequency vector"""
31
- kmers = [''.join(p) for p in product("ACGT", repeat=k)]
32
- kmer_dict = {kmer: 0 for kmer in kmers}
33
-
34
- for i in range(len(sequence) - k + 1):
35
- kmer = sequence[i:i+k]
36
- if kmer in kmer_dict: # only count valid kmers
37
- kmer_dict[kmer] += 1
38
-
39
- return np.array(list(kmer_dict.values()))
 
 
 
 
40
 
41
- def parse_fasta(fasta_content: str):
42
- """Parse FASTA format string"""
43
- sequences = []
44
- current_header = None
45
- current_sequence = []
46
-
47
- for line in fasta_content.split('\n'):
48
- line = line.strip()
49
- if not line:
50
- continue
51
- if line.startswith('>'):
52
- if current_header is not None:
53
- sequences.append((current_header, ''.join(current_sequence)))
54
- current_header = line[1:]
55
- current_sequence = []
56
- else:
57
- current_sequence.append(line.upper())
58
-
59
- if current_header is not None:
60
- sequences.append((current_header, ''.join(current_sequence)))
61
 
62
- return sequences
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
- def predict_sequence(fasta_content: str) -> str:
65
  """Process FASTA input and return formatted predictions"""
66
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
67
- k = 4
68
-
69
- # Load model and scaler
70
- model = VirusClassifier(256).to(device) # 256 = 4^4 for 4-mers
71
- model.load_state_dict(torch.load('model.pt', map_location=device))
72
- scaler = joblib.load('scaler.pkl')
73
- model.eval()
74
-
75
- # Process sequences
76
- sequences = parse_fasta(fasta_content)
77
- results = []
78
-
79
- for header, seq in sequences:
80
- # Convert sequence to k-mer vector
81
- kmer_vector = sequence_to_kmer_vector(seq, k)
82
- kmer_vector = scaler.transform(kmer_vector.reshape(1, -1))
83
 
84
- # Get prediction
85
- with torch.no_grad():
86
- output = model(torch.FloatTensor(kmer_vector).to(device))
87
- probs = torch.softmax(output, dim=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
- # Format result
90
- pred_class = 1 if probs[0][1] > probs[0][0] else 0
91
- pred_label = 'human' if pred_class == 1 else 'non-human'
92
 
93
- result = f"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  Sequence: {header}
95
  Prediction: {pred_label}
96
  Confidence: {float(max(probs[0])):0.4f}
97
  Human probability: {float(probs[0][1]):0.4f}
98
  Non-human probability: {float(probs[0][0]):0.4f}
99
  """
100
- results.append(result)
101
-
102
- return "\n".join(results)
 
 
 
 
 
 
 
 
 
103
 
104
  # Create Gradio interface
105
  iface = gr.Interface(
106
  fn=predict_sequence,
107
  inputs=gr.File(label="Upload FASTA file", file_types=[".fasta", ".fa", ".txt"]),
108
- outputs=gr.Textbox(label="Prediction Results"),
109
  title="Virus Host Classifier",
110
  description="Upload a FASTA file to predict whether a virus sequence is likely to infect human or non-human hosts.",
111
  examples=[["example.fasta"]],
 
3
  import joblib
4
  import numpy as np
5
  from itertools import product
 
6
  import torch.nn as nn
7
+ import logging
8
+
9
+ # Set up logging
10
+ logging.basicConfig(level=logging.INFO)
11
+ logger = logging.getLogger(__name__)
12
 
13
  class VirusClassifier(nn.Module):
14
  def __init__(self, input_shape: int):
 
32
 
33
  def sequence_to_kmer_vector(sequence: str, k: int = 4) -> np.ndarray:
34
  """Convert sequence to k-mer frequency vector"""
35
+ try:
36
+ kmers = [''.join(p) for p in product("ACGT", repeat=k)]
37
+ kmer_dict = {kmer: 0 for kmer in kmers}
38
+
39
+ for i in range(len(sequence) - k + 1):
40
+ kmer = sequence[i:i+k]
41
+ if kmer in kmer_dict: # only count valid kmers
42
+ kmer_dict[kmer] += 1
43
+
44
+ return np.array(list(kmer_dict.values()))
45
+ except Exception as e:
46
+ logger.error(f"Error in sequence_to_kmer_vector: {str(e)}")
47
+ raise
48
 
49
+ def parse_fasta(file_obj) -> list:
50
+ """Parse FASTA format from file object"""
51
+ try:
52
+ # Read the content from the file object
53
+ content = file_obj.decode('utf-8')
54
+ logger.info(f"Received file content length: {len(content)}")
55
+
56
+ sequences = []
57
+ current_header = None
58
+ current_sequence = []
 
 
 
 
 
 
 
 
 
 
59
 
60
+ for line in content.split('\n'):
61
+ line = line.strip()
62
+ if not line:
63
+ continue
64
+ if line.startswith('>'):
65
+ if current_header is not None:
66
+ sequences.append((current_header, ''.join(current_sequence)))
67
+ current_header = line[1:]
68
+ current_sequence = []
69
+ else:
70
+ current_sequence.append(line.upper())
71
+
72
+ if current_header is not None:
73
+ sequences.append((current_header, ''.join(current_sequence)))
74
+
75
+ logger.info(f"Parsed {len(sequences)} sequences from FASTA")
76
+ return sequences
77
+ except Exception as e:
78
+ logger.error(f"Error parsing FASTA: {str(e)}")
79
+ raise
80
 
81
+ def predict_sequence(file_obj) -> str:
82
  """Process FASTA input and return formatted predictions"""
83
+ try:
84
+ logger.info("Starting prediction process")
85
+
86
+ if file_obj is None:
87
+ return "Please upload a FASTA file"
88
+
89
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
90
+ logger.info(f"Using device: {device}")
91
+ k = 4
 
 
 
 
 
 
 
 
92
 
93
+ # Load model and scaler
94
+ try:
95
+ logger.info("Loading model and scaler")
96
+ model = VirusClassifier(256).to(device) # 256 = 4^4 for 4-mers
97
+ model.load_state_dict(torch.load('model.pt', map_location=device))
98
+ scaler = joblib.load('scaler.pkl')
99
+ model.eval()
100
+ except Exception as e:
101
+ logger.error(f"Error loading model or scaler: {str(e)}")
102
+ return f"Error loading model: {str(e)}"
103
+
104
+ # Process sequences
105
+ try:
106
+ sequences = parse_fasta(file_obj)
107
+ except Exception as e:
108
+ logger.error(f"Error parsing FASTA file: {str(e)}")
109
+ return f"Error parsing FASTA file: {str(e)}"
110
 
111
+ results = []
 
 
112
 
113
+ for header, seq in sequences:
114
+ logger.info(f"Processing sequence: {header}")
115
+ try:
116
+ # Convert sequence to k-mer vector
117
+ kmer_vector = sequence_to_kmer_vector(seq, k)
118
+ kmer_vector = scaler.transform(kmer_vector.reshape(1, -1))
119
+
120
+ # Get prediction
121
+ with torch.no_grad():
122
+ output = model(torch.FloatTensor(kmer_vector).to(device))
123
+ probs = torch.softmax(output, dim=1)
124
+
125
+ # Format result
126
+ pred_class = 1 if probs[0][1] > probs[0][0] else 0
127
+ pred_label = 'human' if pred_class == 1 else 'non-human'
128
+
129
+ result = f"""
130
  Sequence: {header}
131
  Prediction: {pred_label}
132
  Confidence: {float(max(probs[0])):0.4f}
133
  Human probability: {float(probs[0][1]):0.4f}
134
  Non-human probability: {float(probs[0][0]):0.4f}
135
  """
136
+ results.append(result)
137
+ logger.info(f"Processed sequence {header} successfully")
138
+
139
+ except Exception as e:
140
+ logger.error(f"Error processing sequence {header}: {str(e)}")
141
+ results.append(f"Error processing sequence {header}: {str(e)}")
142
+
143
+ return "\n".join(results)
144
+
145
+ except Exception as e:
146
+ logger.error(f"Unexpected error in predict_sequence: {str(e)}")
147
+ return f"An unexpected error occurred: {str(e)}"
148
 
149
  # Create Gradio interface
150
  iface = gr.Interface(
151
  fn=predict_sequence,
152
  inputs=gr.File(label="Upload FASTA file", file_types=[".fasta", ".fa", ".txt"]),
153
+ outputs=gr.Textbox(label="Prediction Results", lines=10),
154
  title="Virus Host Classifier",
155
  description="Upload a FASTA file to predict whether a virus sequence is likely to infect human or non-human hosts.",
156
  examples=[["example.fasta"]],