IACC-ranker-small / modeling_rankingprompter.py
howard-hou's picture
Upload RankingPrompterForPreTraining
438b415
raw
history blame
5.32 kB
from contextlib import nullcontext
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import torch
from torch import nn
from torch.nn import CrossEntropyLoss
from transformers import UMT5Model
from .configuration_rankingprompter import RankingPrompterConfig
@dataclass
class RankingPrompterForPreTrainingOutput:
loss: torch.FloatTensor = None
logits: torch.FloatTensor = None
class RankingPrompterForPreTraining(UMT5Model):
config_class = RankingPrompterConfig
_tied_weights_keys = [
"encoder.embed_tokens.weight",
"decoder.embed_tokens.weight",
]
def __init__(self, config):
# encoder, decoder and shared are from UMT5Model
super().__init__(config)
# add ranking head
self.ranking_head = nn.Linear(config.d_model, 1)
# Initialize weights and apply final processing
self.post_init()
# ctx for mixed precision training
self.ctx = nullcontext()
def enable_amp_ctx(self, device_type="cuda", dtype=torch.bfloat16):
self.ctx = torch.amp.autocast(device_type=device_type, dtype=dtype)
def disable_amp_ctx(self):
self.ctx = nullcontext()
def forward(
self,
document_input_ids: Optional[torch.LongTensor] = None,
document_attention_mask: Optional[torch.FloatTensor] = None,
question_input_ids: Optional[torch.LongTensor] = None,
question_attention_mask: Optional[torch.BoolTensor] = None,
encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.FloatTensor], RankingPrompterForPreTrainingOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ...,
config.vocab_size - 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for
labels in `[0, ..., config.vocab_size]`
Returns:
```"""
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
# document_input_ids: [batch_size, num_doc, doc_seq_len]
batch_size, num_doc, doc_seq_len = document_input_ids.shape
#
document_input_ids = document_input_ids.view(-1, doc_seq_len)
# to [batch_size * num_doc, doc_seq_len]
document_attention_mask = document_attention_mask.view(-1, doc_seq_len)
# Convert encoder inputs in embeddings if needed
with self.ctx:
encoder_outputs = self.encoder(
input_ids=document_input_ids,
attention_mask=document_attention_mask,
return_dict=return_dict,
)
document_embeds = encoder_outputs[0]
# repeat question inputs for each document
# question_input_ids: [batch_size, question_seq_len]
question_seq_len = question_input_ids.shape[1]
question_input_ids = (
question_input_ids.unsqueeze(1)
.expand(-1, num_doc, -1)
.reshape(-1, question_seq_len)
) # [batch_size * num_doc, question_seq_len]
question_attention_mask = (
question_attention_mask.unsqueeze(1)
.expand(-1, num_doc, -1)
.reshape(-1, question_seq_len)
) # [batch_size * num_doc, question_seq_len]
# Decode
with self.ctx:
decoder_outputs = self.decoder(
input_ids=question_input_ids,
attention_mask=question_attention_mask,
past_key_values=past_key_values,
encoder_hidden_states=document_embeds,
encoder_attention_mask=document_attention_mask,
use_cache=use_cache,
return_dict=return_dict,
)
# [batch_size * num_doc, soft_prompt_len + question_seq_len, hidden_size]
sequence_output = decoder_outputs[0]
# [batch_size * num_doc, soft_prompt_len, hidden_size]
question_seq_len = sequence_output.size(1)
# [batch_size, num_doc, soft_prompt_len, hidden_size]
soft_prompt_output = sequence_output.view(
batch_size, num_doc, question_seq_len, -1
)
# [batch_size, num_doc, self.num_soft_prompt_tokens, hidden_size] -> [batch_size, num_doc, hidden_size]
ranking_logits = self.ranking_head(soft_prompt_output.mean(dim=2))
# rank loss
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss(ignore_index=-100)
ranking_logits = ranking_logits.view(batch_size, num_doc)
loss = loss_fct(ranking_logits, labels)
if not return_dict:
output = (ranking_logits,) + decoder_outputs[1:] + encoder_outputs
return ((loss,) + output) if loss is not None else output
return RankingPrompterForPreTrainingOutput(
loss=loss,
logits=ranking_logits
)