File size: 4,802 Bytes
9788df4
 
 
 
 
5e08178
9788df4
 
5e08178
 
 
 
9788df4
 
 
 
5e08178
9788df4
 
 
 
 
 
 
 
 
 
5e08178
 
9788df4
5e08178
 
 
 
9788df4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5e08178
9788df4
5e08178
9788df4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
from typing import Optional, Tuple, Union

import torch
from torch.nn import BCEWithLogitsLoss
from transformers import PreTrainedModel, PreTrainedTokenizer
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from transformers.modeling_outputs import BaseModelOutputWithPast, SequenceClassifierOutputWithPast


from .sensor_loc_reg import SENSOR_LOC_REGISTRY
from .sensor_loc_finder import SensorLocFinder

class MeasurementPredictorMixin(PreTrainedModel):
    
    def __init__(self, config):
        super().__init__(config)
        self.sensor_loc_type = config.sensor_loc_type
        self.sensor_token = config.sensor_token
        self.n_sensors = config.n_sensors
        self.sensor_probes = torch.nn.ModuleList([
            torch.nn.Linear(config.emb_dim, 1) for _ in range(config.n_sensors)
        ])
        self.use_aggregated = config.use_aggregated
        if config.use_aggregated:
            self.aggregate_probe = torch.nn.Linear(config.emb_dim, 1)
        self.sensors_weight = config.sensors_weight
        self.aggregate_weight = config.aggregate_weight

        self.get_sensor_locs: SensorLocFinder = None 
    
    def init_sensor_loc_finder(self, tokenizer: PreTrainedTokenizerBase):
        self.get_sensor_locs = SENSOR_LOC_REGISTRY[self.sensor_loc_type](
            tokenizer, sensor_token=self.sensor_token, n_sensors=self.n_sensors
        )

    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
        r"""
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
            `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
            are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        base_model_output: BaseModelOutputWithPast = self.base_model(
            input_ids,
            past_key_values=past_key_values,
            attention_mask=attention_mask,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        sensor_locs = self.get_sensor_locs(input_ids)
        sensor_embs = base_model_output.last_hidden_state.gather(
            1, sensor_locs.unsqueeze(-1).expand(-1, -1, self.config.emb_dim)
        )
        assert sensor_embs.shape == (input_ids.shape[0], self.n_sensors, self.config.emb_dim), f"{sensor_embs.shape} != {(input_ids.shape[0], self.n_sensors, self.config.emb_dim)}"
        sensor_logits = torch.concat([self.sensor_probes[i](sensor_embs[:, i, :]) 
                               for i in range(self.n_sensors)], dim=-1)
        logits = sensor_logits

        if self.use_aggregated:
            last_emb = base_model_output.last_hidden_state[:, -1, :]
            aggregate_logits = self.aggregate_probe(last_emb)
            logits = torch.concat([logits, aggregate_logits], dim=-1)
        
        loss = None
        if labels is not None:
            loss_fct = BCEWithLogitsLoss()
            sensor_loss = loss_fct(sensor_logits, labels[:, :self.n_sensors]) * self.sensors_weight
            loss = sensor_loss
            if self.use_aggregated: #TOOD: should be use aggregate
                aggregate_loss = loss_fct(aggregate_logits, labels[:, -1:]) * self.aggregate_weight
                loss += aggregate_loss

        if not return_dict:
            output = (logits, ) + base_model_output[1:]
            return ((loss,) + output) if loss is not None else output 
        
        return SequenceClassifierOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=base_model_output.past_key_values,
            hidden_states=base_model_output.hidden_states,
            attentions=base_model_output.attentions,
        )