HostClassifier / app.py
hiyata's picture
Create app.py
5263bd3 verified
raw
history blame
3.64 kB
import gradio as gr
import torch
import joblib
import numpy as np
from itertools import product
from typing import Dict
import torch.nn as nn
class VirusClassifier(nn.Module):
def __init__(self, input_shape: int):
super(VirusClassifier, self).__init__()
self.network = nn.Sequential(
nn.Linear(input_shape, 64),
nn.GELU(),
nn.BatchNorm1d(64),
nn.Dropout(0.3),
nn.Linear(64, 32),
nn.GELU(),
nn.BatchNorm1d(32),
nn.Dropout(0.3),
nn.Linear(32, 32),
nn.GELU(),
nn.Linear(32, 2)
)
def forward(self, x):
return self.network(x)
def sequence_to_kmer_vector(sequence: str, k: int = 6) -> np.ndarray:
"""Convert sequence to k-mer frequency vector"""
kmers = [''.join(p) for p in product("ACGT", repeat=k)]
kmer_dict = {kmer: 0 for kmer in kmers}
for i in range(len(sequence) - k + 1):
kmer = sequence[i:i+k]
if kmer in kmer_dict: # only count valid kmers
kmer_dict[kmer] += 1
return np.array(list(kmer_dict.values()))
def parse_fasta(fasta_content: str):
"""Parse FASTA format string"""
sequences = []
current_header = None
current_sequence = []
for line in fasta_content.split('\n'):
line = line.strip()
if not line:
continue
if line.startswith('>'):
if current_header is not None:
sequences.append((current_header, ''.join(current_sequence)))
current_header = line[1:]
current_sequence = []
else:
current_sequence.append(line.upper())
if current_header is not None:
sequences.append((current_header, ''.join(current_sequence)))
return sequences
def predict_sequence(fasta_content: str) -> str:
"""Process FASTA input and return formatted predictions"""
device = 'cuda' if torch.cuda.is_available() else 'cpu'
k = 6
# Load model and scaler
model = VirusClassifier(4096).to(device) # 4096 = 4^6 for 6-mers
model.load_state_dict(torch.load('model.pt', map_location=device))
scaler = joblib.load('scaler.pkl')
model.eval()
# Process sequences
sequences = parse_fasta(fasta_content)
results = []
for header, seq in sequences:
# Convert sequence to k-mer vector
kmer_vector = sequence_to_kmer_vector(seq, k)
kmer_vector = scaler.transform(kmer_vector.reshape(1, -1))
# Get prediction
with torch.no_grad():
output = model(torch.FloatTensor(kmer_vector).to(device))
probs = torch.softmax(output, dim=1)
# Format result
pred_class = 1 if probs[0][1] > probs[0][0] else 0
pred_label = 'human' if pred_class == 1 else 'non-human'
result = f"""
Sequence: {header}
Prediction: {pred_label}
Confidence: {float(max(probs[0])):0.4f}
Human probability: {float(probs[0][1]):0.4f}
Non-human probability: {float(probs[0][0]):0.4f}
"""
results.append(result)
return "\n".join(results)
# Create Gradio interface
iface = gr.Interface(
fn=predict_sequence,
inputs=gr.File(label="Upload FASTA file", file_types=[".fasta", ".fa", ".txt"]),
outputs=gr.Textbox(label="Prediction Results"),
title="Virus Host Classifier",
description="Upload a FASTA file to predict whether a virus sequence is likely to infect human or non-human hosts.",
examples=[["example.fasta"]],
cache_examples=True
)
# Launch the interface
iface.launch()