File size: 635 Bytes
5e08178
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import torch
from transformers import PreTrainedTokenizerBase

from .sensor_loc_finder import SensorLocFinder


class SensorLocFinderFromToken(SensorLocFinder):

    def __init__(self, tokenizer: PreTrainedTokenizerBase, sensor_token: str, n_sensors: int):
        self.sensor_token_id = tokenizer.encode(sensor_token)[0]
        self.n_sensors = n_sensors

    def find_sensor_locs(self, input_ids: torch.Tensor) -> torch.Tensor:
        flat_sensor_token_idxs = (input_ids == self.sensor_token_id).nonzero(as_tuple=True)[1]
        sensor_token_idxs = flat_sensor_token_idxs.view(-1, self.n_sensors)
        return sensor_token_idxs