hiyata commited on
Commit
9d48283
·
verified ·
1 Parent(s): 93d6a57

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -7
app.py CHANGED
@@ -25,14 +25,27 @@ class VirusClassifier(nn.Module):
25
  def forward(self, x):
26
  return self.network(x)
27
 
28
- def sequence_to_kmer_vector(sequence: str, k: int = 6) -> np.ndarray:
 
 
29
  kmers = [''.join(p) for p in product("ACGT", repeat=k)]
30
- kmer_dict = {kmer: 0 for kmer in kmers}
 
 
 
 
 
31
  for i in range(len(sequence) - k + 1):
32
  kmer = sequence[i:i+k]
33
  if kmer in kmer_dict:
34
- kmer_dict[kmer] += 1
35
- return np.array(list(kmer_dict.values()))
 
 
 
 
 
 
36
 
37
  def parse_fasta(text):
38
  sequences = []
@@ -71,12 +84,19 @@ def predict(file_obj):
71
  # Load model and scaler
72
  try:
73
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
74
- model = VirusClassifier(4096).to(device)
75
- model.load_state_dict(torch.load('model.pt', map_location=device))
 
 
 
 
 
76
  scaler = joblib.load('scaler.pkl')
 
 
77
  model.eval()
78
  except Exception as e:
79
- return f"Error loading model: {str(e)}"
80
 
81
  # Get predictions
82
  results = []
 
25
  def forward(self, x):
26
  return self.network(x)
27
 
28
+ def sequence_to_kmer_vector(sequence: str, k: int = 4) -> np.ndarray:
29
+ """Convert sequence to k-mer frequency vector"""
30
+ # Generate all possible k-mers
31
  kmers = [''.join(p) for p in product("ACGT", repeat=k)]
32
+ kmer_dict = {km: i for i, km in enumerate(kmers)}
33
+
34
+ # Initialize vector
35
+ vec = np.zeros(len(kmers), dtype=np.float32)
36
+
37
+ # Count k-mers
38
  for i in range(len(sequence) - k + 1):
39
  kmer = sequence[i:i+k]
40
  if kmer in kmer_dict:
41
+ vec[kmer_dict[kmer]] += 1
42
+
43
+ # Convert to frequencies
44
+ total_kmers = len(sequence) - k + 1
45
+ if total_kmers > 0:
46
+ vec = vec / total_kmers
47
+
48
+ return vec
49
 
50
  def parse_fasta(text):
51
  sequences = []
 
84
  # Load model and scaler
85
  try:
86
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
87
+ model = VirusClassifier(256).to(device) # k=4 -> 4^4 = 256 features
88
+
89
+ # Load model with explicit map_location
90
+ state_dict = torch.load('model.pt', map_location=device)
91
+ model.load_state_dict(state_dict)
92
+
93
+ # Load scaler
94
  scaler = joblib.load('scaler.pkl')
95
+
96
+ # Set model to evaluation mode
97
  model.eval()
98
  except Exception as e:
99
+ return f"Error loading model: {str(e)}\nFull traceback: {str(e.__traceback__)}"
100
 
101
  # Get predictions
102
  results = []