|
from contextlib import nullcontext |
|
from dataclasses import dataclass |
|
from typing import Optional, Tuple, Union |
|
|
|
import torch |
|
from torch import nn |
|
from torch.nn import CrossEntropyLoss |
|
from transformers import UMT5Model |
|
from .configuration_rankingprompter import RankingPrompterConfig |
|
|
|
|
|
@dataclass |
|
class RankingPrompterForPreTrainingOutput: |
|
loss: torch.FloatTensor = None |
|
logits: torch.FloatTensor = None |
|
|
|
|
|
class RankingPrompterForPreTraining(UMT5Model): |
|
config_class = RankingPrompterConfig |
|
|
|
_tied_weights_keys = [ |
|
"encoder.embed_tokens.weight", |
|
"decoder.embed_tokens.weight", |
|
] |
|
|
|
def __init__(self, config): |
|
|
|
super().__init__(config) |
|
|
|
|
|
self.ranking_head = nn.Linear(config.d_model, 1) |
|
|
|
|
|
self.post_init() |
|
|
|
|
|
self.ctx = nullcontext() |
|
|
|
def enable_amp_ctx(self, device_type="cuda", dtype=torch.bfloat16): |
|
self.ctx = torch.amp.autocast(device_type=device_type, dtype=dtype) |
|
|
|
def disable_amp_ctx(self): |
|
self.ctx = nullcontext() |
|
|
|
def forward( |
|
self, |
|
document_input_ids: Optional[torch.LongTensor] = None, |
|
document_attention_mask: Optional[torch.FloatTensor] = None, |
|
question_input_ids: Optional[torch.LongTensor] = None, |
|
question_attention_mask: Optional[torch.BoolTensor] = None, |
|
encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None, |
|
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, |
|
labels: Optional[torch.LongTensor] = None, |
|
use_cache: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
) -> Union[Tuple[torch.FloatTensor], RankingPrompterForPreTrainingOutput]: |
|
r""" |
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): |
|
Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ..., |
|
config.vocab_size - 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for |
|
labels in `[0, ..., config.vocab_size]` |
|
|
|
Returns: |
|
|
|
```""" |
|
use_cache = use_cache if use_cache is not None else self.config.use_cache |
|
return_dict = ( |
|
return_dict if return_dict is not None else self.config.use_return_dict |
|
) |
|
|
|
batch_size, num_doc, doc_seq_len = document_input_ids.shape |
|
|
|
document_input_ids = document_input_ids.view(-1, doc_seq_len) |
|
|
|
document_attention_mask = document_attention_mask.view(-1, doc_seq_len) |
|
|
|
|
|
with self.ctx: |
|
encoder_outputs = self.encoder( |
|
input_ids=document_input_ids, |
|
attention_mask=document_attention_mask, |
|
return_dict=return_dict, |
|
) |
|
|
|
document_embeds = encoder_outputs[0] |
|
|
|
|
|
|
|
question_seq_len = question_input_ids.shape[1] |
|
question_input_ids = ( |
|
question_input_ids.unsqueeze(1) |
|
.expand(-1, num_doc, -1) |
|
.reshape(-1, question_seq_len) |
|
) |
|
question_attention_mask = ( |
|
question_attention_mask.unsqueeze(1) |
|
.expand(-1, num_doc, -1) |
|
.reshape(-1, question_seq_len) |
|
) |
|
|
|
|
|
with self.ctx: |
|
decoder_outputs = self.decoder( |
|
input_ids=question_input_ids, |
|
attention_mask=question_attention_mask, |
|
past_key_values=past_key_values, |
|
encoder_hidden_states=document_embeds, |
|
encoder_attention_mask=document_attention_mask, |
|
use_cache=use_cache, |
|
return_dict=return_dict, |
|
) |
|
|
|
sequence_output = decoder_outputs[0] |
|
|
|
question_seq_len = sequence_output.size(1) |
|
|
|
soft_prompt_output = sequence_output.view( |
|
batch_size, num_doc, question_seq_len, -1 |
|
) |
|
|
|
|
|
ranking_logits = self.ranking_head(soft_prompt_output.mean(dim=2)) |
|
|
|
|
|
loss = None |
|
if labels is not None: |
|
loss_fct = CrossEntropyLoss(ignore_index=-100) |
|
ranking_logits = ranking_logits.view(batch_size, num_doc) |
|
loss = loss_fct(ranking_logits, labels) |
|
|
|
if not return_dict: |
|
output = (ranking_logits,) + decoder_outputs[1:] + encoder_outputs |
|
return ((loss,) + output) if loss is not None else output |
|
|
|
return RankingPrompterForPreTrainingOutput( |
|
loss=loss, |
|
logits=ranking_logits |
|
) |
|
|
|
|
|
|