File size: 2,014 Bytes
7e59788
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a7ef0da
 
7e59788
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import torch
from transformers import PreTrainedTokenizerBase

from .sensor_loc_finder import SensorLocFinder


class StoriesSensorLocFinder(SensorLocFinder):

    def __init__(self, tokenizer: PreTrainedTokenizerBase, **kwargs):
        self.questions_section_toks = tokenizer.encode("## Questions")
        self.question_mark_tok = tokenizer.encode("?")[0]
        self.other_question_mark_tok = tokenizer.encode(")?")[0]
        assert len(self.questions_section_toks) == 2

    def find_sensor_locs(self, input_ids: torch.Tensor) -> torch.Tensor:
        device = input_ids.device
        question_mark_locs = self._is_sensor_loc(input_ids)
        total_locs = torch.cumsum(question_mark_locs, dim=-1)
        total_overall = total_locs[:, -1]
        assert (
            total_overall == 3
        ).all(), "can handle different cases, but assuming this is easiest"
        eqs = total_locs[:, :, None] == torch.arange(1, 4)[None, None].to(device)
        locs = torch.where(
            eqs.any(dim=-2),
            torch.argmax(eqs.to(torch.uint8), dim=-2),
            input_ids.shape[-1] - 3,
        ).clamp(max=input_ids.shape[-1] - 3)
        aggregate_sensor_loc = locs[:, -1].unsqueeze(1)
        locs = torch.cat([locs, aggregate_sensor_loc], dim=1)
        return locs

    
    def _is_sensor_loc(self, input_ids: torch.Tensor):
        questions_section_toks = self.questions_section_toks
        question_mark_tok = self.question_mark_tok
        other_question_mark_tok = self.other_question_mark_tok
        eq_question_item = (input_ids[:, :-1] == questions_section_toks[0]) & (
            input_ids[:, 1:] == questions_section_toks[1]
        )
        assert (eq_question_item.sum(dim=-1, dtype=torch.int) == 1).all(), "could relax"

        summed = torch.cumsum(
            torch.cat([eq_question_item, eq_question_item[:, -1:]], dim=-1), dim=-1
        )
        return (summed > 0) & (
            (input_ids == question_mark_tok) | (input_ids == other_question_mark_tok)
        )