|
|
|
|
|
|
|
import pickle |
|
|
|
import torch |
|
|
|
from data.parser.from_mrp.node_centric_parser import NodeCentricParser |
|
from data.parser.from_mrp.labeled_edge_parser import LabeledEdgeParser |
|
from data.parser.from_mrp.sequential_parser import SequentialParser |
|
from data.parser.from_mrp.evaluation_parser import EvaluationParser |
|
from data.parser.from_mrp.request_parser import RequestParser |
|
from data.field.edge_field import EdgeField |
|
from data.field.edge_label_field import EdgeLabelField |
|
from data.field.field import Field |
|
from data.field.mini_torchtext.field import Field as TorchTextField |
|
from data.field.label_field import LabelField |
|
from data.field.anchored_label_field import AnchoredLabelField |
|
from data.field.nested_field import NestedField |
|
from data.field.basic_field import BasicField |
|
from data.field.bert_field import BertField |
|
from data.field.anchor_field import AnchorField |
|
from data.batch import Batch |
|
|
|
|
|
def char_tokenize(word): |
|
return [c for i, c in enumerate(word)] |
|
|
|
|
|
class Collate: |
|
def __call__(self, batch): |
|
batch.sort(key=lambda example: example["every_input"][0].size(0), reverse=True) |
|
return Batch.build(batch) |
|
|
|
|
|
class Dataset: |
|
def __init__(self, args, verbose=True): |
|
self.verbose = verbose |
|
self.sos, self.eos, self.pad, self.unk = "<sos>", "<eos>", "<pad>", "<unk>" |
|
|
|
self.bert_input_field = BertField() |
|
self.scatter_field = BasicField() |
|
self.every_word_input_field = Field(lower=True, init_token=self.sos, eos_token=self.eos, batch_first=True, include_lengths=True) |
|
|
|
char_form_nesting = TorchTextField(tokenize=char_tokenize, init_token=self.sos, eos_token=self.eos, batch_first=True) |
|
self.char_form_field = NestedField(char_form_nesting, include_lengths=True) |
|
|
|
self.label_field = LabelField(preprocessing=lambda nodes: [n["label"] for n in nodes]) |
|
self.anchored_label_field = AnchoredLabelField() |
|
|
|
self.id_field = Field(batch_first=True, tokenize=lambda x: [x]) |
|
self.edge_presence_field = EdgeField() |
|
self.edge_label_field = EdgeLabelField() |
|
self.anchor_field = AnchorField() |
|
self.source_anchor_field = AnchorField() |
|
self.target_anchor_field = AnchorField() |
|
self.token_interval_field = BasicField() |
|
|
|
self.load_dataset(args) |
|
|
|
def log(self, text): |
|
if not self.verbose: |
|
return |
|
print(text, flush=True) |
|
|
|
def load_state_dict(self, args, d): |
|
for key, value in d["vocabs"].items(): |
|
getattr(self, key).vocab = pickle.loads(value) |
|
|
|
def state_dict(self): |
|
return { |
|
"vocabs": {key: pickle.dumps(value.vocab) for key, value in self.__dict__.items() if hasattr(value, "vocab")} |
|
} |
|
|
|
def load_sentences(self, sentences, args): |
|
dataset = RequestParser( |
|
sentences, args, |
|
fields={ |
|
"input": [("every_input", self.every_word_input_field), ("char_form_input", self.char_form_field)], |
|
"bert input": ("input", self.bert_input_field), |
|
"to scatter": ("input_scatter", self.scatter_field), |
|
"token anchors": ("token_intervals", self.token_interval_field), |
|
"id": ("id", self.id_field), |
|
}, |
|
) |
|
|
|
self.every_word_input_field.build_vocab(dataset, min_freq=1, specials=[self.pad, self.unk, self.sos, self.eos]) |
|
self.id_field.build_vocab(dataset, min_freq=1, specials=[]) |
|
|
|
return torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, collate_fn=Collate()) |
|
|
|
def load_dataset(self, args): |
|
parser = { |
|
"sequential": SequentialParser, |
|
"node-centric": NodeCentricParser, |
|
"labeled-edge": LabeledEdgeParser |
|
}[args.graph_mode] |
|
|
|
train = parser( |
|
args, "training", |
|
fields={ |
|
"input": [("every_input", self.every_word_input_field), ("char_form_input", self.char_form_field)], |
|
"bert input": ("input", self.bert_input_field), |
|
"to scatter": ("input_scatter", self.scatter_field), |
|
"nodes": ("labels", self.label_field), |
|
"anchored labels": ("anchored_labels", self.anchored_label_field), |
|
"edge presence": ("edge_presence", self.edge_presence_field), |
|
"edge labels": ("edge_labels", self.edge_label_field), |
|
"anchor edges": ("anchor", self.anchor_field), |
|
"source anchor edges": ("source_anchor", self.source_anchor_field), |
|
"target anchor edges": ("target_anchor", self.target_anchor_field), |
|
"token anchors": ("token_intervals", self.token_interval_field), |
|
"id": ("id", self.id_field), |
|
}, |
|
filter_pred=lambda example: len(example.input) <= 256, |
|
) |
|
|
|
val = parser( |
|
args, "validation", |
|
fields={ |
|
"input": [("every_input", self.every_word_input_field), ("char_form_input", self.char_form_field)], |
|
"bert input": ("input", self.bert_input_field), |
|
"to scatter": ("input_scatter", self.scatter_field), |
|
"nodes": ("labels", self.label_field), |
|
"anchored labels": ("anchored_labels", self.anchored_label_field), |
|
"edge presence": ("edge_presence", self.edge_presence_field), |
|
"edge labels": ("edge_labels", self.edge_label_field), |
|
"anchor edges": ("anchor", self.anchor_field), |
|
"source anchor edges": ("source_anchor", self.source_anchor_field), |
|
"target anchor edges": ("target_anchor", self.target_anchor_field), |
|
"token anchors": ("token_intervals", self.token_interval_field), |
|
"id": ("id", self.id_field), |
|
}, |
|
) |
|
|
|
test = EvaluationParser( |
|
args, |
|
fields={ |
|
"input": [("every_input", self.every_word_input_field), ("char_form_input", self.char_form_field)], |
|
"bert input": ("input", self.bert_input_field), |
|
"to scatter": ("input_scatter", self.scatter_field), |
|
"token anchors": ("token_intervals", self.token_interval_field), |
|
"id": ("id", self.id_field), |
|
}, |
|
) |
|
|
|
del train.data, val.data, test.data |
|
for f in list(train.fields.values()) + list(val.fields.values()) + list(test.fields.values()): |
|
if hasattr(f, "preprocessing"): |
|
del f.preprocessing |
|
|
|
self.train_size = len(train) |
|
self.val_size = len(val) |
|
self.test_size = len(test) |
|
|
|
self.log(f"\n{self.train_size} sentences in the train split") |
|
self.log(f"{self.val_size} sentences in the validation split") |
|
self.log(f"{self.test_size} sentences in the test split") |
|
|
|
self.node_count = train.node_counter |
|
self.token_count = train.input_count |
|
self.edge_count = train.edge_counter |
|
self.no_edge_count = train.no_edge_counter |
|
self.anchor_freq = train.anchor_freq |
|
|
|
self.source_anchor_freq = train.source_anchor_freq if hasattr(train, "source_anchor_freq") else 0.5 |
|
self.target_anchor_freq = train.target_anchor_freq if hasattr(train, "target_anchor_freq") else 0.5 |
|
self.log(f"{self.node_count} nodes in the train split") |
|
|
|
self.every_word_input_field.build_vocab(val, test, min_freq=1, specials=[self.pad, self.unk, self.sos, self.eos]) |
|
self.char_form_field.build_vocab(train, min_freq=1, specials=[self.pad, self.unk, self.sos, self.eos]) |
|
self.char_form_field.nesting_field.vocab = self.char_form_field.vocab |
|
self.id_field.build_vocab(train, val, test, min_freq=1, specials=[]) |
|
self.label_field.build_vocab(train) |
|
self.anchored_label_field.vocab = self.label_field.vocab |
|
self.edge_label_field.build_vocab(train) |
|
print(list(self.edge_label_field.vocab.freqs.keys()), flush=True) |
|
|
|
self.char_form_vocab_size = len(self.char_form_field.vocab) |
|
self.create_label_freqs(args) |
|
self.create_edge_freqs(args) |
|
|
|
self.log(f"Edge frequency: {self.edge_presence_freq*100:.2f} %") |
|
self.log(f"{len(self.label_field.vocab)} words in the label vocabulary") |
|
self.log(f"{len(self.anchored_label_field.vocab)} words in the anchored label vocabulary") |
|
self.log(f"{len(self.edge_label_field.vocab)} words in the edge label vocabulary") |
|
self.log(f"{len(self.char_form_field.vocab)} characters in the vocabulary") |
|
|
|
self.log(self.label_field.vocab.freqs) |
|
self.log(self.anchored_label_field.vocab.freqs) |
|
|
|
self.train = torch.utils.data.DataLoader( |
|
train, |
|
batch_size=args.batch_size, |
|
shuffle=True, |
|
num_workers=args.workers, |
|
collate_fn=Collate(), |
|
pin_memory=True, |
|
drop_last=True |
|
) |
|
self.train_size = len(self.train.dataset) |
|
|
|
self.val = torch.utils.data.DataLoader( |
|
val, |
|
batch_size=args.batch_size, |
|
shuffle=False, |
|
num_workers=args.workers, |
|
collate_fn=Collate(), |
|
pin_memory=True, |
|
) |
|
self.val_size = len(self.val.dataset) |
|
|
|
self.test = torch.utils.data.DataLoader( |
|
test, |
|
batch_size=args.batch_size, |
|
shuffle=False, |
|
num_workers=args.workers, |
|
collate_fn=Collate(), |
|
pin_memory=True, |
|
) |
|
self.test_size = len(self.test.dataset) |
|
|
|
if self.verbose: |
|
batch = next(iter(self.train)) |
|
print(f"\nBatch content: {Batch.to_str(batch)}\n") |
|
print(flush=True) |
|
|
|
def create_label_freqs(self, args): |
|
n_rules = len(self.label_field.vocab) |
|
blank_count = (args.query_length * self.token_count - self.node_count) |
|
label_counts = [blank_count] + [ |
|
self.label_field.vocab.freqs[self.label_field.vocab.itos[i]] |
|
for i in range(n_rules) |
|
] |
|
label_counts = torch.FloatTensor(label_counts) |
|
self.label_freqs = label_counts / (self.node_count + blank_count) |
|
self.log(f"Label frequency: {self.label_freqs}") |
|
|
|
def create_edge_freqs(self, args): |
|
edge_counter = [ |
|
self.edge_label_field.vocab.freqs[self.edge_label_field.vocab.itos[i]] for i in range(len(self.edge_label_field.vocab)) |
|
] |
|
edge_counter = torch.FloatTensor(edge_counter) |
|
self.edge_label_freqs = edge_counter / self.edge_count |
|
self.edge_presence_freq = self.edge_count / (self.edge_count + self.no_edge_count) |
|
|