import os import numpy as np from time import sleep from typing import Dict, List, Optional class K8SParser(): def __init__(self, platform_spec: Optional[Dict] = None, **kwargs) -> None: """ Overview: Should only set global cluster properties """ self.kwargs = kwargs self.nodelist = self._parse_node_list() self.ntasks = len(self.nodelist) self.platform_spec = platform_spec self.parallel_workers = kwargs.get("parallel_workers") or 1 self.topology = kwargs.get("topology") or "alone" self.ports = int(kwargs.get("ports") or 50515) self.tasks = {} def parse(self) -> dict: if self.kwargs.get("mq_type", "nng") != "nng": return self.kwargs procid = int(os.environ["DI_RANK"]) nodename = self.nodelist[procid] task = self._get_task(procid) # Validation assert task["address"] == nodename return {**self.kwargs, **task} def _parse_node_list(self) -> List[str]: return os.environ["DI_NODES"].split(",") def _get_task(self, procid: int) -> dict: """ Overview: Complete node properties, use environment vars in list instead of on current node. For example, if you want to set nodename in this function, please derive it from DI_NODES. Arguments: - procid (:obj:`int`): Proc order, starting from 0, must be set automatically by dijob. Note that it is different from node_id. """ if procid in self.tasks: return self.tasks.get(procid) if self.platform_spec: task = self.platform_spec["tasks"][procid] else: task = {} if "ports" not in task: task["ports"] = self.kwargs.get("ports") or self._get_ports() if "address" not in task: task["address"] = self.kwargs.get("address") or self._get_address(procid) if "node_ids" not in task: task["node_ids"] = self.kwargs.get("node_ids") or self._get_node_id(procid) task["attach_to"] = self.kwargs.get("attach_to") or self._get_attach_to(procid, task.get("attach_to")) task["topology"] = self.topology task["parallel_workers"] = self.parallel_workers self.tasks[procid] = task return task def _get_attach_to(self, procid: int, attach_to: Optional[str] = None) -> str: """ Overview: Parse from pattern of attach_to. If attach_to is specified in the platform_spec, it is formatted as a real address based on the specified address. If not, the real addresses will be generated based on the globally specified typology. Arguments: - procid (:obj:`int`): Proc order. - attach_to (:obj:`str`): The attach_to field in platform_spec for the task with current procid. Returns - attach_to (:obj:`str`): The real addresses for attach_to. """ if attach_to: attach_to = [self._get_attach_to_part(part) for part in attach_to.split(",")] elif procid == 0: attach_to = [] else: if self.topology == "mesh": prev_tasks = [self._get_task(i) for i in range(procid)] attach_to = [self._get_attach_to_from_task(task) for task in prev_tasks] attach_to = list(np.concatenate(attach_to)) elif self.topology == "star": head_task = self._get_task(0) attach_to = self._get_attach_to_from_task(head_task) else: attach_to = [] return ",".join(attach_to) def _get_attach_to_part(self, attach_part: str) -> str: """ Overview: Parse each part of attach_to. Arguments: - attach_part (:obj:`str`): The attach_to field with specific pattern, e.g. $node:0 Returns - attach_to (:obj:`str`): The real address, e.g. tcp://SH-0:50000 """ if not attach_part.startswith("$node."): return attach_part attach_node_id = int(attach_part[6:]) attach_task = self._get_task(self._get_procid_from_nodeid(attach_node_id)) return self._get_tcp_link(attach_task["address"], attach_task["ports"]) def _get_attach_to_from_task(self, task: dict) -> List[str]: """ Overview: Get attach_to list from task, note that parallel_workers will affact the connected processes. Arguments: - task (:obj:`dict`): The task object. Returns - attach_to (:obj:`str`): The real address, e.g. tcp://SH-0:50000 """ port = task.get("ports") address = task.get("address") ports = [int(port) + i for i in range(self.parallel_workers)] attach_to = [self._get_tcp_link(address, port) for port in ports] return attach_to def _get_procid_from_nodeid(self, nodeid: int) -> int: procid = None for i in range(self.ntasks): task = self._get_task(i) if task["node_ids"] == nodeid: procid = i break if procid is None: raise Exception("Can not find procid from nodeid: {}".format(nodeid)) return procid def _get_ports(self) -> str: return self.ports def _get_address(self, procid: int) -> str: address = self.nodelist[procid] return address def _get_tcp_link(self, address: str, port: int) -> str: return "tcp://{}:{}".format(address, port) def _get_node_id(self, procid: int) -> int: return procid * self.parallel_workers def k8s_parser(platform_spec: Optional[str] = None, **kwargs) -> dict: return K8SParser(platform_spec, **kwargs).parse()