File size: 5,656 Bytes
079c32c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
import os
import re
from time import sleep
import numpy as np
from typing import Any, Dict, List, Optional
class SlurmParser():
def __init__(self, platform_spec: Optional[Dict] = None, **kwargs) -> None:
"""
Overview:
Should only set global cluster properties
"""
self.kwargs = kwargs
self.ntasks = int(os.environ["SLURM_NTASKS"])
self.platform_spec = platform_spec
self.tasks = {}
self.ntasks_per_node = int(os.environ["SLURM_NTASKS_PER_NODE"])
self.nodelist = self._parse_node_list()
self.ports = int(kwargs.get("ports") or 15151)
self.parallel_workers = kwargs.get("parallel_workers") or 1
self.topology = kwargs.get("topology") or "alone"
def parse(self) -> dict:
procid = int(os.environ["SLURM_PROCID"])
task = self._get_task(procid)
# Validation
assert task["address"] == os.environ["SLURMD_NODENAME"]
return {**self.kwargs, **task}
def _get_task(self, procid: int) -> Dict[str, Any]:
if procid in self.tasks:
return self.tasks.get(procid)
if self.platform_spec:
task = self.platform_spec["tasks"][procid]
else:
task = {}
if "ports" not in task:
task["ports"] = self._get_ports(procid)
if "address" not in task:
task["address"] = self._get_address(procid)
if "node_ids" not in task:
task["node_ids"] = self._get_node_id(procid)
task["attach_to"] = self._get_attach_to(procid, task.get("attach_to"))
task["topology"] = self.topology
task["parallel_workers"] = self.parallel_workers
self.tasks[procid] = task
return task
def _parse_node_list(self) -> List[str]:
nodelist = os.environ["SLURM_NODELIST"]
result = re.match(r"(.*)?\[(.*)\]$", nodelist)
if result:
prefix, tails = result.groups()
nodelist = []
for tail in tails.split(","):
if "-" in tail:
start, stop = tail.split("-")
for number in range(int(start), int(stop) + 1):
nodelist.append(prefix + str(number))
else:
nodelist.append(prefix + tail)
elif isinstance(nodelist, str):
nodelist = [nodelist]
if self.ntasks_per_node > 1:
expand_nodelist = [] # Expand node for each task
for node in nodelist:
for _ in range(self.ntasks_per_node):
expand_nodelist.append(node)
nodelist = expand_nodelist
return nodelist
def _get_attach_to(self, procid: int, attach_to: Optional[str] = None) -> str:
if attach_to:
attach_to = [self._get_attach_to_part(part) for part in attach_to.split(",")]
elif procid == 0:
attach_to = []
else:
if self.topology == "mesh":
prev_tasks = [self._get_task(i) for i in range(procid)]
attach_to = [self._get_attach_to_from_task(task) for task in prev_tasks]
attach_to = list(np.concatenate(attach_to))
elif self.topology == "star":
head_task = self._get_task(0)
attach_to = self._get_attach_to_from_task(head_task)
else:
attach_to = []
return ",".join(attach_to)
def _get_attach_to_part(self, attach_part: str) -> str:
"""
Overview:
Parse each part of attach_to.
Arguments:
- attach_part (:obj:`str`): The attach_to field with specific pattern, e.g. $node:0
Returns
- attach_to (:obj:`str`): The real address, e.g. tcp://SH-0:50000
"""
if not attach_part.startswith("$node."):
return attach_part
attach_node_id = int(attach_part[6:])
attach_task = self._get_task(self._get_procid_from_nodeid(attach_node_id))
return self._get_tcp_link(attach_task["address"], attach_task["ports"])
def _get_attach_to_from_task(self, task: dict) -> List[str]:
"""
Overview:
Get attach_to list from task, note that parallel_workers will affact the connected processes.
Arguments:
- task (:obj:`dict`): The task object.
Returns
- attach_to (:obj:`str`): The real address, e.g. tcp://SH-0:50000
"""
port = task.get("ports")
address = task.get("address")
ports = [int(port) + i for i in range(self.parallel_workers)]
attach_to = [self._get_tcp_link(address, port) for port in ports]
return attach_to
def _get_procid_from_nodeid(self, nodeid: int) -> int:
procid = None
for i in range(self.ntasks):
task = self._get_task(i)
if task["node_ids"] == nodeid:
procid = i
break
if procid is None:
raise Exception("Can not find procid from nodeid: {}".format(nodeid))
return procid
def _get_ports(self, procid) -> int:
return self.ports + (procid % self.ntasks_per_node) * self.parallel_workers
def _get_address(self, procid: int) -> str:
address = self.nodelist[procid]
return address
def _get_node_id(self, procid: int) -> int:
return procid * self.parallel_workers
def _get_tcp_link(self, address: str, port: int) -> str:
return "tcp://{}:{}".format(address, port)
def slurm_parser(platform_spec: str, **kwargs) -> dict:
return SlurmParser(platform_spec, **kwargs).parse()
|