import os import re from time import sleep import numpy as np from typing import Any, Dict, List, Optional class SlurmParser(): def __init__(self, platform_spec: Optional[Dict] = None, **kwargs) -> None: """ Overview: Should only set global cluster properties """ self.kwargs = kwargs self.ntasks = int(os.environ["SLURM_NTASKS"]) self.platform_spec = platform_spec self.tasks = {} self.ntasks_per_node = int(os.environ["SLURM_NTASKS_PER_NODE"]) self.nodelist = self._parse_node_list() self.ports = int(kwargs.get("ports") or 15151) self.parallel_workers = kwargs.get("parallel_workers") or 1 self.topology = kwargs.get("topology") or "alone" def parse(self) -> dict: procid = int(os.environ["SLURM_PROCID"]) task = self._get_task(procid) # Validation assert task["address"] == os.environ["SLURMD_NODENAME"] return {**self.kwargs, **task} def _get_task(self, procid: int) -> Dict[str, Any]: if procid in self.tasks: return self.tasks.get(procid) if self.platform_spec: task = self.platform_spec["tasks"][procid] else: task = {} if "ports" not in task: task["ports"] = self._get_ports(procid) if "address" not in task: task["address"] = self._get_address(procid) if "node_ids" not in task: task["node_ids"] = self._get_node_id(procid) task["attach_to"] = self._get_attach_to(procid, task.get("attach_to")) task["topology"] = self.topology task["parallel_workers"] = self.parallel_workers self.tasks[procid] = task return task def _parse_node_list(self) -> List[str]: nodelist = os.environ["SLURM_NODELIST"] result = re.match(r"(.*)?\[(.*)\]$", nodelist) if result: prefix, tails = result.groups() nodelist = [] for tail in tails.split(","): if "-" in tail: start, stop = tail.split("-") for number in range(int(start), int(stop) + 1): nodelist.append(prefix + str(number)) else: nodelist.append(prefix + tail) elif isinstance(nodelist, str): nodelist = [nodelist] if self.ntasks_per_node > 1: expand_nodelist = [] # Expand node for each task for node in nodelist: for _ in range(self.ntasks_per_node): expand_nodelist.append(node) nodelist = expand_nodelist return nodelist def _get_attach_to(self, procid: int, attach_to: Optional[str] = None) -> str: if attach_to: attach_to = [self._get_attach_to_part(part) for part in attach_to.split(",")] elif procid == 0: attach_to = [] else: if self.topology == "mesh": prev_tasks = [self._get_task(i) for i in range(procid)] attach_to = [self._get_attach_to_from_task(task) for task in prev_tasks] attach_to = list(np.concatenate(attach_to)) elif self.topology == "star": head_task = self._get_task(0) attach_to = self._get_attach_to_from_task(head_task) else: attach_to = [] return ",".join(attach_to) def _get_attach_to_part(self, attach_part: str) -> str: """ Overview: Parse each part of attach_to. Arguments: - attach_part (:obj:`str`): The attach_to field with specific pattern, e.g. $node:0 Returns - attach_to (:obj:`str`): The real address, e.g. tcp://SH-0:50000 """ if not attach_part.startswith("$node."): return attach_part attach_node_id = int(attach_part[6:]) attach_task = self._get_task(self._get_procid_from_nodeid(attach_node_id)) return self._get_tcp_link(attach_task["address"], attach_task["ports"]) def _get_attach_to_from_task(self, task: dict) -> List[str]: """ Overview: Get attach_to list from task, note that parallel_workers will affact the connected processes. Arguments: - task (:obj:`dict`): The task object. Returns - attach_to (:obj:`str`): The real address, e.g. tcp://SH-0:50000 """ port = task.get("ports") address = task.get("address") ports = [int(port) + i for i in range(self.parallel_workers)] attach_to = [self._get_tcp_link(address, port) for port in ports] return attach_to def _get_procid_from_nodeid(self, nodeid: int) -> int: procid = None for i in range(self.ntasks): task = self._get_task(i) if task["node_ids"] == nodeid: procid = i break if procid is None: raise Exception("Can not find procid from nodeid: {}".format(nodeid)) return procid def _get_ports(self, procid) -> int: return self.ports + (procid % self.ntasks_per_node) * self.parallel_workers def _get_address(self, procid: int) -> str: address = self.nodelist[procid] return address def _get_node_id(self, procid: int) -> int: return procid * self.parallel_workers def _get_tcp_link(self, address: str, port: int) -> str: return "tcp://{}:{}".format(address, port) def slurm_parser(platform_spec: str, **kwargs) -> dict: return SlurmParser(platform_spec, **kwargs).parse()