|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from abc import ABC, abstractmethod |
|
from typing import Dict, Optional, Sequence |
|
|
|
import safetensors |
|
import torch |
|
|
|
from mergekit.io.lazy_unpickle import DeferredLoad, TorchArchiveReader, torch_lazy_load |
|
|
|
|
|
class TensorLoader(ABC): |
|
"""Base class for (potentially lazy) tensor loaders.""" |
|
|
|
@abstractmethod |
|
def get_tensor(self, key: str) -> torch.Tensor: |
|
... |
|
|
|
@abstractmethod |
|
def keys(self) -> Sequence[str]: |
|
... |
|
|
|
@classmethod |
|
def get( |
|
cls, |
|
shard_path: str, |
|
use_lazy_unpickle: bool = False, |
|
device: Optional[str] = None, |
|
) -> "TensorLoader": |
|
if shard_path.lower().endswith(".safetensors"): |
|
|
|
return safetensors.safe_open( |
|
shard_path, framework="pt", device=device or "cpu" |
|
) |
|
elif use_lazy_unpickle: |
|
return LazyPickleLoader(shard_path, device=device) |
|
return DumbPytorchLoader(shard_path, device=device) |
|
|
|
|
|
class LazyPickleLoader(TensorLoader): |
|
"""Loader for pytorch files using a custom unpickler and vigorous monkeypatching.""" |
|
|
|
zip_reader: TorchArchiveReader |
|
index: Dict[str, DeferredLoad] |
|
device: Optional[str] = None |
|
|
|
def __init__(self, path: str, device: Optional[str] = None): |
|
self.zip_reader = TorchArchiveReader(path) |
|
self.device = device |
|
with torch_lazy_load(): |
|
self.index = torch.load(path) |
|
|
|
def get_tensor(self, key: str) -> torch.Tensor: |
|
if key not in self.index: |
|
raise KeyError(key) |
|
|
|
return self.index[key].execute(self.zip_reader, map_location=self.device) |
|
|
|
def keys(self) -> Sequence[str]: |
|
return self.index.keys() |
|
|
|
|
|
class DumbPytorchLoader(TensorLoader): |
|
"""Naive `torch.load` shard loading.""" |
|
|
|
tensors: Dict[str, torch.Tensor] |
|
|
|
def __init__(self, path: str, device: str): |
|
self.tensors = torch.load(path, map_location=device, weights_only=True) |
|
|
|
def get_tensor(self, key: str) -> torch.Tensor: |
|
return self.tensors[key] |
|
|
|
def keys(self) -> Sequence[str]: |
|
return self.tensors.keys() |
|
|