|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import codecs |
|
import collections |
|
import contextlib |
|
import operator |
|
import os |
|
import pickle |
|
import zipfile |
|
from functools import reduce |
|
from typing import Any, Optional, Tuple, Union |
|
|
|
import accelerate |
|
import numpy |
|
import torch |
|
from pydantic import BaseModel, PrivateAttr |
|
|
|
ACCEPTABLE_TYPES = { |
|
("torch._utils", "_rebuild_tensor_v2"): torch._utils._rebuild_tensor_v2, |
|
("collections", "OrderedDict"): collections.OrderedDict, |
|
("numpy.core.multiarray", "scalar"): numpy.core.multiarray.scalar, |
|
("numpy", "dtype"): numpy.core.multiarray.scalar, |
|
("_codecs", "encode"): codecs.encode, |
|
**{ |
|
("torch", name): getattr(torch, name) |
|
for name in [ |
|
"DoubleStorage", |
|
"FloatStorage", |
|
"HalfStorage", |
|
"LongStorage", |
|
"IntStorage", |
|
"ShortStorage", |
|
"CharStorage", |
|
"ByteStorage", |
|
"BoolStorage", |
|
"BFloat16Storage", |
|
] |
|
}, |
|
} |
|
|
|
|
|
class DeferredLoad(BaseModel, arbitrary_types_allowed=True): |
|
name: str |
|
location: str |
|
dtype: torch.dtype |
|
|
|
|
|
file_offset: Optional[int] = None |
|
shape: Optional[Union[torch.Size, Tuple[int, ...]]] = None |
|
stride: Optional[Tuple[int, ...]] = None |
|
|
|
|
|
requires_grad: bool = False |
|
_backward_hooks: Any = PrivateAttr(None) |
|
|
|
@staticmethod |
|
def rebuild( |
|
load: "DeferredLoad", |
|
offset: int, |
|
shape: Union[torch.Size, Tuple[int, ...]], |
|
stride: Tuple[int, ...], |
|
) -> "DeferredLoad": |
|
load.shape = shape |
|
load.stride = stride |
|
load.file_offset = offset * dtype_bytes(load.dtype) |
|
return load |
|
|
|
def execute( |
|
self, |
|
reader: "TorchArchiveReader", |
|
map_location: Any = None, |
|
) -> torch.Tensor: |
|
total_params = reduce(operator.mul, self.shape) |
|
total_bytes = total_params * dtype_bytes(self.dtype) |
|
|
|
f = reader.open_file(file_name=self.name, offset=self.file_offset) |
|
storage = torch.UntypedStorage.from_buffer( |
|
f.read(total_bytes), "little", dtype=self.dtype |
|
) |
|
storage = torch.serialization._get_restore_location(map_location)( |
|
storage, self.location |
|
) |
|
|
|
tensor = torch.tensor([], dtype=self.dtype, device=storage.device) |
|
tensor.set_(storage, 0, self.shape, self.stride) |
|
tensor.requires_grad = self.requires_grad |
|
tensor._backward_hooks = self._backward_hooks |
|
return tensor |
|
|
|
|
|
class LazyTorchUnpickler(pickle.Unpickler): |
|
def find_class(self, module: str, name: str) -> Any: |
|
if (module, name) in ACCEPTABLE_TYPES: |
|
return ACCEPTABLE_TYPES[(module, name)] |
|
raise pickle.UnpicklingError(f"Unsupported type {module}.{name}") |
|
|
|
def persistent_load(self, pid: Any) -> Any: |
|
if not isinstance(pid, tuple) or pid[0] != "storage": |
|
raise RuntimeError(f"Unpickling object with unexpected PID: {repr(pid)}") |
|
|
|
storage_type, key, location, _ = pid[1:] |
|
return DeferredLoad(name=key, location=location, dtype=get_dtype(storage_type)) |
|
|
|
|
|
class TorchArchiveReader: |
|
""" |
|
Class for lazily reading (sections of) files from a torch ZIP archive. |
|
|
|
Maintains a handle to the most recently opened file for faster access with |
|
consecutive reads from the same file. |
|
""" |
|
|
|
archive: zipfile.ZipFile |
|
archive_name: str |
|
file_name: Optional[str] = None |
|
file: Optional[zipfile.ZipExtFile] = None |
|
|
|
def __init__(self, path: str): |
|
self.archive = zipfile.ZipFile(path, mode="r") |
|
self.archive_name = os.path.basename(os.path.normpath(path)).split(".")[0] |
|
|
|
def open_file(self, file_name: str, offset: int = 0) -> zipfile.ZipExtFile: |
|
if self.file_name != file_name or ( |
|
self.file is not None and self.file.tell() > offset |
|
): |
|
if self.file is not None: |
|
self.file.close() |
|
|
|
try: |
|
fd = self.archive.open(f"archive/data/{file_name}", mode="r") |
|
except Exception: |
|
fd = self.archive.open( |
|
f"{self.archive_name}/data/{file_name}", mode="r" |
|
) |
|
self.file = fd |
|
self.file_name = file_name |
|
|
|
skip_bytes = offset - self.file.tell() |
|
assert skip_bytes >= 0 |
|
self.file.seek(skip_bytes, os.SEEK_CUR) |
|
|
|
return self.file |
|
|
|
|
|
@contextlib.contextmanager |
|
def torch_lazy_load(): |
|
""" |
|
Context manager under which `torch.load` will return a `DeferredLoad` instead |
|
of `torch.Tensor.` |
|
""" |
|
old_unpickler = pickle.Unpickler |
|
old_load = pickle.load |
|
old_rebuild_tensor = torch._utils._rebuild_tensor |
|
try: |
|
|
|
def load_monkeypatch(*args, **kwargs): |
|
return pickle.Unpickler(*args, **kwargs).load() |
|
|
|
pickle.Unpickler = LazyTorchUnpickler |
|
pickle.load = load_monkeypatch |
|
torch._utils._rebuild_tensor = DeferredLoad.rebuild |
|
|
|
with accelerate.init_empty_weights(): |
|
yield |
|
|
|
finally: |
|
torch._utils._rebuild_tensor = old_rebuild_tensor |
|
pickle.Unpickler = old_unpickler |
|
pickle.load = old_load |
|
|
|
|
|
def dtype_bytes(dtype: torch.dtype) -> int: |
|
"""Return the number of bytes used to store a single instance of `dtype`.""" |
|
if dtype.is_floating_point: |
|
ti = torch.finfo(dtype) |
|
else: |
|
ti = torch.iinfo(dtype) |
|
return max(1, ti.bits // 8) |
|
|
|
|
|
def get_dtype(storage_type: Any): |
|
if isinstance(storage_type, torch.dtype): |
|
return storage_type |
|
dtype = storage_type.dtype |
|
if not isinstance(dtype, torch.dtype): |
|
dtype = storage_type(0).dtype |
|
return dtype |
|
|