|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
|
from . import BaseWrapperDataset, data_utils |
|
from fairseq.data.text_compressor import TextCompressor, TextCompressionLevel |
|
|
|
|
|
class AddTargetDataset(BaseWrapperDataset): |
|
def __init__( |
|
self, |
|
dataset, |
|
labels, |
|
pad, |
|
eos, |
|
batch_targets, |
|
process_label=None, |
|
label_len_fn=None, |
|
add_to_input=False, |
|
text_compression_level=TextCompressionLevel.none, |
|
): |
|
super().__init__(dataset) |
|
self.labels = labels |
|
self.batch_targets = batch_targets |
|
self.pad = pad |
|
self.eos = eos |
|
self.process_label = process_label |
|
self.label_len_fn = label_len_fn |
|
self.add_to_input = add_to_input |
|
self.text_compressor = TextCompressor(level=text_compression_level) |
|
|
|
def get_label(self, index, process_fn=None): |
|
lbl = self.labels[index] |
|
lbl = self.text_compressor.decompress(lbl) |
|
return lbl if process_fn is None else process_fn(lbl) |
|
|
|
def __getitem__(self, index): |
|
item = self.dataset[index] |
|
item["label"] = self.get_label(index, process_fn=self.process_label) |
|
return item |
|
|
|
def size(self, index): |
|
sz = self.dataset.size(index) |
|
own_sz = self.label_len_fn(self.get_label(index)) |
|
return sz, own_sz |
|
|
|
def collater(self, samples): |
|
collated = self.dataset.collater(samples) |
|
if len(collated) == 0: |
|
return collated |
|
indices = set(collated["id"].tolist()) |
|
target = [s["label"] for s in samples if s["id"] in indices] |
|
|
|
if self.add_to_input: |
|
eos = torch.LongTensor([self.eos]) |
|
prev_output_tokens = [torch.cat([eos, t], axis=-1) for t in target] |
|
target = [torch.cat([t, eos], axis=-1) for t in target] |
|
collated["net_input"]["prev_output_tokens"] = prev_output_tokens |
|
|
|
if self.batch_targets: |
|
collated["target_lengths"] = torch.LongTensor([len(t) for t in target]) |
|
target = data_utils.collate_tokens(target, pad_idx=self.pad, left_pad=False) |
|
collated["ntokens"] = collated["target_lengths"].sum().item() |
|
if getattr(collated["net_input"], "prev_output_tokens", None): |
|
collated["net_input"]["prev_output_tokens"] = data_utils.collate_tokens( |
|
collated["net_input"]["prev_output_tokens"], |
|
pad_idx=self.pad, |
|
left_pad=False, |
|
) |
|
else: |
|
collated["ntokens"] = sum([len(t) for t in target]) |
|
|
|
collated["target"] = target |
|
return collated |
|
|
|
def filter_indices_by_size(self, indices, max_sizes): |
|
indices, ignored = data_utils._filter_by_size_dynamic( |
|
indices, self.size, max_sizes |
|
) |
|
return indices, ignored |
|
|