|
import os |
|
import subprocess |
|
from typing import Optional, Dict, Tuple |
|
|
|
MANAGER_NODE_TABLE = { |
|
'10.198.8': '10.198.8.31', |
|
'10.198.6': '10.198.6.31', |
|
'10.5.38': '10.5.38.31', |
|
'10.5.39': '10.5.38.31', |
|
'10.5.36': '10.5.36.31', |
|
'10.5.37': '10.5.36.31', |
|
'10.10.30': '10.10.30.91', |
|
} |
|
|
|
|
|
def get_ip() -> str: |
|
""" |
|
Overview: |
|
Get the ip of the current node |
|
""" |
|
|
|
assert os.environ.get('SLURMD_NODENAME'), 'not found SLURMD_NODENAME env variable' |
|
|
|
nodename = os.environ.get('SLURMD_NODENAME', '') |
|
myaddr = '.'.join(nodename.split('-')[-4:]) |
|
return myaddr |
|
|
|
|
|
def get_manager_node_ip(node_ip: Optional[str] = None) -> str: |
|
""" |
|
Overview: |
|
Look up the manager node of the slurm cluster and return the node ip |
|
Arguments: |
|
- node_ip (:obj:`Optional[str]`): The ip of the current node |
|
""" |
|
if 'SLURM_JOB_ID' not in os.environ: |
|
from ditk import logging |
|
logging.error( |
|
'We are not running on slurm!, \'auto\' for manager_ip or ' |
|
'coordinator_ip is only intended for running on multiple slurm clusters' |
|
) |
|
return '127.0.0.1' |
|
node_ip = node_ip or get_ip() |
|
learner_manager_ip_prefix = '.'.join(node_ip.split('.')[0:3]) |
|
|
|
if learner_manager_ip_prefix in MANAGER_NODE_TABLE: |
|
return MANAGER_NODE_TABLE[learner_manager_ip_prefix] |
|
else: |
|
raise KeyError("Cluster not found, please add it to the MANAGER_NODE_TABLE in {}".format(__file__)) |
|
|
|
|
|
|
|
def get_cls_info() -> Dict[str, list]: |
|
""" |
|
Overview: |
|
Get the cluster info |
|
""" |
|
|
|
ret_dict = {} |
|
info = subprocess.getoutput('sinfo -Nh').split('\n') |
|
for line in info: |
|
line = line.strip().split() |
|
if len(line) != 4: |
|
continue |
|
node, _, partition, state = line |
|
if partition not in ret_dict: |
|
ret_dict[partition] = [] |
|
assert node not in ret_dict[partition] |
|
if state in ['idle', 'mix']: |
|
ret_dict[partition].append(node) |
|
|
|
return ret_dict |
|
|
|
|
|
def node_to_partition(target_node: str) -> Tuple[str, str]: |
|
""" |
|
Overview: |
|
Get the partition of the target node |
|
Arguments: |
|
- target_node (:obj:`str`): The target node |
|
""" |
|
|
|
info = subprocess.getoutput('sinfo -Nh').split('\n') |
|
for line in info: |
|
line = line.strip().split() |
|
if len(line) != 4: |
|
continue |
|
node, _, partition, state = line |
|
if node == target_node: |
|
return partition |
|
raise RuntimeError("not found target_node: {}".format(target_node)) |
|
|
|
|
|
def node_to_host(node: str) -> str: |
|
""" |
|
Overview: |
|
Get the host of the node |
|
Arguments: |
|
- node (:obj:`str`): The node |
|
""" |
|
|
|
return '.'.join(node.split('-')[-4:]) |
|
|
|
|
|
def find_free_port_slurm(node: str) -> int: |
|
""" |
|
Overview: |
|
Find a free port on the node |
|
Arguments: |
|
- node (:obj:`str`): The node |
|
""" |
|
|
|
partition = node_to_partition(node) |
|
if partition == 'spring_scheduler': |
|
comment = '--comment=spring-submit' |
|
else: |
|
comment = '' |
|
output = subprocess.getoutput( |
|
"srun -p {} -w {} {} python -c \"from ding.utils import find_free_port; print('port' + str(find_free_port(0)))\"" |
|
.format(partition, node, comment) |
|
) |
|
port = output.split('port')[-1] |
|
return int(port) |
|
|