File size: 804 Bytes
51b0cb7 446b3a0 51b0cb7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 |
import torch
from transformers import PreTrainedTokenizerBase
from .sensor_loc_finder import SensorLocFinder
class SensorLocFinderFromToken(SensorLocFinder):
def __init__(self, tokenizer: PreTrainedTokenizerBase, sensor_token: str, n_sensors: int):
self.sensor_token_id = tokenizer.encode(sensor_token)[0]
self.n_sensors = n_sensors
def find_sensor_locs(self, input_ids: torch.Tensor) -> torch.Tensor:
flat_sensor_token_idxs = (input_ids == self.sensor_token_id).nonzero(as_tuple=True)[1]
sensor_token_idxs = flat_sensor_token_idxs.view(-1, self.n_sensors)
aggregate_sensor_token_idx = sensor_token_idxs[:, -1].unsqueeze(1)
sensor_token_idxs = torch.cat([sensor_token_idxs, aggregate_sensor_token_idx], dim=1)
return sensor_token_idxs
|