File size: 4,729 Bytes
b06adb9 a66b84c b06adb9 ba25241 b06adb9 ba25241 b06adb9 ba25241 b06adb9 a66b84c b06adb9 ba25241 a66b84c b06adb9 ba25241 a66b84c ba25241 b06adb9 a66b84c b06adb9 ba25241 b06adb9 a66b84c b06adb9 a66b84c b06adb9 a66b84c b06adb9 a66b84c b06adb9 a66b84c b06adb9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
from typing import Optional, Tuple, Union
from abc import abstractmethod
import torch
from torch.nn import BCEWithLogitsLoss
from transformers import PreTrainedModel, PreTrainedTokenizer
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from transformers.modeling_outputs import BaseModelOutputWithPast, SequenceClassifierOutputWithPast
from .sensor_loc_reg import SENSOR_LOC_REGISTRY
from .sensor_loc_finder import SensorLocFinder
class MeasurementPredictorMixin(PreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.sensor_loc_type = config.sensor_loc_type
self.sensor_token = config.sensor_token
self.n_sensors = config.n_sensors
self.sensor_probes = torch.nn.ModuleList([
torch.nn.Linear(config.emb_dim, 1) for _ in range(config.n_sensors)
])
self.aggregate_probe = torch.nn.Linear(config.emb_dim, 1)
self.sensors_weight = config.sensors_weight
self.aggregate_weight = config.aggregate_weight
self.find_sensor_locs: SensorLocFinder = None
@abstractmethod
def set_pad_token(self, tokenizer: PreTrainedTokenizerBase):
pass
def init_sensor_loc_finder(self, tokenizer: PreTrainedTokenizerBase):
self.find_sensor_locs = SENSOR_LOC_REGISTRY[self.sensor_loc_type](
tokenizer, sensor_token=self.sensor_token, n_sensors=self.n_sensors
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
base_model_output: BaseModelOutputWithPast = self.base_model(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
# get sensor embeddings (including aggregate)
sensor_locs = self.find_sensor_locs(input_ids)
sensor_embs = base_model_output.last_hidden_state.gather(
1, sensor_locs.unsqueeze(-1).expand(-1, -1, self.config.emb_dim)
)
assert sensor_embs.shape == (input_ids.shape[0], self.n_sensors + 1, self.config.emb_dim), sensor_embs.shape
# get sensor and aggregate logits
sensor_logits = torch.concat([self.sensor_probes[i](sensor_embs[:, i, :])
for i in range(self.n_sensors)], dim=-1)
aggregate_logits = self.aggregate_probe(sensor_embs[:, -1, :])
logits = torch.concat([sensor_logits, aggregate_logits], dim=-1)
# compute loss
loss = None
if labels is not None:
loss_fct = BCEWithLogitsLoss()
sensor_loss = loss_fct(sensor_logits[:, :self.n_sensors], labels[:, :self.n_sensors]) * self.sensors_weight
loss = sensor_loss
aggregate_loss = loss_fct(aggregate_logits, labels[:, -1:]) * self.aggregate_weight
loss += aggregate_loss
if not return_dict:
output = (logits, ) + base_model_output[1:]
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutputWithPast(
loss=loss,
logits=logits,
past_key_values=base_model_output.past_key_values,
hidden_states=base_model_output.hidden_states,
attentions=base_model_output.attentions,
)
|