File size: 3,333 Bytes
bbcc985 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 |
import json
import struct
import torch
import threading
import warnings
###
# Code from ljleb/sd-mecha/sd_mecha/streaming.py
DTYPE_MAPPING = {
'F64': (torch.float64, 8),
'F32': (torch.float32, 4),
'F16': (torch.float16, 2),
'BF16': (torch.bfloat16, 2),
'I8': (torch.int8, 1),
'I64': (torch.int64, 8),
'I32': (torch.int32, 4),
'I16': (torch.int16, 2),
"F8_E4M3": (torch.float8_e4m3fn, 1),
"F8_E5M2": (torch.float8_e5m2, 1),
}
class InSafetensorsDict:
def __init__(self, f, buffer_size):
self.default_buffer_size = buffer_size
self.file = f
self.header_size, self.header = self._read_header()
self.buffer = bytearray()
self.buffer_start_offset = 8 + self.header_size
self.lock = threading.Lock()
def __del__(self):
self.close()
def __getitem__(self, key):
if key not in self.header or key == "__metadata__":
raise KeyError(key)
return self._load_tensor(key)
def __iter__(self):
return iter(self.keys())
def __len__(self):
return len(self.header)
def close(self):
self.file.close()
self.buffer = None
self.header = None
def keys(self):
return (
key
for key in self.header.keys()
if key != "__metadata__"
)
def values(self):
for key in self.keys():
yield self[key]
def items(self):
for key in self.keys():
yield key, self[key]
def _read_header(self):
header_size_bytes = self.file.read(8)
header_size = struct.unpack('<Q', header_size_bytes)[0]
header_json = self.file.read(header_size).decode('utf-8').strip()
header = json.loads(header_json)
# sort by memory order to reduce seek time
sorted_header = dict(sorted(header.items(), key=lambda item: item[1].get('data_offsets', [0])[0]))
return header_size, sorted_header
def _ensure_buffer(self, start_pos, length):
if start_pos < self.buffer_start_offset or start_pos + length > self.buffer_start_offset + len(self.buffer):
self.file.seek(start_pos)
necessary_buffer_size = max(self.default_buffer_size, length)
if len(self.buffer) < necessary_buffer_size:
self.buffer = bytearray(necessary_buffer_size)
else:
self.buffer = self.buffer[:necessary_buffer_size]
self.file.readinto(self.buffer)
self.buffer_start_offset = start_pos
def _load_tensor(self, tensor_name):
tensor_info = self.header[tensor_name]
offsets = tensor_info['data_offsets']
dtype, dtype_bytes = DTYPE_MAPPING[tensor_info['dtype']]
shape = tensor_info['shape']
total_bytes = offsets[1] - offsets[0]
absolute_start_pos = 8 + self.header_size + offsets[0]
with warnings.catch_warnings():
warnings.simplefilter('ignore')
with self.lock:
self._ensure_buffer(absolute_start_pos, total_bytes)
buffer_offset = absolute_start_pos - self.buffer_start_offset
return torch.frombuffer(self.buffer, count=total_bytes // dtype_bytes, offset=buffer_offset, dtype=dtype).reshape(shape)
|