|
from typing import Optional, Tuple, Union |
|
from abc import abstractmethod |
|
|
|
import torch |
|
from torch.nn import BCEWithLogitsLoss |
|
from transformers import PreTrainedModel, PreTrainedTokenizer |
|
from transformers.tokenization_utils_base import PreTrainedTokenizerBase |
|
from transformers.modeling_outputs import BaseModelOutputWithPast, SequenceClassifierOutputWithPast |
|
|
|
|
|
from .sensor_loc_reg import SENSOR_LOC_REGISTRY |
|
from .sensor_loc_finder import SensorLocFinder |
|
|
|
class MeasurementPredictorMixin(PreTrainedModel): |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.sensor_loc_type = config.sensor_loc_type |
|
self.sensor_token = config.sensor_token |
|
self.n_sensors = config.n_sensors |
|
self.sensor_probes = torch.nn.ModuleList([ |
|
torch.nn.Linear(config.emb_dim, 1) for _ in range(config.n_sensors) |
|
]) |
|
self.aggregate_probe = torch.nn.Linear(config.emb_dim, 1) |
|
self.sensors_weight = config.sensors_weight |
|
self.aggregate_weight = config.aggregate_weight |
|
|
|
self.find_sensor_locs: SensorLocFinder = None |
|
|
|
@abstractmethod |
|
def set_pad_token(self, tokenizer: PreTrainedTokenizerBase): |
|
pass |
|
|
|
def init_sensor_loc_finder(self, tokenizer: PreTrainedTokenizerBase): |
|
self.find_sensor_locs = SENSOR_LOC_REGISTRY[self.sensor_loc_type]( |
|
tokenizer, sensor_token=self.sensor_token, n_sensors=self.n_sensors |
|
) |
|
|
|
def forward( |
|
self, |
|
input_ids: Optional[torch.LongTensor] = None, |
|
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, |
|
attention_mask: Optional[torch.FloatTensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
head_mask: Optional[torch.FloatTensor] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
labels: Optional[torch.LongTensor] = None, |
|
use_cache: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
) -> Union[Tuple, SequenceClassifierOutputWithPast]: |
|
r""" |
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
|
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set |
|
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` |
|
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` |
|
""" |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
base_model_output: BaseModelOutputWithPast = self.base_model( |
|
input_ids, |
|
past_key_values=past_key_values, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
head_mask=head_mask, |
|
inputs_embeds=inputs_embeds, |
|
use_cache=use_cache, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
|
|
sensor_locs = self.find_sensor_locs(input_ids) |
|
sensor_embs = base_model_output.last_hidden_state.gather( |
|
1, sensor_locs.unsqueeze(-1).expand(-1, -1, self.config.emb_dim) |
|
) |
|
assert sensor_embs.shape == (input_ids.shape[0], self.n_sensors + 1, self.config.emb_dim), sensor_embs.shape |
|
|
|
|
|
sensor_logits = torch.concat([self.sensor_probes[i](sensor_embs[:, i, :]) |
|
for i in range(self.n_sensors)], dim=-1) |
|
aggregate_logits = self.aggregate_probe(sensor_embs[:, -1, :]) |
|
logits = torch.concat([sensor_logits, aggregate_logits], dim=-1) |
|
|
|
|
|
loss = None |
|
if labels is not None: |
|
loss_fct = BCEWithLogitsLoss() |
|
sensor_loss = loss_fct(sensor_logits[:, :self.n_sensors], labels[:, :self.n_sensors]) * self.sensors_weight |
|
loss = sensor_loss |
|
aggregate_loss = loss_fct(aggregate_logits, labels[:, -1:]) * self.aggregate_weight |
|
loss += aggregate_loss |
|
|
|
if not return_dict: |
|
output = (logits, ) + base_model_output[1:] |
|
return ((loss,) + output) if loss is not None else output |
|
|
|
return SequenceClassifierOutputWithPast( |
|
loss=loss, |
|
logits=logits, |
|
past_key_values=base_model_output.past_key_values, |
|
hidden_states=base_model_output.hidden_states, |
|
attentions=base_model_output.attentions, |
|
) |
|
|
|
|