habdine commited on
Commit
818ca8a
1 Parent(s): 045f65d

Upload code

Browse files
Files changed (2) hide show
  1. modeling_prot2text.py +17 -228
  2. utils.py +1 -157
modeling_prot2text.py CHANGED
@@ -1,88 +1,16 @@
1
  from transformers import GPT2Config, AutoTokenizer, GPT2Config
2
  from transformers import PretrainedConfig, PreTrainedModel
3
  import transformers
4
- from typing import Optional, Tuple, Callable
5
  import torch
6
  import torch.nn as nn
7
  from transformers.modeling_utils import PreTrainedModel, PretrainedConfig
8
  from .utils import CABlock, _GPT2LMHeadModel
9
  from .configuration_prot2text import Prot2TextConfig
10
- import os
11
- import numpy as np
12
  from transformers.generation.configuration_utils import GenerationConfig
13
  from transformers.generation.logits_process import LogitsProcessorList
14
  from transformers.generation.stopping_criteria import StoppingCriteriaList
15
 
16
- from .pdb2graph import PDB2Graph, download_alphafold_structure
17
- from .graphs import *
18
- from .utils_dataset import *
19
-
20
- try:
21
- from graphein.protein.config import ProteinGraphConfig, DSSPConfig
22
- from graphein.protein.features.nodes.amino_acid import amino_acid_one_hot, meiler_embedding, expasy_protein_scale, hydrogen_bond_acceptor, hydrogen_bond_donor
23
- from graphein.protein.features.nodes.dssp import phi, psi, asa, rsa, secondary_structure
24
- from graphein.protein.edges.distance import (add_peptide_bonds,
25
- add_hydrogen_bond_interactions,
26
- add_distance_threshold,
27
- )
28
- except ImportError:
29
- raise Exception('You need to install graphein from source in addition to DSSP to use this model please refer to https://github.com/a-r-j/graphein and https://ssbio.readthedocs.io/en/latest/instructions/dssp.html')
30
-
31
- try:
32
- from torch_geometric.nn import RGCNConv, global_mean_pool
33
- except ImportError:
34
- raise Exception('You need to install torch geometric and its dependecies to use this model please refer to https://pytorch-geometric.readthedocs.io/en/latest/install/installation.html')
35
-
36
-
37
-
38
- class EncoderRGCN(PreTrainedModel):
39
- '''
40
- This class implement the RGCN encoder to encode the protein structure
41
- '''
42
- def __init__(self, input_dim, hidden_dim=512, n_layers=6, emb_dim=512, dropout=0.2, num_relation=7, prot2text_version='1.0'):
43
- super(EncoderRGCN, self).__init__(PretrainedConfig(name='RGCN'))
44
- self.n_layers = n_layers
45
- self.output_dim = emb_dim
46
- self.prot2text_version = prot2text_version
47
-
48
- self.fc0 = nn.Linear(input_dim, hidden_dim)
49
- self.batchnorm_final = nn.BatchNorm1d(hidden_dim)
50
-
51
- self.batch_norms = nn.ModuleList()
52
- self.batch_norms.append(nn.BatchNorm1d(hidden_dim))
53
- lst = list()
54
-
55
- lst.append(RGCNConv(hidden_dim, hidden_dim, num_relations=num_relation))
56
-
57
- for i in range(n_layers-1):
58
- lst.append(RGCNConv(hidden_dim,hidden_dim, num_relations=num_relation))
59
-
60
- self.conv = nn.ModuleList(lst)
61
-
62
- self.fc1 = nn.Linear(hidden_dim, hidden_dim)
63
- self.fc2 = nn.Linear(hidden_dim, self.output_dim)
64
-
65
- self.dropout = nn.Dropout(p=dropout)
66
- self.relu = nn.LeakyReLU()
67
- self.batchnorm = nn.BatchNorm1d(hidden_dim)
68
- self.main_input_name = 'nothing'
69
-
70
- def forward(self, x:Optional[torch.FloatTensor] = None,
71
- edge_index:Optional[torch.LongTensor] = None,
72
- edge_type:Optional[torch.LongTensor] = None,
73
- batch:Optional[torch.LongTensor] = None,
74
- **kargs):
75
- #construct pyg edge index shape (2, num_edges) from edge_list
76
- x = self.relu(self.fc0(x))
77
-
78
- for i in range(self.n_layers):
79
- x = self.conv[i](x, edge_index, edge_type)
80
-
81
- out = global_mean_pool(x, batch)
82
- out = self.relu(self.fc1(out))
83
- out = self.relu(self.fc2(out))
84
-
85
- return out.unsqueeze(1)
86
 
87
  class Prot2TextModel(PreTrainedModel):
88
  config_class = Prot2TextConfig
@@ -92,10 +20,6 @@ class Prot2TextModel(PreTrainedModel):
92
  super().__init__(config)
93
 
94
  self.gpt_config = GPT2Config.from_dict(config.gpt_config)
95
-
96
- # if we are using RGCN to encode the protein's structure, define the RGCN encoder
97
- if config.rgcn:
98
- self.encoder = EncoderRGCN(input_dim=config.rgcn_input_dim, hidden_dim=self.gpt_config.n_embd, n_layers=config.rgcn_n_layers, emb_dim=self.gpt_config.n_embd, prot2text_version=self.config.prot2text_version)
99
 
100
  # define the GPT2 decoder
101
  self.decoder = _GPT2LMHeadModel(self.gpt_config)
@@ -164,10 +88,6 @@ class Prot2TextModel(PreTrainedModel):
164
  if decoder_input_ids is not None and len(decoder_input_ids.size()) == 3:
165
  decoder_input_ids = decoder_input_ids.squeeze(0)
166
 
167
- if x is not None and self.config.rgcn:
168
- graph_emb = self.encoder(x, edge_index, edge_type, batch)
169
- graph_mask = None
170
-
171
  if self.config.esm:
172
  if self.config.prot2text_version=='1.0':
173
  if encoder_input_ids.size()[1] != 1021:
@@ -175,38 +95,7 @@ class Prot2TextModel(PreTrainedModel):
175
 
176
  esm_emb = self.esm(input_ids=encoder_input_ids, attention_mask=attention_mask, return_dict=return_dict).last_hidden_state
177
  esm_emb = self.to_embedding(esm_emb)
178
- if not self.config.cross_esm_graph and self.config.rgcn:
179
- graph_emb = torch.cat((graph_emb, esm_emb), dim=1)
180
- t_add = torch.ones((attention_mask.size(0), 1)).to(attention_mask.get_device())
181
- attention_mask = torch.cat((t_add, attention_mask), dim=1)
182
- elif self.config.cross_esm_graph and self.config.rgcn:
183
- if past_key_values_graph_esm is None:
184
- past_length = 0
185
- past_key_values_graph_esm = tuple([None] * len(self.h))
186
- else:
187
- past_length = past_key_values_graph_esm[0][0].size(-2)
188
- output_shape = esm_emb.size()
189
-
190
- all_self_attentions = () if output_attentions else None
191
- all_cross_attentions = () if output_attentions and self.gpt_config.add_cross_attention else None
192
- all_hidden_states = () if output_hidden_states else None
193
- for i, (block, layer_past) in enumerate(zip(self.h, past_key_values_graph_esm)):
194
- outputs = block(
195
- esm_emb,
196
- layer_past=layer_past,
197
- attention_mask=attention_mask,
198
- encoder_hidden_states=graph_emb,
199
- encoder_attention_mask=graph_mask,
200
- use_cache=use_cache,
201
- output_attentions=False,
202
- )
203
- esm_emb = outputs[0]
204
-
205
- esm_emb = self.ln_f(esm_emb)
206
- esm_emb = esm_emb.view(output_shape)
207
- graph_emb = esm_emb
208
- else:
209
- graph_emb = esm_emb
210
  else:
211
  attention_mask = None
212
  if self.config.prot2text_version=='1.0':
@@ -234,11 +123,7 @@ class Prot2TextModel(PreTrainedModel):
234
 
235
  @torch.no_grad()
236
  def generate_protein_description(self,
237
- protein_pdbID=None,
238
  protein_sequence=None,
239
- edge_index: Optional[torch.LongTensor] = None,
240
- x: Optional[torch.FloatTensor] = None,
241
- edge_type: Optional[torch.LongTensor] = None,
242
  tokenizer=None,
243
  device='cpu'
244
  ):
@@ -247,120 +132,24 @@ class Prot2TextModel(PreTrainedModel):
247
  raise ValueError(
248
  "The model you are trying to use is based only on protein sequence, please provide an amino-acid protein_sequence"
249
  )
250
- if self.config.rgcn and protein_pdbID==None and (x==None or edge_index==None or edge_type==None):
251
- raise ValueError(
252
- "The model you are trying to use is based on protein structure, please provide a AlphaFold ID (you must have to have internet connection using protein_pdbID, or provide the triplet inputs: x (node features), edge_index and edge_type"
253
- )
254
  if self.config.esm:
255
  esmtokenizer = AutoTokenizer.from_pretrained(self.config.esm_model_name)
256
-
257
- if protein_pdbID==None and protein_sequence==None:
258
- raise ValueError(
259
- "you need to provide either a protein AlphaFold Id or an amino-acid sequence"
260
- )
261
-
262
- if protein_pdbID!=None:
263
- config = {"node_metadata_functions": [amino_acid_one_hot,
264
- expasy_protein_scale,
265
- meiler_embedding,
266
- hydrogen_bond_acceptor, hydrogen_bond_donor
267
- ],
268
- "edge_construction_functions": [add_peptide_bonds,
269
- add_hydrogen_bond_interactions,
270
- partial(add_distance_threshold, long_interaction_threshold=3, threshold=10.),],
271
- "graph_metadata_functions":[asa,phi, psi, secondary_structure, rsa],
272
- "dssp_config": DSSPConfig()}
273
- config = ProteinGraphConfig(**config)
274
-
275
- PATH_TO_DATA = f"~/.tmp/pdb/pdb"
276
- OUTPUT_FOLDER = f"~/.tmp/pdb/raw"
277
- save_dir = f"~/.tmp/pdb/"
278
- isExist = os.path.exists(PATH_TO_DATA)
279
- if not isExist:
280
- os.makedirs(PATH_TO_DATA)
281
- isExist = os.path.exists(OUTPUT_FOLDER)
282
- if not isExist:
283
- os.makedirs(OUTPUT_FOLDER)
284
- isExist = os.path.exists(save_dir+'processed')
285
- if not isExist:
286
- os.makedirs(save_dir+'processed')
287
-
288
- structure_filename = download_alphafold_structure(uniprot_id=protein_pdbID, out_dir=PATH_TO_DATA)
289
- if structure_filename is None:
290
- raise ValueError("Error! the ID does not exist in AlphaFoldDB or you do not have internet connection")
291
- graph_filename = structure_filename.split('/')
292
- graph_filename[-2] = 'raw'
293
- graph_filename[-1] = graph_filename[-1].replace('.pdb', '.pt')
294
- graph_filename = '/'.join(graph_filename)
295
- process_filename = structure_filename.split('/')
296
- process_filename[-2] = 'processed'
297
- process_filename[-1] = process_filename[-1].replace('.pdb', '.pt')
298
- process_filename = '/'.join(process_filename)
299
- try:
300
- gpdb = PDB2Graph(root = PATH_TO_DATA, output_folder = OUTPUT_FOLDER, config=config, n_processors=1).create_pyg_graph(structure_filename)
301
- seq = esmtokenizer(gpdb.sequence, add_special_tokens=True, truncation=True, max_length=1021, padding='max_length',return_tensors="pt") #
302
- torch.save(gpdb, graph_filename)
303
- gpdb.edge_type = [np.array(gpdb.edge_type.transpose(0,1))]
304
- gpdb.encoder_input_ids = seq['input_ids']
305
- gpdb.attention_mask = seq['attention_mask']
306
- torch.save(gpdb, process_filename)
307
- except:
308
- os.remove(structure_filename)
309
- raise ValueError('creating graphs did not work, probably the pdb file of alphaFold is damaged')
310
-
311
- self.eval()
312
- inputs = gpdb
313
- inputs = inputs.to_dict()
314
-
315
- inputs['edge_type'] = torch.cat([torch.tensor(inputs['edge_type'][i]) for i in range(len(inputs['edge_type']))], dim=0)
316
- inputs['edge_type'] = torch.argmax(inputs['edge_type'], dim=1)
317
- for key in ['num_nodes', 'node_id', 'name', 'sequence', 'distance_matrix', 'distance', 'coordinates']:
318
- inputs.pop(key)
319
- inputs['decoder_input_ids'] = inputs['encoder_input_ids'][:,0:1].clone()
320
- inputs['decoder_input_ids'][:,0] = tokenizer.bos_token_id
321
- inputs["decoder_attention_mask"] = torch.ones(inputs['decoder_input_ids'].shape[0], 1)
322
- self.to(device)
323
- inputs = {k: v.to(device=device, non_blocking=True) if hasattr(v, 'to') else v for k, v in inputs.items()}
324
- encoder_state = dict()
325
- encoder_state['hidden_states'] = self(**inputs, get_graph_emb=True, output_attentions=True)
326
- encoder_state['attentions'] = inputs['attention_mask']
327
- for key in ['edge_index', 'edge_type', 'x', 'encoder_input_ids']:
328
- inputs.pop(key)
329
- tok_ids = self.decoder.generate(input_ids=inputs['decoder_input_ids'],
330
- encoder_outputs=encoder_state,
331
- use_cache=True,
332
- output_attentions=False,
333
- output_scores=False,
334
- return_dict_in_generate=True,
335
- encoder_attention_mask=inputs['attention_mask'],
336
- length_penalty=1.0,
337
- no_repeat_ngram_size=None,
338
- early_stopping=False,
339
- num_beams=1)
340
-
341
- generated = tokenizer.batch_decode(tok_ids.get('sequences'), skip_special_tokens=True)
342
-
343
- os.remove(structure_filename)
344
- os.remove(graph_filename)
345
- os.remove(process_filename)
346
-
347
- return generated[0].replace('<|stop_token|>', '').replace('<|graph_token|>', '')
348
-
349
- else:
350
- seq = esmtokenizer([protein_sequence], add_special_tokens=True, truncation=True, max_length=1021, padding='max_length', return_tensors="pt")
351
- inputs={}
352
- inputs['encoder_input_ids'] = seq['input_ids']
353
- inputs['attention_mask'] = seq['attention_mask']
354
- inputs['decoder_input_ids'] = inputs['encoder_input_ids'][:,0:1].clone()
355
- inputs['decoder_input_ids'][:,0] = tokenizer.bos_token_id
356
 
357
- self.to(device)
358
- inputs = {k: v.to(device=device, non_blocking=True) if hasattr(v, 'to') else v for k, v in inputs.items()}
359
- encoder_state = dict()
360
- encoder_state['hidden_states'] = self(**inputs, get_graph_emb=True, output_attentions=True)
361
- generated = tokenizer.batch_decode(self.decoder.generate(input_ids=inputs['decoder_input_ids'], encoder_outputs=encoder_state, use_cache=True), skip_special_tokens=True)
362
-
363
- return generated[0].replace('<|stop_token|>', '').replace('<|graph_token|>', '')
 
 
 
 
 
 
 
 
364
 
365
  @torch.no_grad()
366
  def generate(self,
 
1
  from transformers import GPT2Config, AutoTokenizer, GPT2Config
2
  from transformers import PretrainedConfig, PreTrainedModel
3
  import transformers
4
+ from typing import Optional, Tuple, Callable, List
5
  import torch
6
  import torch.nn as nn
7
  from transformers.modeling_utils import PreTrainedModel, PretrainedConfig
8
  from .utils import CABlock, _GPT2LMHeadModel
9
  from .configuration_prot2text import Prot2TextConfig
 
 
10
  from transformers.generation.configuration_utils import GenerationConfig
11
  from transformers.generation.logits_process import LogitsProcessorList
12
  from transformers.generation.stopping_criteria import StoppingCriteriaList
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  class Prot2TextModel(PreTrainedModel):
16
  config_class = Prot2TextConfig
 
20
  super().__init__(config)
21
 
22
  self.gpt_config = GPT2Config.from_dict(config.gpt_config)
 
 
 
 
23
 
24
  # define the GPT2 decoder
25
  self.decoder = _GPT2LMHeadModel(self.gpt_config)
 
88
  if decoder_input_ids is not None and len(decoder_input_ids.size()) == 3:
89
  decoder_input_ids = decoder_input_ids.squeeze(0)
90
 
 
 
 
 
91
  if self.config.esm:
92
  if self.config.prot2text_version=='1.0':
93
  if encoder_input_ids.size()[1] != 1021:
 
95
 
96
  esm_emb = self.esm(input_ids=encoder_input_ids, attention_mask=attention_mask, return_dict=return_dict).last_hidden_state
97
  esm_emb = self.to_embedding(esm_emb)
98
+ graph_emb = esm_emb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  else:
100
  attention_mask = None
101
  if self.config.prot2text_version=='1.0':
 
123
 
124
  @torch.no_grad()
125
  def generate_protein_description(self,
 
126
  protein_sequence=None,
 
 
 
127
  tokenizer=None,
128
  device='cpu'
129
  ):
 
132
  raise ValueError(
133
  "The model you are trying to use is based only on protein sequence, please provide an amino-acid protein_sequence"
134
  )
 
 
 
 
135
  if self.config.esm:
136
  esmtokenizer = AutoTokenizer.from_pretrained(self.config.esm_model_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
 
138
+
139
+ seq = esmtokenizer([protein_sequence], add_special_tokens=True, truncation=True, max_length=1021, padding='max_length', return_tensors="pt")
140
+ inputs={}
141
+ inputs['encoder_input_ids'] = seq['input_ids']
142
+ inputs['attention_mask'] = seq['attention_mask']
143
+ inputs['decoder_input_ids'] = inputs['encoder_input_ids'][:,0:1].clone()
144
+ inputs['decoder_input_ids'][:,0] = tokenizer.bos_token_id
145
+
146
+ self.to(device)
147
+ inputs = {k: v.to(device=device, non_blocking=True) if hasattr(v, 'to') else v for k, v in inputs.items()}
148
+ encoder_state = dict()
149
+ encoder_state['hidden_states'] = self(**inputs, get_graph_emb=True, output_attentions=True)
150
+ generated = tokenizer.batch_decode(self.decoder.generate(input_ids=inputs['decoder_input_ids'], encoder_outputs=encoder_state, use_cache=True), skip_special_tokens=True)
151
+
152
+ return generated[0].replace('<|stop_token|>', '').replace('<|graph_token|>', '')
153
 
154
  @torch.no_grad()
155
  def generate(self,
utils.py CHANGED
@@ -1,8 +1,7 @@
1
  import torch.nn as nn
2
  from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2MLP
3
  from typing import Optional, Tuple, Union, Any, Dict, List
4
- from transformers import Seq2SeqTrainer, GPT2LMHeadModel
5
- from torch.utils.data.distributed import DistributedSampler
6
  import torch
7
  from transformers.deepspeed import is_deepspeed_zero3_enabled
8
  from transformers.generation.logits_process import LogitsProcessorList
@@ -10,11 +9,6 @@ from transformers.generation.stopping_criteria import StoppingCriteriaList
10
  from transformers.generation.utils import GreedySearchOutput, GreedySearchEncoderDecoderOutput, BeamSearchOutput, BeamSearchEncoderDecoderOutput
11
  from transformers.generation.beam_search import BeamScorer
12
 
13
- try:
14
- from torch_geometric.loader import DataLoader
15
- from torch_geometric.data import Dataset
16
- except ImportError:
17
- raise Exception('You need to install torch geometric and its dependecies to use this model please refer to https://pytorch-geometric.readthedocs.io/en/latest/install/installation.html')
18
 
19
  class _GPT2LMHeadModel(GPT2LMHeadModel):
20
  def _init_(self, config):
@@ -593,153 +587,3 @@ class CABlock(nn.Module):
593
 
594
  return (hidden_states,)
595
 
596
- class Prot2TextTrainer(Seq2SeqTrainer):
597
- '''
598
- This function is an edited version of the Seq2SeqTrainer from HuggingFace's transformers
599
- '''
600
- def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
601
- if self.args.world_size > 1:
602
- eval_sampler = DistributedSampler(self.eval_dataset, num_replicas=self.args.world_size, rank=self.args.process_index)
603
- else:
604
- eval_sampler = None
605
- return DataLoader(
606
- self.eval_dataset,
607
- batch_size=self.args.eval_batch_size,
608
- collate_fn=None,
609
- num_workers=self.args.dataloader_num_workers,
610
- pin_memory=self.args.dataloader_pin_memory,
611
- sampler=eval_sampler,
612
- )
613
- def get_train_dataloader(self) -> DataLoader:
614
- if self.args.world_size > 1:
615
- train_sampler = DistributedSampler(self.train_dataset, num_replicas=self.args.world_size, rank=self.args.process_index)
616
- else:
617
- train_sampler = None
618
- return DataLoader(
619
- self.train_dataset,
620
- batch_size=self.args.per_device_train_batch_size,
621
- collate_fn=None,
622
- num_workers=self.args.dataloader_num_workers,
623
- pin_memory=self.args.dataloader_pin_memory,
624
- sampler=train_sampler,
625
- )
626
- def _prepare_inputs(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> Dict[str, Union[torch.Tensor, Any]]:
627
- """
628
- Prepare `inputs` before feeding them to the model, converting them to tensors if they are not already and
629
- handling potential state.
630
- """
631
- inputs = self._prepare_input(inputs)
632
- if len(inputs) == 0:
633
- raise ValueError(
634
- "The batch received was empty, your model won't be able to train on it. Double-check that your "
635
- f"training dataset contains keys expected by the model: {','.join(self._signature_columns)}."
636
- )
637
- if self.args.past_index >= 0 and self._past is not None:
638
- inputs["mems"] = self._past
639
-
640
- inputs = inputs.to_dict()
641
- inputs['edge_type'] = torch.cat([torch.tensor(inputs['edge_type'][i]) for i in range(len(inputs['edge_type']))], dim=0)
642
- inputs['edge_type'] = torch.argmax(inputs['edge_type'], dim=1)
643
- inputs = {k: v.to(device=self.args.device, non_blocking=True) if hasattr(v, 'to') else v for k, v in inputs.items()}
644
- return inputs
645
-
646
- def prediction_step(
647
- self,
648
- model: nn.Module,
649
- inputs: Dict[str, Union[torch.Tensor, Any]],
650
- prediction_loss_only: bool,
651
- ignore_keys: Optional[List[str]] = None,
652
- ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
653
- """
654
- Perform an evaluation step on `model` using `inputs`.
655
-
656
- Subclass and override to inject custom behavior.
657
-
658
- Args:
659
- model (`nn.Module`):
660
- The model to evaluate.
661
- inputs (`Dict[str, Union[torch.Tensor, Any]]`):
662
- The inputs and targets of the model.
663
-
664
- The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
665
- argument `labels`. Check your model's documentation for all accepted arguments.
666
- prediction_loss_only (`bool`):
667
- Whether or not to return the loss only.
668
-
669
- Return:
670
- Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, logits and
671
- labels (each being optional).
672
- """
673
-
674
- if not self.args.predict_with_generate or prediction_loss_only:
675
- return super().prediction_step(
676
- model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys
677
- )
678
-
679
- has_labels = "labels" in inputs
680
- inputs = self._prepare_inputs(inputs)
681
-
682
- # XXX: adapt synced_gpus for fairscale as well
683
- gen_kwargs = self._gen_kwargs.copy()
684
- if gen_kwargs.get("max_length") is None and gen_kwargs.get("max_new_tokens") is None:
685
- gen_kwargs["max_length"] = self.model.config.max_length
686
- gen_kwargs["num_beams"] = (
687
- gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.model.config.num_beams
688
- )
689
- default_synced_gpus = True if is_deepspeed_zero3_enabled() else False
690
- gen_kwargs["synced_gpus"] = (
691
- gen_kwargs["synced_gpus"] if gen_kwargs.get("synced_gpus") is not None else default_synced_gpus
692
- )
693
-
694
- if "attention_mask" in inputs:
695
- gen_kwargs["attention_mask"] = inputs.get("attention_mask", None)
696
- if "global_attention_mask" in inputs:
697
- gen_kwargs["global_attention_mask"] = inputs.get("global_attention_mask", None)
698
-
699
- generation_inputs = None
700
- gen_kwargs['x'] = inputs.get('x', None)
701
- gen_kwargs['edge_index'] = inputs.get('edge_index', None)
702
- gen_kwargs['edge_type'] = inputs.get('edge_type', None)
703
- gen_kwargs['batch'] = inputs.get('batch', None)
704
- gen_kwargs['encoder_input_ids'] = inputs.get('encoder_input_ids', None)
705
- gen_kwargs['decoder_input_ids'] = inputs.get('decoder_input_ids', None)[:,0:1]
706
- gen_kwargs["decoder_attention_mask"] = torch.ones(gen_kwargs['decoder_input_ids'].shape[0], 1).to(self.args.device)
707
-
708
- generated_tokens = self.model.generate(
709
- generation_inputs,
710
- **gen_kwargs,
711
- )
712
- # in case the batch is shorter than max length, the output should be padded
713
- if gen_kwargs.get("max_length") is not None and generated_tokens.shape[-1] < gen_kwargs["max_length"]:
714
- generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_length"])
715
- elif gen_kwargs.get("max_new_tokens") is not None and generated_tokens.shape[-1] < (
716
- gen_kwargs["max_new_tokens"] + 1
717
- ):
718
- generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_new_tokens"] + 1)
719
-
720
- with torch.no_grad():
721
- if has_labels:
722
- with self.compute_loss_context_manager():
723
- outputs = model(**inputs)
724
- if self.label_smoother is not None:
725
- loss = self.label_smoother(outputs, inputs["labels"]).mean().detach()
726
- else:
727
- loss = (outputs["loss"] if isinstance(outputs, dict) else outputs[0]).mean().detach()
728
- else:
729
- loss = None
730
-
731
- if self.args.prediction_loss_only:
732
- return (loss, None, None)
733
-
734
- if has_labels:
735
- labels = inputs["labels"]
736
- if gen_kwargs.get("max_length") is not None and labels.shape[-1] < gen_kwargs["max_length"]:
737
- labels = self._pad_tensors_to_max_len(labels, gen_kwargs["max_length"])
738
- elif gen_kwargs.get("max_new_tokens") is not None and labels.shape[-1] < (
739
- gen_kwargs["max_new_tokens"] + 1
740
- ):
741
- labels = self._pad_tensors_to_max_len(labels, (gen_kwargs["max_new_tokens"] + 1))
742
- else:
743
- labels = None
744
-
745
- return (loss, generated_tokens, labels)
 
1
  import torch.nn as nn
2
  from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2MLP
3
  from typing import Optional, Tuple, Union, Any, Dict, List
4
+ from transformers import GPT2LMHeadModel
 
5
  import torch
6
  from transformers.deepspeed import is_deepspeed_zero3_enabled
7
  from transformers.generation.logits_process import LogitsProcessorList
 
9
  from transformers.generation.utils import GreedySearchOutput, GreedySearchEncoderDecoderOutput, BeamSearchOutput, BeamSearchEncoderDecoderOutput
10
  from transformers.generation.beam_search import BeamScorer
11
 
 
 
 
 
 
12
 
13
  class _GPT2LMHeadModel(GPT2LMHeadModel):
14
  def _init_(self, config):
 
587
 
588
  return (hidden_states,)
589