File size: 4,729 Bytes
9788df4
9c1e821
9788df4
 
 
 
5e08178
9788df4
 
5e08178
 
 
 
9788df4
 
 
 
5e08178
9788df4
 
 
 
 
9c1e821
9788df4
 
5e08178
9c1e821
 
 
 
 
9788df4
5e08178
9c1e821
5e08178
 
9788df4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9c1e821
 
9788df4
5e08178
9788df4
9c1e821
 
 
9788df4
 
9c1e821
 
9788df4
9c1e821
9788df4
 
 
9c1e821
9788df4
9c1e821
 
9788df4
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
from typing import Optional, Tuple, Union
from abc import abstractmethod

import torch
from torch.nn import BCEWithLogitsLoss
from transformers import PreTrainedModel, PreTrainedTokenizer
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from transformers.modeling_outputs import BaseModelOutputWithPast, SequenceClassifierOutputWithPast


from .sensor_loc_reg import SENSOR_LOC_REGISTRY
from .sensor_loc_finder import SensorLocFinder

class MeasurementPredictorMixin(PreTrainedModel):
    
    def __init__(self, config):
        super().__init__(config)
        self.sensor_loc_type = config.sensor_loc_type
        self.sensor_token = config.sensor_token
        self.n_sensors = config.n_sensors
        self.sensor_probes = torch.nn.ModuleList([
            torch.nn.Linear(config.emb_dim, 1) for _ in range(config.n_sensors)
        ])
        self.aggregate_probe = torch.nn.Linear(config.emb_dim, 1)
        self.sensors_weight = config.sensors_weight
        self.aggregate_weight = config.aggregate_weight

        self.find_sensor_locs: SensorLocFinder = None 
    
    @abstractmethod
    def set_pad_token(self, tokenizer: PreTrainedTokenizerBase):
        pass
    
    def init_sensor_loc_finder(self, tokenizer: PreTrainedTokenizerBase):
        self.find_sensor_locs = SENSOR_LOC_REGISTRY[self.sensor_loc_type](
            tokenizer, sensor_token=self.sensor_token, n_sensors=self.n_sensors
        )

    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
        r"""
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
            `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
            are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        base_model_output: BaseModelOutputWithPast = self.base_model(
            input_ids,
            past_key_values=past_key_values,
            attention_mask=attention_mask,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        # get sensor embeddings (including aggregate)
        sensor_locs = self.find_sensor_locs(input_ids)
        sensor_embs = base_model_output.last_hidden_state.gather(
            1, sensor_locs.unsqueeze(-1).expand(-1, -1, self.config.emb_dim)
        )
        assert sensor_embs.shape == (input_ids.shape[0], self.n_sensors + 1, self.config.emb_dim), sensor_embs.shape
        
        # get sensor and aggregate logits
        sensor_logits = torch.concat([self.sensor_probes[i](sensor_embs[:, i, :]) 
                               for i in range(self.n_sensors)], dim=-1)
        aggregate_logits = self.aggregate_probe(sensor_embs[:, -1, :])
        logits = torch.concat([sensor_logits, aggregate_logits], dim=-1)

        # compute loss
        loss = None
        if labels is not None:
            loss_fct = BCEWithLogitsLoss()
            sensor_loss = loss_fct(sensor_logits[:, :self.n_sensors], labels[:, :self.n_sensors]) * self.sensors_weight
            loss = sensor_loss
            aggregate_loss = loss_fct(aggregate_logits, labels[:, -1:]) * self.aggregate_weight
            loss += aggregate_loss

        if not return_dict:
            output = (logits, ) + base_model_output[1:]
            return ((loss,) + output) if loss is not None else output 
        
        return SequenceClassifierOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=base_model_output.past_key_values,
            hidden_states=base_model_output.hidden_states,
            attentions=base_model_output.attentions,
        )