Spaces:
Runtime error
Runtime error
import gradio as gr | |
from transformers import AutoTokenizer, AutoModelForMaskedLM | |
import torch | |
from torch.distributions.categorical import Categorical | |
import numpy as np | |
import pandas as pd | |
# Load the model and tokenizer | |
tokenizer = AutoTokenizer.from_pretrained("ChatterjeeLab/PepMLM-650M") | |
model = AutoModelForMaskedLM.from_pretrained("ChatterjeeLab/PepMLM-650M") | |
def compute_pseudo_perplexity(model, tokenizer, protein_seq, binder_seq): | |
sequence = protein_seq + binder_seq | |
tensor_input = tokenizer.encode(sequence, return_tensors='pt').to(model.device) | |
total_loss = 0 | |
# Loop through each token in the binder sequence | |
for i in range(-len(binder_seq)-1, -1): | |
# Create a copy of the original tensor | |
masked_input = tensor_input.clone() | |
# Mask one token at a time | |
masked_input[0, i] = tokenizer.mask_token_id | |
# Create labels | |
labels = torch.full(tensor_input.shape, -100).to(model.device) | |
labels[0, i] = tensor_input[0, i] | |
# Get model prediction and loss | |
with torch.no_grad(): | |
outputs = model(masked_input, labels=labels) | |
total_loss += outputs.loss.item() | |
# Calculate the average loss | |
avg_loss = total_loss / len(binder_seq) | |
# Calculate pseudo perplexity | |
pseudo_perplexity = np.exp(avg_loss) | |
return pseudo_perplexity | |
def generate_peptide(protein_seq, peptide_length, top_k, num_binders): | |
peptide_length = int(peptide_length) | |
top_k = int(top_k) | |
num_binders = int(num_binders) | |
binders_with_ppl = [] | |
for _ in range(num_binders): | |
# Generate binder | |
masked_peptide = '<mask>' * peptide_length | |
input_sequence = protein_seq + masked_peptide | |
inputs = tokenizer(input_sequence, return_tensors="pt").to(model.device) | |
with torch.no_grad(): | |
logits = model(**inputs).logits | |
mask_token_indices = (inputs["input_ids"] == tokenizer.mask_token_id).nonzero(as_tuple=True)[1] | |
logits_at_masks = logits[0, mask_token_indices] | |
# Apply top-k sampling | |
top_k_logits, top_k_indices = logits_at_masks.topk(top_k, dim=-1) | |
probabilities = torch.nn.functional.softmax(top_k_logits, dim=-1) | |
predicted_indices = Categorical(probabilities).sample() | |
predicted_token_ids = top_k_indices.gather(-1, predicted_indices.unsqueeze(-1)).squeeze(-1) | |
generated_binder = tokenizer.decode(predicted_token_ids, skip_special_tokens=True).replace(' ', '') | |
# Compute PPL for the generated binder | |
ppl_value = compute_pseudo_perplexity(model, tokenizer, protein_seq, generated_binder) | |
# Add the generated binder and its PPL to the results list | |
binders_with_ppl.append([generated_binder, ppl_value]) | |
# Convert the list of lists to a pandas dataframe | |
df = pd.DataFrame(binders_with_ppl, columns=["Binder", "Perplexity"]) | |
# Save the dataframe to a CSV file | |
output_filename = "output.csv" | |
df.to_csv(output_filename, index=False) | |
return binders_with_ppl, output_filename | |
# Define the Gradio interface | |
interface = gr.Interface( | |
fn=generate_peptide, | |
inputs=[ | |
gr.Textbox(label="Protein Sequence", info="Enter protein sequence here", type="text"), | |
gr.Slider(3, 50, value=15, label="Peptide Length", step=1, info='Default value is 15'), | |
gr.Slider(1, 10, value=3, label="Top K Value", step=1, info='Default value is 3'), | |
gr.Dropdown(choices=[1, 2, 4, 8, 16, 32], label="Number of Binders", value=1) | |
], | |
outputs=[ | |
gr.Dataframe( | |
headers=["Binder", "Perplexity"], | |
datatype=["str", "number"], | |
col_count=(2, "fixed") | |
), | |
gr.outputs.File(label="Download CSV") | |
], | |
title="PepMLM: Target Sequence-Conditioned Generation of Peptide Binders via Masked Language Modeling" | |
) | |
interface.launch() |