|
import os |
|
import numpy as np |
|
from time import sleep |
|
from typing import Dict, List, Optional |
|
|
|
|
|
class K8SParser(): |
|
|
|
def __init__(self, platform_spec: Optional[Dict] = None, **kwargs) -> None: |
|
""" |
|
Overview: |
|
Should only set global cluster properties |
|
""" |
|
self.kwargs = kwargs |
|
self.nodelist = self._parse_node_list() |
|
self.ntasks = len(self.nodelist) |
|
self.platform_spec = platform_spec |
|
self.parallel_workers = kwargs.get("parallel_workers") or 1 |
|
self.topology = kwargs.get("topology") or "alone" |
|
self.ports = int(kwargs.get("ports") or 50515) |
|
self.tasks = {} |
|
|
|
def parse(self) -> dict: |
|
if self.kwargs.get("mq_type", "nng") != "nng": |
|
return self.kwargs |
|
procid = int(os.environ["DI_RANK"]) |
|
nodename = self.nodelist[procid] |
|
task = self._get_task(procid) |
|
|
|
assert task["address"] == nodename |
|
return {**self.kwargs, **task} |
|
|
|
def _parse_node_list(self) -> List[str]: |
|
return os.environ["DI_NODES"].split(",") |
|
|
|
def _get_task(self, procid: int) -> dict: |
|
""" |
|
Overview: |
|
Complete node properties, use environment vars in list instead of on current node. |
|
For example, if you want to set nodename in this function, please derive it from DI_NODES. |
|
Arguments: |
|
- procid (:obj:`int`): Proc order, starting from 0, must be set automatically by dijob. |
|
Note that it is different from node_id. |
|
""" |
|
if procid in self.tasks: |
|
return self.tasks.get(procid) |
|
|
|
if self.platform_spec: |
|
task = self.platform_spec["tasks"][procid] |
|
else: |
|
task = {} |
|
if "ports" not in task: |
|
task["ports"] = self.kwargs.get("ports") or self._get_ports() |
|
if "address" not in task: |
|
task["address"] = self.kwargs.get("address") or self._get_address(procid) |
|
if "node_ids" not in task: |
|
task["node_ids"] = self.kwargs.get("node_ids") or self._get_node_id(procid) |
|
|
|
task["attach_to"] = self.kwargs.get("attach_to") or self._get_attach_to(procid, task.get("attach_to")) |
|
task["topology"] = self.topology |
|
task["parallel_workers"] = self.parallel_workers |
|
|
|
self.tasks[procid] = task |
|
return task |
|
|
|
def _get_attach_to(self, procid: int, attach_to: Optional[str] = None) -> str: |
|
""" |
|
Overview: |
|
Parse from pattern of attach_to. If attach_to is specified in the platform_spec, |
|
it is formatted as a real address based on the specified address. |
|
If not, the real addresses will be generated based on the globally specified typology. |
|
Arguments: |
|
- procid (:obj:`int`): Proc order. |
|
- attach_to (:obj:`str`): The attach_to field in platform_spec for the task with current procid. |
|
Returns |
|
- attach_to (:obj:`str`): The real addresses for attach_to. |
|
""" |
|
if attach_to: |
|
attach_to = [self._get_attach_to_part(part) for part in attach_to.split(",")] |
|
elif procid == 0: |
|
attach_to = [] |
|
else: |
|
if self.topology == "mesh": |
|
prev_tasks = [self._get_task(i) for i in range(procid)] |
|
attach_to = [self._get_attach_to_from_task(task) for task in prev_tasks] |
|
attach_to = list(np.concatenate(attach_to)) |
|
elif self.topology == "star": |
|
head_task = self._get_task(0) |
|
attach_to = self._get_attach_to_from_task(head_task) |
|
else: |
|
attach_to = [] |
|
|
|
return ",".join(attach_to) |
|
|
|
def _get_attach_to_part(self, attach_part: str) -> str: |
|
""" |
|
Overview: |
|
Parse each part of attach_to. |
|
Arguments: |
|
- attach_part (:obj:`str`): The attach_to field with specific pattern, e.g. $node:0 |
|
Returns |
|
- attach_to (:obj:`str`): The real address, e.g. tcp://SH-0:50000 |
|
""" |
|
if not attach_part.startswith("$node."): |
|
return attach_part |
|
attach_node_id = int(attach_part[6:]) |
|
attach_task = self._get_task(self._get_procid_from_nodeid(attach_node_id)) |
|
return self._get_tcp_link(attach_task["address"], attach_task["ports"]) |
|
|
|
def _get_attach_to_from_task(self, task: dict) -> List[str]: |
|
""" |
|
Overview: |
|
Get attach_to list from task, note that parallel_workers will affact the connected processes. |
|
Arguments: |
|
- task (:obj:`dict`): The task object. |
|
Returns |
|
- attach_to (:obj:`str`): The real address, e.g. tcp://SH-0:50000 |
|
""" |
|
port = task.get("ports") |
|
address = task.get("address") |
|
ports = [int(port) + i for i in range(self.parallel_workers)] |
|
attach_to = [self._get_tcp_link(address, port) for port in ports] |
|
return attach_to |
|
|
|
def _get_procid_from_nodeid(self, nodeid: int) -> int: |
|
procid = None |
|
for i in range(self.ntasks): |
|
task = self._get_task(i) |
|
if task["node_ids"] == nodeid: |
|
procid = i |
|
break |
|
if procid is None: |
|
raise Exception("Can not find procid from nodeid: {}".format(nodeid)) |
|
return procid |
|
|
|
def _get_ports(self) -> str: |
|
return self.ports |
|
|
|
def _get_address(self, procid: int) -> str: |
|
address = self.nodelist[procid] |
|
return address |
|
|
|
def _get_tcp_link(self, address: str, port: int) -> str: |
|
return "tcp://{}:{}".format(address, port) |
|
|
|
def _get_node_id(self, procid: int) -> int: |
|
return procid * self.parallel_workers |
|
|
|
|
|
def k8s_parser(platform_spec: Optional[str] = None, **kwargs) -> dict: |
|
return K8SParser(platform_spec, **kwargs).parse() |
|
|