File size: 12,613 Bytes
1ce5e18 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 |
import numpy as np
import torch
from torch.utils.data import Dataset, IterableDataset
from ..utils.generic import ModelOutput
class PipelineDataset(Dataset):
def __init__(self, dataset, process, params):
self.dataset = dataset
self.process = process
self.params = params
def __len__(self):
return len(self.dataset)
def __getitem__(self, i):
item = self.dataset[i]
processed = self.process(item, **self.params)
return processed
class PipelineIterator(IterableDataset):
def __init__(self, loader, infer, params, loader_batch_size=None):
"""
Roughly equivalent to
```
for item in loader:
yield infer(item, **params)
```
Arguments:
loader (`torch.utils.data.DataLoader` or any iterator):
The iterator that will be used to apply `infer` on.
infer (any function):
The function to apply of each element of `loader`.
params (`dict`):
The parameters passed to `infer` along with every item
loader_batch_size (`int`, *optional*):
If specified, the items of `loader` are supposed to come as batch, and are loader_batched here
making it roughly behave as
```
for items in loader:
for i in loader_batch_size:
item = items[i]
yield infer(item, **params)
```"""
self.loader = loader
self.infer = infer
self.params = params
if loader_batch_size == 1:
# Let's spare some time by deactivating altogether
loader_batch_size = None
self.loader_batch_size = loader_batch_size
# Internal bookkeeping
self._loader_batch_index = None
self._loader_batch_data = None
def __len__(self):
return len(self.loader)
def __iter__(self):
self.iterator = iter(self.loader)
return self
def loader_batch_item(self):
"""
Return item located at `loader_batch_index` within the current `loader_batch_data`.
"""
if isinstance(self._loader_batch_data, torch.Tensor):
# Batch data is simple tensor, just fetch the slice
result = self._loader_batch_data[self._loader_batch_index]
else:
# Batch data is assumed to be BaseModelOutput (or dict)
loader_batched = {}
for k, element in self._loader_batch_data.items():
if isinstance(element, ModelOutput):
# Convert ModelOutput to tuple first
element = element.to_tuple()
if isinstance(element[0], torch.Tensor):
loader_batched[k] = tuple(el[self._loader_batch_index].unsqueeze(0) for el in element)
elif isinstance(element[0], np.ndarray):
loader_batched[k] = tuple(np.expand_dims(el[self._loader_batch_index], 0) for el in element)
continue
if k in {"hidden_states", "past_key_values", "attentions"} and isinstance(element, tuple):
# Those are stored as lists of tensors so need specific unbatching.
if isinstance(element[0], torch.Tensor):
loader_batched[k] = tuple(el[self._loader_batch_index].unsqueeze(0) for el in element)
elif isinstance(element[0], np.ndarray):
loader_batched[k] = tuple(np.expand_dims(el[self._loader_batch_index], 0) for el in element)
continue
if element is None:
# This can happen for optional data that get passed around
loader_batched[k] = None
elif isinstance(element[self._loader_batch_index], torch.Tensor):
# Take correct batch data, but make it looked like batch_size=1
# For compatibility with other methods within transformers
loader_batched[k] = element[self._loader_batch_index].unsqueeze(0)
elif isinstance(element[self._loader_batch_index], np.ndarray):
# Take correct batch data, but make it looked like batch_size=1
# For compatibility with other methods within transformers
loader_batched[k] = np.expand_dims(element[self._loader_batch_index], 0)
else:
# This is typically a list, so no need to `unsqueeze`.
loader_batched[k] = element[self._loader_batch_index]
# Recreate the element by reusing the original class to make it look
# batch_size=1
result = self._loader_batch_data.__class__(loader_batched)
self._loader_batch_index += 1
return result
def __next__(self):
if self._loader_batch_index is not None and self._loader_batch_index < self.loader_batch_size:
# We are currently unrolling a batch so we just need to return
# the current item within a batch
return self.loader_batch_item()
# We're out of items within a batch
item = next(self.iterator)
processed = self.infer(item, **self.params)
# We now have a batch of "inferred things".
if self.loader_batch_size is not None:
# Try to infer the size of the batch
if isinstance(processed, torch.Tensor):
first_tensor = processed
else:
key = list(processed.keys())[0]
first_tensor = processed[key]
if isinstance(first_tensor, list):
observed_batch_size = len(first_tensor)
else:
observed_batch_size = first_tensor.shape[0]
if 0 < observed_batch_size < self.loader_batch_size:
# could be last batch so we can't unroll as many
# elements.
self.loader_batch_size = observed_batch_size
# Setting internal index to unwrap the batch
self._loader_batch_data = processed
self._loader_batch_index = 0
return self.loader_batch_item()
else:
# We're not unrolling batches
return processed
class PipelineChunkIterator(PipelineIterator):
def __init__(self, loader, infer, params, loader_batch_size=None):
"""
Roughly equivalent to
```
for iterator in loader:
for item in iterator:
yield infer(item, **params)
```
Arguments:
loader (`torch.utils.data.DataLoader` or any iterator):
The iterator that will be used to apply `infer` on.
infer (any function):
The function to apply of each element of `loader`.
params (`dict`):
The parameters passed to `infer` along with every item
"""
super().__init__(loader, infer, params)
def __iter__(self):
self.iterator = iter(self.loader)
self.subiterator = None
return self
def __next__(self):
if self.subiterator is None:
"Subiterator None means we haven't started a `preprocess` iterator. so start it"
self.subiterator = self.infer(next(self.iterator), **self.params)
try:
# Try to return next item
processed = next(self.subiterator)
except StopIteration:
# When a preprocess iterator ends, we can start lookig at the next item
# ChunkIterator will keep feeding until ALL elements of iterator
# all have created their subiterator and have been iterating against.
#
# Another way to look at it, is we're basically flattening lists of lists
# into a single list, but with generators
self.subiterator = self.infer(next(self.iterator), **self.params)
processed = next(self.subiterator)
return processed
class PipelinePackIterator(PipelineIterator):
"""
Roughly equivalent to
```
packed = []
for item in loader:
packed.append(item)
if item["is_last"]:
yield packed
packed = []
```
but it also handles cases where `item` are batched (meaning it's a dict of Tensor with first dimension > 1. In
that case it does
```
packed = []
for batch in loader:
# item is batched
for item in batch:
packed.append(item)
if item["is_last"]:
yield packed
packed = []
```
Arguments:
loader (`torch.utils.data.DataLoader` or any iterator):
The iterator that will be used to apply `infer` on.
infer (any function):
The function to apply of each element of `loader`.
params (`dict`):
The parameters passed to `infer` along with every item
loader_batch_size (`int`, *optional*):
If specified, the items of `loader` are supposed to come as batch, and are loader_batched here making
it roughly behave as
```
for items in loader:
for i in loader_batch_size:
item = items[i]
yield infer(item, **params)
```"""
def __iter__(self):
self.iterator = iter(self.loader)
return self
def __next__(self):
# Extremely similar to PipelineIterator in its unpacking mechanism
# BUT, we have an extra required item which is the presence of `is_last`
# That is because everything is flattened by `PipelineChunkIterator` we
# need to keep track of how to regroup here in the original `process`
# boundaries so that `process` and `postprocess` see the same data.
# This iterator accumulates items (possibly while unbatching) until it
# its a `is_last` and then just passes it on to the caller.
is_last = False
accumulator = []
if self._loader_batch_index is not None and self._loader_batch_index < self.loader_batch_size:
while self._loader_batch_index < self.loader_batch_size:
item = self.loader_batch_item()
is_last = item.pop("is_last")
accumulator.append(item)
if is_last:
return accumulator
while not is_last:
processed = self.infer(next(self.iterator), **self.params)
if self.loader_batch_size is not None:
if isinstance(processed, torch.Tensor):
first_tensor = processed
else:
key = list(processed.keys())[0]
first_tensor = processed[key]
if isinstance(first_tensor, list):
observed_batch_size = len(first_tensor)
else:
observed_batch_size = first_tensor.shape[0]
if 0 < observed_batch_size < self.loader_batch_size:
# could be last batch so we can't unroll as many
# elements.
self.loader_batch_size = observed_batch_size
self._loader_batch_data = processed
self._loader_batch_index = 0
while self._loader_batch_index < self.loader_batch_size:
item = self.loader_batch_item()
is_last = item.pop("is_last")
accumulator.append(item)
if is_last:
return accumulator
else:
item = processed
is_last = item.pop("is_last")
accumulator.append(item)
return accumulator
class KeyDataset(Dataset):
def __init__(self, dataset: Dataset, key: str):
self.dataset = dataset
self.key = key
def __len__(self):
return len(self.dataset)
def __getitem__(self, i):
return self.dataset[i][self.key]
class KeyPairDataset(Dataset):
def __init__(self, dataset: Dataset, key1: str, key2: str):
self.dataset = dataset
self.key1 = key1
self.key2 = key2
def __len__(self):
return len(self.dataset)
def __getitem__(self, i):
return {"text": self.dataset[i][self.key1], "text_pair": self.dataset[i][self.key2]}
|