zjowowen's picture
init space
079c32c
raw
history blame
3.02 kB
import pytest
import os
from ding.entry.cli_parsers import PLATFORM_PARSERS
from ding.entry.cli_parsers.slurm_parser import SlurmParser
slurm_parser = PLATFORM_PARSERS["slurm"]
@pytest.fixture
def set_slurm_env():
os.environ["SLURM_NTASKS"] = '6' # Parameter n,Process count / Task count
os.environ["SLURM_NTASKS_PER_NODE"] = '3' # Parameter ntasks-per-node,process count of each node
os.environ["SLURM_NODELIST"] = 'SH-IDC1-10-5-38-[190,215]' # All the nodes
os.environ["SLURM_PROCID"] = '3' # Proc order,start from 0,the read proc order may be different from nominal order
os.environ["SLURMD_NODENAME"] = 'SH-IDC1-10-5-38-215' # Name of current node
yield
del os.environ["SLURM_NTASKS"]
del os.environ["SLURM_NTASKS_PER_NODE"]
del os.environ["SLURM_NODELIST"]
del os.environ["SLURM_PROCID"]
del os.environ["SLURMD_NODENAME"]
@pytest.mark.unittest
@pytest.mark.usefixtures('set_slurm_env')
def test_slurm_parser():
platform_spec = {
"tasks": [
{
"labels": "league,collect",
"node_ids": 10
}, {
"labels": "league,collect",
"node_ids": 11
}, {
"labels": "evaluate",
"node_ids": 20,
"attach_to": "$node.10,$node.11"
}, {
"labels": "learn",
"node_ids": 31,
"attach_to": "$node.10,$node.11,$node.20"
}, {
"labels": "learn",
"node_ids": 32,
"attach_to": "$node.10,$node.11,$node.20"
}, {
"labels": "learn",
"node_ids": 33,
"attach_to": "$node.10,$node.11,$node.20"
}
]
}
all_args = slurm_parser(platform_spec)
assert all_args["labels"] == "learn"
assert all_args["address"] == "SH-IDC1-10-5-38-215"
assert all_args["ports"] == 15151 # Start from 15151
assert all_args["node_ids"] == 31
assert all_args[
"attach_to"
] == "tcp://SH-IDC1-10-5-38-190:15151," +\
"tcp://SH-IDC1-10-5-38-190:15152," +\
"tcp://SH-IDC1-10-5-38-190:15153"
# Test without platform_spec
all_args = slurm_parser(None, topology="mesh", mq_type="nng")
assert all_args["address"] == "SH-IDC1-10-5-38-215"
assert all_args["node_ids"] == 3
assert all_args["parallel_workers"] == 1
assert all_args[
"attach_to"
] == "tcp://SH-IDC1-10-5-38-190:15151," +\
"tcp://SH-IDC1-10-5-38-190:15152," +\
"tcp://SH-IDC1-10-5-38-190:15153"
# Test _parse_node_list
sp = SlurmParser(platform_spec)
os.environ["SLURM_NODELIST"] = 'SH-IDC1-10-5-[38-40]'
nodelist = sp._parse_node_list() # Nodes * parallel_workers
assert nodelist == [
'SH-IDC1-10-5-38', 'SH-IDC1-10-5-38', 'SH-IDC1-10-5-38', 'SH-IDC1-10-5-39', 'SH-IDC1-10-5-39',
'SH-IDC1-10-5-39', 'SH-IDC1-10-5-40', 'SH-IDC1-10-5-40', 'SH-IDC1-10-5-40'
]