File size: 2,165 Bytes
079c32c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 |
import shutil
import tempfile
from time import sleep, time
import pytest
from ding.data.model_loader import FileModelLoader
from ding.data.storage.file import FileModelStorage
from ding.model import DQN
from ding.config import compile_config
from dizoo.atari.config.serial.pong.pong_dqn_config import main_config, create_config
from os import path
import torch
@pytest.mark.tmp # gitlab ci and local test pass, github always fail
def test_model_loader():
tempdir = path.join(tempfile.gettempdir(), "test_model_loader")
cfg = compile_config(main_config, create_cfg=create_config, auto=True)
model = DQN(**cfg.policy.model)
loader = FileModelLoader(model=model, dirname=tempdir, ttl=1)
try:
loader.start()
model_storage = None
def save_model(storage):
nonlocal model_storage
model_storage = storage
start = time()
loader.save(save_model)
save_time = time() - start
print("Save time: {:.4f}s".format(save_time))
assert save_time < 0.1
sleep(0.5)
assert isinstance(model_storage, FileModelStorage)
assert len(loader._files) > 0
state_dict = loader.load(model_storage)
model.load_state_dict(state_dict)
sleep(2)
assert not path.exists(model_storage.path)
assert len(loader._files) == 0
finally:
if path.exists(tempdir):
shutil.rmtree(tempdir)
@pytest.mark.benchmark
def test_model_loader_benchmark():
model = torch.nn.Sequential(torch.nn.Linear(1024, 1024), torch.nn.Linear(1024, 100)) # 40MB
tempdir = path.join(tempfile.gettempdir(), "test_model_loader")
loader = FileModelLoader(model=model, dirname=tempdir)
try:
loader.start()
count = 0
def send_callback(_):
nonlocal count
count += 1
start = time()
for _ in range(5):
loader.save(send_callback)
sleep(0.2)
while count < 5:
sleep(0.001)
assert time() - start < 1.2
finally:
if path.exists(tempdir):
shutil.rmtree(tempdir)
loader.shutdown()
|