|
|
|
|
|
|
|
|
|
|
|
from . import BaseWrapperDataset |
|
|
|
|
|
class ReplaceDataset(BaseWrapperDataset): |
|
"""Replaces tokens found in the dataset by a specified replacement token |
|
|
|
Args: |
|
dataset (~torch.utils.data.Dataset): dataset to replace tokens in |
|
replace_map(Dictionary[int,int]): map of token to replace -> replacement token |
|
offsets (List[int]): do not replace tokens before (from left if pos, right if neg) this offset. should be |
|
as many as the number of objects returned by the underlying dataset __getitem__ method. |
|
""" |
|
|
|
def __init__(self, dataset, replace_map, offsets): |
|
super().__init__(dataset) |
|
assert len(replace_map) > 0 |
|
self.replace_map = replace_map |
|
self.offsets = offsets |
|
|
|
def __getitem__(self, index): |
|
item = self.dataset[index] |
|
is_tuple = isinstance(item, tuple) |
|
srcs = item if is_tuple else [item] |
|
|
|
for offset, src in zip(self.offsets, srcs): |
|
for k, v in self.replace_map.items(): |
|
src_off = src[offset:] if offset >= 0 else src[:offset] |
|
src_off.masked_fill_(src_off == k, v) |
|
|
|
item = srcs if is_tuple else srcs[0] |
|
return item |
|
|