Esm2Text-Base-v1-1 / modeling_prot2text.py
habdine's picture
Upload code
818ca8a verified
raw
history blame contribute delete
No virus
10.1 kB
from transformers import GPT2Config, AutoTokenizer, GPT2Config
from transformers import PretrainedConfig, PreTrainedModel
import transformers
from typing import Optional, Tuple, Callable, List
import torch
import torch.nn as nn
from transformers.modeling_utils import PreTrainedModel, PretrainedConfig
from .utils import CABlock, _GPT2LMHeadModel
from .configuration_prot2text import Prot2TextConfig
from transformers.generation.configuration_utils import GenerationConfig
from transformers.generation.logits_process import LogitsProcessorList
from transformers.generation.stopping_criteria import StoppingCriteriaList
class Prot2TextModel(PreTrainedModel):
config_class = Prot2TextConfig
_keys_to_ignore_on_load_missing = [r"transformer"]
base_model_prefix = "decoder"
def __init__(self, config):
super().__init__(config)
self.gpt_config = GPT2Config.from_dict(config.gpt_config)
# define the GPT2 decoder
self.decoder = _GPT2LMHeadModel(self.gpt_config)
# if using ESM to encode protein's sequence, define the ESM layer, the Projection layer and the fusion layer
if config.esm:
self.esm_config = PretrainedConfig.from_dict(config.esm_config)
self.esm = transformers.EsmModel(self.esm_config)
self.to_embedding = nn.Linear(self.esm_config.hidden_size, self.gpt_config.n_embd)
if config.cross_esm_graph and config.rgcn:
self.h = nn.ModuleList([CABlock(self.gpt_config, layer_idx=i) for i in range(4)])
self.ln_f = nn.LayerNorm(self.gpt_config.n_embd, eps=self.gpt_config.layer_norm_epsilon)
self.config = config
def get_encoder(self):
return self.encoder
def get_decoder(self):
return self.decoder
def get_input_embeddings(self):
if hasattr(self, "transformer"):
return self.transformer.wte
return self.decoder.transformer.wte
def warm_up(self, gpt_model=None, esm_model=None):
if esm_model is not None:
self.esm = transformers.EsmModel.from_pretrained(esm_model)
if gpt_model is not None:
self.decoder = _GPT2LMHeadModel.from_pretrained(gpt_model, add_cross_attention=True, use_cache=False)
self.decoder.resize_token_embeddings(self.gpt_config.vocab_size)
self.decoder.config = self.gpt_config
def forward(self,
encoder_input_ids: Optional[torch.LongTensor] = None,
edge_index: Optional[torch.LongTensor] = None,
batch: Optional[torch.LongTensor] = None,
x: Optional[torch.FloatTensor] = None,
edge_type: Optional[torch.LongTensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
past_key_values_graph_esm: Optional[Tuple[Tuple[torch.Tensor]]] = None,
decoder_attention_mask: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
get_graph_emb: Optional[bool] = False,
**delete_args,
):
use_cache = use_cache if use_cache is not None else self.gpt_config.use_cache
return_dict = return_dict if return_dict is not None else self.gpt_config.use_return_dict
if decoder_input_ids is not None and len(decoder_input_ids.size()) == 3:
decoder_input_ids = decoder_input_ids.squeeze(0)
if self.config.esm:
if self.config.prot2text_version=='1.0':
if encoder_input_ids.size()[1] != 1021:
raise ValueError("For this version of the model you need to PAD/Truncate the amino acid sequence for the ESM model to 1021")
esm_emb = self.esm(input_ids=encoder_input_ids, attention_mask=attention_mask, return_dict=return_dict).last_hidden_state
esm_emb = self.to_embedding(esm_emb)
graph_emb = esm_emb
else:
attention_mask = None
if self.config.prot2text_version=='1.0':
attention_mask = None
if get_graph_emb:
return graph_emb
transformer_outputs = self.decoder(input_ids=decoder_input_ids,
past_key_values=past_key_values,
attention_mask=decoder_attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
encoder_hidden_states=graph_emb,
encoder_attention_mask=attention_mask,
labels=labels,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
return transformer_outputs
@torch.no_grad()
def generate_protein_description(self,
protein_sequence=None,
tokenizer=None,
device='cpu'
):
if self.config.esm and not self.config.rgcn and protein_sequence==None:
raise ValueError(
"The model you are trying to use is based only on protein sequence, please provide an amino-acid protein_sequence"
)
if self.config.esm:
esmtokenizer = AutoTokenizer.from_pretrained(self.config.esm_model_name)
seq = esmtokenizer([protein_sequence], add_special_tokens=True, truncation=True, max_length=1021, padding='max_length', return_tensors="pt")
inputs={}
inputs['encoder_input_ids'] = seq['input_ids']
inputs['attention_mask'] = seq['attention_mask']
inputs['decoder_input_ids'] = inputs['encoder_input_ids'][:,0:1].clone()
inputs['decoder_input_ids'][:,0] = tokenizer.bos_token_id
self.to(device)
inputs = {k: v.to(device=device, non_blocking=True) if hasattr(v, 'to') else v for k, v in inputs.items()}
encoder_state = dict()
encoder_state['hidden_states'] = self(**inputs, get_graph_emb=True, output_attentions=True)
generated = tokenizer.batch_decode(self.decoder.generate(input_ids=inputs['decoder_input_ids'], encoder_outputs=encoder_state, use_cache=True), skip_special_tokens=True)
return generated[0].replace('<|stop_token|>', '').replace('<|graph_token|>', '')
@torch.no_grad()
def generate(self,
inputs: Optional[torch.Tensor] = None,
generation_config: Optional[GenerationConfig] = None,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
synced_gpus: Optional[bool] = None,
assistant_model: Optional["PreTrainedModel"] = None,
streamer: Optional["BaseStreamer"] = None,
**kwargs,
):
encoder_state = self(**kwargs, get_graph_emb=True)
input_ids = kwargs['decoder_input_ids']
attention_mask = kwargs['decoder_attention_mask']
kwargs['encoder_attention_mask'] = kwargs['attention_mask']
if not self.config.cross_esm_graph and self.config.rgcn and self.config.esm:
t_add = torch.ones((kwargs['encoder_attention_mask'].size(0), 1)).to(kwargs['encoder_attention_mask'].get_device())
kwargs['encoder_attention_mask'] = torch.cat((t_add, kwargs['encoder_attention_mask']), dim=1)
for key in ['edge_index', 'edge_type', 'x', 'encoder_input_ids', 'decoder_input_ids', 'decoder_attention_mask', 'batch', 'attention_mask', 'max_length',
'_num_nodes', 'node_id', 'name', 'sequence', 'distance_matrix', 'distance', 'coordinates', 'ptr', 'num_nodes',]:
if key in kwargs.keys():
kwargs.pop(key)
return self.decoder.generate(input_ids=input_ids,
generation_config=generation_config,
logits_processor=logits_processor,
stopping_criteria=stopping_criteria,
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
synced_gpus=synced_gpus,
assistant_model=assistant_model,
streamer=streamer,
encoder_outputs={'hidden_states': encoder_state, 'attentions':0},
**kwargs
)