hiyata commited on
Commit
33ae1e8
·
verified ·
1 Parent(s): 6168c46

Change from kmer 6 to kmer 4

Browse files
Files changed (1) hide show
  1. app.py +3 -3
app.py CHANGED
@@ -26,7 +26,7 @@ class VirusClassifier(nn.Module):
26
  def forward(self, x):
27
  return self.network(x)
28
 
29
- def sequence_to_kmer_vector(sequence: str, k: int = 6) -> np.ndarray:
30
  """Convert sequence to k-mer frequency vector"""
31
  kmers = [''.join(p) for p in product("ACGT", repeat=k)]
32
  kmer_dict = {kmer: 0 for kmer in kmers}
@@ -64,10 +64,10 @@ def parse_fasta(fasta_content: str):
64
  def predict_sequence(fasta_content: str) -> str:
65
  """Process FASTA input and return formatted predictions"""
66
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
67
- k = 6
68
 
69
  # Load model and scaler
70
- model = VirusClassifier(4096).to(device) # 4096 = 4^6 for 6-mers
71
  model.load_state_dict(torch.load('model.pt', map_location=device))
72
  scaler = joblib.load('scaler.pkl')
73
  model.eval()
 
26
  def forward(self, x):
27
  return self.network(x)
28
 
29
+ def sequence_to_kmer_vector(sequence: str, k: int = 4) -> np.ndarray:
30
  """Convert sequence to k-mer frequency vector"""
31
  kmers = [''.join(p) for p in product("ACGT", repeat=k)]
32
  kmer_dict = {kmer: 0 for kmer in kmers}
 
64
  def predict_sequence(fasta_content: str) -> str:
65
  """Process FASTA input and return formatted predictions"""
66
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
67
+ k = 4
68
 
69
  # Load model and scaler
70
+ model = VirusClassifier(256).to(device) # 256 = 4^4 for 4-mers
71
  model.load_state_dict(torch.load('model.pt', map_location=device))
72
  scaler = joblib.load('scaler.pkl')
73
  model.eval()