TianlaiChen commited on
Commit
302efca
1 Parent(s): a317320
Files changed (1) hide show
  1. app.py +50 -24
app.py CHANGED
@@ -8,43 +8,69 @@ from torch.distributions.categorical import Categorical
8
  tokenizer = AutoTokenizer.from_pretrained("TianlaiChen/PepMLM-650M")
9
  model = AutoModelForMaskedLM.from_pretrained("TianlaiChen/PepMLM-650M")
10
 
 
 
 
11
 
12
- def generate_peptide(protein_seq, peptide_length, top_k):
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  peptide_length = int(peptide_length)
15
  top_k = int(top_k)
16
-
17
- masked_peptide = '<mask>' * peptide_length
18
- input_sequence = protein_seq + masked_peptide
19
- inputs = tokenizer(input_sequence, return_tensors="pt").to(model.device)
20
 
21
- with torch.no_grad():
22
- logits = model(**inputs).logits
23
- mask_token_indices = (inputs["input_ids"] == tokenizer.mask_token_id).nonzero(as_tuple=True)[1]
24
- logits_at_masks = logits[0, mask_token_indices]
 
 
 
 
 
 
 
 
25
 
26
- # Apply top-k sampling
27
- top_k_logits, top_k_indices = logits_at_masks.topk(top_k, dim=-1)
28
- probabilities = torch.nn.functional.softmax(top_k_logits, dim=-1)
29
- predicted_indices = Categorical(probabilities).sample()
30
- predicted_token_ids = top_k_indices.gather(-1, predicted_indices.unsqueeze(-1)).squeeze(-1)
31
 
32
- generated_peptide = tokenizer.decode(predicted_token_ids, skip_special_tokens=True)
33
- return generated_peptide.replace(' ', '')
34
 
 
 
 
35
 
 
 
 
36
 
37
  # Define the Gradio interface
38
  interface = gr.Interface(
39
  fn=generate_peptide,
40
  inputs=[
41
- gr.Textbox(label="Protein Sequence", info = "Enter protein sequence here", type="text"),
42
- gr.Slider(3, 50, value=15, label="Peptide Length", step=1,
43
- info='Default value is 15'),
44
- gr.Slider(1, 10, value=3, label="Top K Value", step=1,
45
- info='Default value is 3')
46
- ],
47
- outputs=gr.outputs.Textbox(label="Binder"),
48
  )
49
 
50
- interface.launch(title="PepMLM: Target Sequence-Conditioned Generation of Peptide Binders via Masked Language Modeling")
 
8
  tokenizer = AutoTokenizer.from_pretrained("TianlaiChen/PepMLM-650M")
9
  model = AutoModelForMaskedLM.from_pretrained("TianlaiChen/PepMLM-650M")
10
 
11
+ def compute_pseudo_perplexity(model, tokenizer, protein_seq, binder_seq):
12
+ sequence = protein_seq + binder_seq
13
+ tensor_input = tokenizer.encode(sequence, return_tensors='pt').to(model.device)
14
 
15
+ # Create a mask for the binder sequence
16
+ binder_mask = torch.zeros(tensor_input.shape).to(model.device)
17
+ binder_mask[0, -len(binder_seq)-1:-1] = 1
18
+
19
+ # Mask the binder sequence in the input and create labels
20
+ masked_input = tensor_input.clone().masked_fill_(binder_mask.bool(), tokenizer.mask_token_id)
21
+ labels = tensor_input.clone().masked_fill_(~binder_mask.bool(), -100)
22
+
23
+ with torch.no_grad():
24
+ loss = model(masked_input, labels=labels).loss
25
+ return np.exp(loss.item())
26
+
27
+
28
+ def generate_peptide(protein_seq, peptide_length, top_k, num_binders):
29
 
30
  peptide_length = int(peptide_length)
31
  top_k = int(top_k)
32
+ num_binders = int(num_binders)
 
 
 
33
 
34
+ binders_with_ppl = []
35
+
36
+ for _ in range(num_binders):
37
+ # Generate binder
38
+ masked_peptide = '<mask>' * peptide_length
39
+ input_sequence = protein_seq + masked_peptide
40
+ inputs = tokenizer(input_sequence, return_tensors="pt").to(model.device)
41
+
42
+ with torch.no_grad():
43
+ logits = model(**inputs).logits
44
+ mask_token_indices = (inputs["input_ids"] == tokenizer.mask_token_id).nonzero(as_tuple=True)[1]
45
+ logits_at_masks = logits[0, mask_token_indices]
46
 
47
+ # Apply top-k sampling
48
+ top_k_logits, top_k_indices = logits_at_masks.topk(top_k, dim=-1)
49
+ probabilities = torch.nn.functional.softmax(top_k_logits, dim=-1)
50
+ predicted_indices = Categorical(probabilities).sample()
51
+ predicted_token_ids = top_k_indices.gather(-1, predicted_indices.unsqueeze(-1)).squeeze(-1)
52
 
53
+ generated_binder = tokenizer.decode(predicted_token_ids, skip_special_tokens=True).replace(' ', '')
 
54
 
55
+ # Compute PPL for the generated binder
56
+ ppl_value = compute_pseudo_perplexity(model, tokenizer, protein_seq, generated_binder)
57
+ binders_with_ppl.append((generated_binder, ppl_value))
58
 
59
+ # Formatting the output
60
+ output = "\n".join([f"Binder: {binder}, PPL: {ppl:.2f}" for binder, ppl in binders_with_ppl])
61
+ return output
62
 
63
  # Define the Gradio interface
64
  interface = gr.Interface(
65
  fn=generate_peptide,
66
  inputs=[
67
+ gr.Textbox(label="Protein Sequence", info="Enter protein sequence here", type="text"),
68
+ gr.Slider(3, 50, value=15, label="Peptide Length", step=1, info='Default value is 15'),
69
+ gr.Slider(1, 10, value=3, label="Top K Value", step=1, info='Default value is 3'),
70
+ gr.Dropdown(choices=[1, 2, 4, 8, 16, 32], label="Number of Binders", value=4)
71
+ ],
72
+ outputs=gr.outputs.Textbox(label="Binders (with Perplexity)"),
73
+ title="PepMLM: Target Sequence-Conditioned Generation of Peptide Binders via Masked Language Modeling"
74
  )
75
 
76
+ interface.launch()