phi-2-merge / tests /test_io.py
Shaleen123's picture
Upload folder using huggingface_hub
a164e13 verified
import os
import tempfile
import torch
from mergekit.io import TensorWriter
class TestTensorWriter:
def test_safetensors(self):
with tempfile.TemporaryDirectory() as d:
writer = TensorWriter(d, safe_serialization=True)
writer.save_tensor("steve", torch.randn(4))
writer.finalize()
assert os.path.exists(os.path.join(d, "model-00001-of-00001.safetensors"))
assert os.path.exists(os.path.join(d, "model.safetensors.index.json"))
def test_pickle(self):
with tempfile.TemporaryDirectory() as d:
writer = TensorWriter(d, safe_serialization=False)
writer.save_tensor("timothan", torch.randn(4))
writer.finalize()
assert os.path.exists(os.path.join(d, "pytorch_model-00001-of-00001.bin"))
assert os.path.exists(os.path.join(d, "pytorch_model.bin.index.json"))
def test_duplicate_tensor(self):
with tempfile.TemporaryDirectory() as d:
writer = TensorWriter(d, safe_serialization=True)
jim = torch.randn(4)
writer.save_tensor("jim", jim)
writer.save_tensor("jimbo", jim)
writer.finalize()
assert os.path.exists(os.path.join(d, "model-00001-of-00001.safetensors"))
assert os.path.exists(os.path.join(d, "model.safetensors.index.json"))