|
""" |
|
This file contains functions for loading various needed data |
|
""" |
|
|
|
import json |
|
import torch |
|
import random |
|
import logging |
|
import os |
|
from random import random as rand |
|
from torch.utils.data import Dataset |
|
from torch.utils.data import DataLoader |
|
|
|
logger = logging.getLogger(__name__) |
|
local_file = os.path.split(__file__)[-1] |
|
logging.basicConfig( |
|
format='%(asctime)s : %(filename)s : %(funcName)s : %(levelname)s : %(message)s', |
|
level=logging.INFO) |
|
|
|
|
|
def load_acronym_kb(kb_path='acronym_kb.json'): |
|
f = open(kb_path, encoding='utf8') |
|
acronym_kb = json.load(f) |
|
for key, values in acronym_kb.items(): |
|
values = [v for v, s in values] |
|
acronym_kb[key] = values |
|
logger.info('loaded acronym dictionary successfully, in total there are [{a}] acronyms'.format(a=len(acronym_kb))) |
|
return acronym_kb |
|
|
|
|
|
def get_candidate(acronym_kb, short_term, can_num=10): |
|
return acronym_kb[short_term][:can_num] |
|
|
|
def load_data(path): |
|
data = list() |
|
for line in open(path, encoding='utf8'): |
|
row = json.loads(line) |
|
data.append(row) |
|
return data |
|
|
|
|
|
def load_dataset(data_path): |
|
all_short_term, all_long_term, all_context = list(), list(), list() |
|
for line in open(data_path, encoding='utf8'): |
|
obj = json.loads(line) |
|
short_term, long_term, context = obj['short_term'], obj['long_term'], ' '.join(obj['tokens']) |
|
all_short_term.append(short_term) |
|
all_long_term.append(long_term) |
|
all_context.append(context) |
|
|
|
return {'short_term': all_short_term, 'long_term': all_long_term, 'context':all_context} |
|
|
|
|
|
def load_pretrain(data_path): |
|
all_short_term, all_long_term, all_context = list(), list(), list() |
|
cnt = 0 |
|
for line in open(data_path, encoding='utf8'): |
|
cnt += 1 |
|
|
|
|
|
if cnt>200:continue |
|
obj = json.loads(line) |
|
short_term, long_term, context = obj['short_term'], obj['long_term'], ' '.join(obj['tokens']) |
|
all_short_term.append(short_term) |
|
all_long_term.append(long_term) |
|
all_context.append(context) |
|
|
|
return {'short_term': all_short_term, 'long_term': all_long_term, 'context': all_context} |
|
|
|
|
|
class TextData(Dataset): |
|
def __init__(self, data): |
|
self.all_short_term = data['short_term'] |
|
self.all_long_term = data['long_term'] |
|
self.all_context = data['context'] |
|
|
|
def __len__(self): |
|
return len(self.all_short_term) |
|
|
|
def __getitem__(self, idx): |
|
return self.all_short_term[idx], self.all_long_term[idx], self.all_context[idx] |
|
|
|
|
|
def random_negative(target, elements): |
|
flag, result = True, '' |
|
while flag: |
|
temp = random.choice(elements) |
|
if temp != target: |
|
result = temp |
|
flag = False |
|
return result |
|
|
|
|
|
class SimpleLoader(): |
|
def __init__(self, batch_size, tokenizer, kb, shuffle=True): |
|
self.batch_size = batch_size |
|
self.shuffle = shuffle |
|
self.tokenizer = tokenizer |
|
self.kb = kb |
|
|
|
def collate_fn(self, batch_data): |
|
pos_tag, neg_tag = 0, 1 |
|
batch_short_term, batch_long_term, batch_context = list(zip(*batch_data)) |
|
batch_short_term, batch_long_term, batch_context = list(batch_short_term), list(batch_long_term), list(batch_context) |
|
batch_negative, batch_label, batch_label_neg = list(), list(), list() |
|
for index in range(len(batch_short_term)): |
|
short_term, long_term, context = batch_short_term[index], batch_long_term[index], batch_context[index] |
|
batch_label.append(pos_tag) |
|
candidates = [v[0] for v in self.kb[short_term]] |
|
if len(candidates) == 1: |
|
batch_negative.append(long_term) |
|
batch_label_neg.append(pos_tag) |
|
continue |
|
|
|
negative = random_negative(long_term, candidates) |
|
batch_negative.append(negative) |
|
batch_label_neg.append(neg_tag) |
|
|
|
prompt = batch_context + batch_context |
|
long_terms = batch_long_term + batch_negative |
|
label = batch_label + batch_label_neg |
|
|
|
encoding = self.tokenizer(prompt, long_terms, return_tensors="pt", padding=True, truncation=True) |
|
label = torch.LongTensor(label) |
|
|
|
return encoding, label |
|
|
|
def __call__(self, data_path): |
|
dataset = load_dataset(data_path=data_path) |
|
dataset = TextData(dataset) |
|
train_iterator = DataLoader(dataset=dataset, batch_size=self.batch_size // 2, shuffle=self.shuffle, |
|
collate_fn=self.collate_fn) |
|
return train_iterator |
|
|
|
|
|
def mask_subword(subword_sequences, prob=0.15, masked_prob=0.8, VOCAB_SIZE=30522): |
|
PAD, CLS, SEP, MASK, BLANK = 0, 101, 102, 103, -100 |
|
masked_labels = list() |
|
for sentence in subword_sequences: |
|
labels = [BLANK for _ in range(len(sentence))] |
|
original = sentence[:] |
|
end = len(sentence) |
|
if PAD in sentence: |
|
end = sentence.index(PAD) |
|
for pos in range(end): |
|
if sentence[pos] in (CLS, SEP): continue |
|
if rand() > prob: continue |
|
if rand() < masked_prob: |
|
sentence[pos] = MASK |
|
elif rand() < 0.5: |
|
sentence[pos] = random.randint(0, VOCAB_SIZE-1) |
|
labels[pos] = original[pos] |
|
masked_labels.append(labels) |
|
return subword_sequences, masked_labels |
|
|
|
|
|
class AcroBERTLoader(): |
|
def __init__(self, batch_size, tokenizer, kb, shuffle=True, masked_prob=0.15, hard_num=2): |
|
self.batch_size = batch_size |
|
self.shuffle = shuffle |
|
self.tokenizer = tokenizer |
|
self.masked_prob = masked_prob |
|
self.hard_num = hard_num |
|
self.kb = kb |
|
self.all_long_terms = list() |
|
for vs in self.kb.values(): |
|
self.all_long_terms.extend(list(vs)) |
|
|
|
def select_negative(self, target): |
|
selected, flag, max_time = None, True, 10 |
|
if target in self.kb: |
|
long_term_candidates = self.kb[target] |
|
if len(long_term_candidates) == 1: |
|
long_term_candidates = self.all_long_terms |
|
else: |
|
long_term_candidates = self.all_long_terms |
|
attempt = 0 |
|
while flag and attempt < max_time: |
|
attempt += 1 |
|
selected = random.choice(long_term_candidates) |
|
if selected != target: |
|
flag = False |
|
if attempt == max_time: |
|
selected = random.choice(self.all_long_terms) |
|
return selected |
|
|
|
def collate_fn(self, batch_data): |
|
batch_short_term, batch_long_term, batch_context = list(zip(*batch_data)) |
|
pos_samples, neg_samples, masked_pos_samples = list(), list(), list() |
|
for _ in range(self.hard_num): |
|
temp_pos_samples = [batch_long_term[index] + ' [SEP] ' + batch_context[index] for index in range(len(batch_long_term))] |
|
neg_long_terms = [self.select_negative(st) for st in batch_short_term] |
|
temp_neg_samples = [neg_long_terms[index] + ' [SEP] ' + batch_context[index] for index in range(len(batch_long_term))] |
|
temp_masked_pos_samples = [batch_long_term[index] + ' [SEP] ' + batch_context[index] for index in range(len(batch_long_term))] |
|
|
|
pos_samples.extend(temp_pos_samples) |
|
neg_samples.extend(temp_neg_samples) |
|
masked_pos_samples.extend(temp_masked_pos_samples) |
|
return pos_samples, masked_pos_samples, neg_samples |
|
|
|
def __call__(self, data_path): |
|
dataset = load_pretrain(data_path=data_path) |
|
logger.info('loaded dataset, sample = {a}'.format(a=len(dataset['short_term']))) |
|
dataset = TextData(dataset) |
|
train_iterator = DataLoader(dataset=dataset, batch_size=self.batch_size // (2 * self.hard_num), shuffle=self.shuffle, |
|
collate_fn=self.collate_fn) |
|
return train_iterator |
|
|
|
|
|
|