Spaces:
Running
Running
File size: 2,696 Bytes
0a0161b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 |
import numpy as np
import torch
from torch.utils.data import Dataset
from collections import defaultdict
class event_detection_data(Dataset):
def __init__(self, raw_data, tokenizer, max_len, domain_adaption=False, wwm_prob=0.1):
self.len = len(raw_data)
self.data = raw_data
self.tokenizer = tokenizer
self.max_len = max_len
self.domain_adaption = domain_adaption
self.wwm_prob = wwm_prob
def __getitem__(self, index):
tokenized_inputs = self.tokenizer(
self.data[index]["text"],
add_special_tokens=True,
max_length=self.max_len,
padding='max_length',
return_token_type_ids=True,
truncation=True,
is_split_into_words=True
)
ids = tokenized_inputs['input_ids']
mask = tokenized_inputs['attention_mask']
if self.domain_adaption:
if self.tokenizer.is_fast:
input_ids, labels = self._whole_word_masking(self.tokenizer, tokenized_inputs, self.wwm_prob)
return {
'input_ids': torch.tensor(input_ids),
'attention_mask': torch.tensor(mask),
'labels': torch.tensor(labels, dtype=torch.long)
}
else:
print("requires fast tokenizer for word_ids")
else:
return {
'input_ids': torch.tensor(ids),
'attention_mask': torch.tensor(mask),
'targets': torch.tensor(self.data[index]["text_tag_id"][0], dtype=torch.long)
}
def __len__(self):
return self.len
def _whole_word_masking(self, tokenizer, tokenized_inputs, wwm_prob):
word_ids = tokenized_inputs.word_ids(0)
# create a map between words_ids and natural id
mapping = defaultdict(list)
current_word_index = -1
current_word = None
for idx, word_id in enumerate(word_ids):
if word_id is not None:
if word_id != current_word:
current_word = word_id
current_word_index += 1
mapping[current_word_index].append(idx)
# randomly mask words
mask = np.random.binomial(1, wwm_prob, (len(mapping),))
input_ids = tokenized_inputs["input_ids"]
# labels only contains masked words as target
labels = [-100] * len(input_ids)
for word_id in np.where(mask == 1)[0]:
for idx in mapping[word_id]:
labels[idx] = tokenized_inputs["input_ids"][idx]
input_ids[idx] = tokenizer.mask_token_id
return input_ids, labels |