File size: 9,567 Bytes
b213d84 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 |
# Copyright (c) Facebook, Inc. and its affiliates.
import io
import numpy as np
import os
from dataclasses import dataclass
from functools import reduce
from operator import mul
from typing import BinaryIO, Dict, Optional, Tuple
import torch
from detectron2.utils.comm import gather, get_rank
from detectron2.utils.file_io import PathManager
@dataclass
class SizeData:
dtype: str
shape: Tuple[int]
def _calculate_record_field_size_b(data_schema: Dict[str, SizeData], field_name: str) -> int:
schema = data_schema[field_name]
element_size_b = np.dtype(schema.dtype).itemsize
record_field_size_b = reduce(mul, schema.shape) * element_size_b
return record_field_size_b
def _calculate_record_size_b(data_schema: Dict[str, SizeData]) -> int:
record_size_b = 0
for field_name in data_schema:
record_field_size_b = _calculate_record_field_size_b(data_schema, field_name)
record_size_b += record_field_size_b
return record_size_b
def _calculate_record_field_sizes_b(data_schema: Dict[str, SizeData]) -> Dict[str, int]:
field_sizes_b = {}
for field_name in data_schema:
field_sizes_b[field_name] = _calculate_record_field_size_b(data_schema, field_name)
return field_sizes_b
class SingleProcessTensorStorage:
"""
Compact tensor storage to keep tensor data of predefined size and type.
"""
def __init__(self, data_schema: Dict[str, SizeData], storage_impl: BinaryIO):
"""
Construct tensor storage based on information on data shape and size.
Internally uses numpy to interpret the type specification.
The storage must support operations `seek(offset, whence=os.SEEK_SET)` and
`read(size)` to be able to perform the `get` operation.
The storage must support operation `write(bytes)` to be able to perform
the `put` operation.
Args:
data_schema (dict: str -> SizeData): dictionary which maps tensor name
to its size data (shape and data type), e.g.
```
{
"coarse_segm": SizeData(dtype="float32", shape=(112, 112)),
"embedding": SizeData(dtype="float32", shape=(16, 112, 112)),
}
```
storage_impl (BinaryIO): io instance that handles file-like seek, read
and write operations, e.g. a file handle or a memory buffer like io.BytesIO
"""
self.data_schema = data_schema
self.record_size_b = _calculate_record_size_b(data_schema)
self.record_field_sizes_b = _calculate_record_field_sizes_b(data_schema)
self.storage_impl = storage_impl
self.next_record_id = 0
def get(self, record_id: int) -> Dict[str, torch.Tensor]:
"""
Load tensors from the storage by record ID
Args:
record_id (int): Record ID, for which to load the data
Return:
dict: str -> tensor: tensor name mapped to tensor data, recorded under the provided ID
"""
self.storage_impl.seek(record_id * self.record_size_b, os.SEEK_SET)
data_bytes = self.storage_impl.read(self.record_size_b)
assert len(data_bytes) == self.record_size_b, (
f"Expected data size {self.record_size_b} B could not be read: "
f"got {len(data_bytes)} B"
)
record = {}
cur_idx = 0
# it's important to read and write in the same order
for field_name in sorted(self.data_schema):
schema = self.data_schema[field_name]
field_size_b = self.record_field_sizes_b[field_name]
chunk = data_bytes[cur_idx : cur_idx + field_size_b]
data_np = np.frombuffer(
chunk, dtype=schema.dtype, count=reduce(mul, schema.shape)
).reshape(schema.shape)
record[field_name] = torch.from_numpy(data_np)
cur_idx += field_size_b
return record
def put(self, data: Dict[str, torch.Tensor]) -> int:
"""
Store tensors in the storage
Args:
data (dict: str -> tensor): data to store, a dictionary which maps
tensor names into tensors; tensor shapes must match those specified
in data schema.
Return:
int: record ID, under which the data is stored
"""
# it's important to read and write in the same order
for field_name in sorted(self.data_schema):
assert (
field_name in data
), f"Field '{field_name}' not present in data: data keys are {data.keys()}"
value = data[field_name]
assert value.shape == self.data_schema[field_name].shape, (
f"Mismatched tensor shapes for field '{field_name}': "
f"expected {self.data_schema[field_name].shape}, got {value.shape}"
)
data_bytes = value.cpu().numpy().tobytes()
assert len(data_bytes) == self.record_field_sizes_b[field_name], (
f"Expected field {field_name} to be of size "
f"{self.record_field_sizes_b[field_name]} B, got {len(data_bytes)} B"
)
self.storage_impl.write(data_bytes)
record_id = self.next_record_id
self.next_record_id += 1
return record_id
class SingleProcessFileTensorStorage(SingleProcessTensorStorage):
"""
Implementation of a single process tensor storage which stores data in a file
"""
def __init__(self, data_schema: Dict[str, SizeData], fpath: str, mode: str):
self.fpath = fpath
assert "b" in mode, f"Tensor storage should be opened in binary mode, got '{mode}'"
if "w" in mode:
# pyre-fixme[6]: For 2nd argument expected `Union[typing_extensions.Liter...
file_h = PathManager.open(fpath, mode)
elif "r" in mode:
local_fpath = PathManager.get_local_path(fpath)
file_h = open(local_fpath, mode)
else:
raise ValueError(f"Unsupported file mode {mode}, supported modes: rb, wb")
super().__init__(data_schema, file_h) # pyre-ignore[6]
class SingleProcessRamTensorStorage(SingleProcessTensorStorage):
"""
Implementation of a single process tensor storage which stores data in RAM
"""
def __init__(self, data_schema: Dict[str, SizeData], buf: io.BytesIO):
super().__init__(data_schema, buf)
class MultiProcessTensorStorage:
"""
Representation of a set of tensor storages created by individual processes,
allows to access those storages from a single owner process. The storages
should either be shared or broadcasted to the owner process.
The processes are identified by their rank, data is uniquely defined by
the rank of the process and the record ID.
"""
def __init__(self, rank_to_storage: Dict[int, SingleProcessTensorStorage]):
self.rank_to_storage = rank_to_storage
def get(self, rank: int, record_id: int) -> Dict[str, torch.Tensor]:
storage = self.rank_to_storage[rank]
return storage.get(record_id)
def put(self, rank: int, data: Dict[str, torch.Tensor]) -> int:
storage = self.rank_to_storage[rank]
return storage.put(data)
class MultiProcessFileTensorStorage(MultiProcessTensorStorage):
def __init__(self, data_schema: Dict[str, SizeData], rank_to_fpath: Dict[int, str], mode: str):
rank_to_storage = {
rank: SingleProcessFileTensorStorage(data_schema, fpath, mode)
for rank, fpath in rank_to_fpath.items()
}
super().__init__(rank_to_storage) # pyre-ignore[6]
class MultiProcessRamTensorStorage(MultiProcessTensorStorage):
def __init__(self, data_schema: Dict[str, SizeData], rank_to_buffer: Dict[int, io.BytesIO]):
rank_to_storage = {
rank: SingleProcessRamTensorStorage(data_schema, buf)
for rank, buf in rank_to_buffer.items()
}
super().__init__(rank_to_storage) # pyre-ignore[6]
def _ram_storage_gather(
storage: SingleProcessRamTensorStorage, dst_rank: int = 0
) -> Optional[MultiProcessRamTensorStorage]:
storage.storage_impl.seek(0, os.SEEK_SET)
# TODO: overhead, pickling a bytes object, can just pass bytes in a tensor directly
# see detectron2/utils.comm.py
data_list = gather(storage.storage_impl.read(), dst=dst_rank)
if get_rank() != dst_rank:
return None
rank_to_buffer = {i: io.BytesIO(data_list[i]) for i in range(len(data_list))}
multiprocess_storage = MultiProcessRamTensorStorage(storage.data_schema, rank_to_buffer)
return multiprocess_storage
def _file_storage_gather(
storage: SingleProcessFileTensorStorage,
dst_rank: int = 0,
mode: str = "rb",
) -> Optional[MultiProcessFileTensorStorage]:
storage.storage_impl.close()
fpath_list = gather(storage.fpath, dst=dst_rank)
if get_rank() != dst_rank:
return None
rank_to_fpath = {i: fpath_list[i] for i in range(len(fpath_list))}
return MultiProcessFileTensorStorage(storage.data_schema, rank_to_fpath, mode)
def storage_gather(
storage: SingleProcessTensorStorage, dst_rank: int = 0
) -> Optional[MultiProcessTensorStorage]:
if isinstance(storage, SingleProcessRamTensorStorage):
return _ram_storage_gather(storage, dst_rank)
elif isinstance(storage, SingleProcessFileTensorStorage):
return _file_storage_gather(storage, dst_rank)
raise Exception(f"Unsupported storage for gather operation: {storage}")
|