File size: 5,855 Bytes
079c32c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
import os
import numpy as np
from time import sleep
from typing import Dict, List, Optional
class K8SParser():
def __init__(self, platform_spec: Optional[Dict] = None, **kwargs) -> None:
"""
Overview:
Should only set global cluster properties
"""
self.kwargs = kwargs
self.nodelist = self._parse_node_list()
self.ntasks = len(self.nodelist)
self.platform_spec = platform_spec
self.parallel_workers = kwargs.get("parallel_workers") or 1
self.topology = kwargs.get("topology") or "alone"
self.ports = int(kwargs.get("ports") or 50515)
self.tasks = {}
def parse(self) -> dict:
if self.kwargs.get("mq_type", "nng") != "nng":
return self.kwargs
procid = int(os.environ["DI_RANK"])
nodename = self.nodelist[procid]
task = self._get_task(procid)
# Validation
assert task["address"] == nodename
return {**self.kwargs, **task}
def _parse_node_list(self) -> List[str]:
return os.environ["DI_NODES"].split(",")
def _get_task(self, procid: int) -> dict:
"""
Overview:
Complete node properties, use environment vars in list instead of on current node.
For example, if you want to set nodename in this function, please derive it from DI_NODES.
Arguments:
- procid (:obj:`int`): Proc order, starting from 0, must be set automatically by dijob.
Note that it is different from node_id.
"""
if procid in self.tasks:
return self.tasks.get(procid)
if self.platform_spec:
task = self.platform_spec["tasks"][procid]
else:
task = {}
if "ports" not in task:
task["ports"] = self.kwargs.get("ports") or self._get_ports()
if "address" not in task:
task["address"] = self.kwargs.get("address") or self._get_address(procid)
if "node_ids" not in task:
task["node_ids"] = self.kwargs.get("node_ids") or self._get_node_id(procid)
task["attach_to"] = self.kwargs.get("attach_to") or self._get_attach_to(procid, task.get("attach_to"))
task["topology"] = self.topology
task["parallel_workers"] = self.parallel_workers
self.tasks[procid] = task
return task
def _get_attach_to(self, procid: int, attach_to: Optional[str] = None) -> str:
"""
Overview:
Parse from pattern of attach_to. If attach_to is specified in the platform_spec,
it is formatted as a real address based on the specified address.
If not, the real addresses will be generated based on the globally specified typology.
Arguments:
- procid (:obj:`int`): Proc order.
- attach_to (:obj:`str`): The attach_to field in platform_spec for the task with current procid.
Returns
- attach_to (:obj:`str`): The real addresses for attach_to.
"""
if attach_to:
attach_to = [self._get_attach_to_part(part) for part in attach_to.split(",")]
elif procid == 0:
attach_to = []
else:
if self.topology == "mesh":
prev_tasks = [self._get_task(i) for i in range(procid)]
attach_to = [self._get_attach_to_from_task(task) for task in prev_tasks]
attach_to = list(np.concatenate(attach_to))
elif self.topology == "star":
head_task = self._get_task(0)
attach_to = self._get_attach_to_from_task(head_task)
else:
attach_to = []
return ",".join(attach_to)
def _get_attach_to_part(self, attach_part: str) -> str:
"""
Overview:
Parse each part of attach_to.
Arguments:
- attach_part (:obj:`str`): The attach_to field with specific pattern, e.g. $node:0
Returns
- attach_to (:obj:`str`): The real address, e.g. tcp://SH-0:50000
"""
if not attach_part.startswith("$node."):
return attach_part
attach_node_id = int(attach_part[6:])
attach_task = self._get_task(self._get_procid_from_nodeid(attach_node_id))
return self._get_tcp_link(attach_task["address"], attach_task["ports"])
def _get_attach_to_from_task(self, task: dict) -> List[str]:
"""
Overview:
Get attach_to list from task, note that parallel_workers will affact the connected processes.
Arguments:
- task (:obj:`dict`): The task object.
Returns
- attach_to (:obj:`str`): The real address, e.g. tcp://SH-0:50000
"""
port = task.get("ports")
address = task.get("address")
ports = [int(port) + i for i in range(self.parallel_workers)]
attach_to = [self._get_tcp_link(address, port) for port in ports]
return attach_to
def _get_procid_from_nodeid(self, nodeid: int) -> int:
procid = None
for i in range(self.ntasks):
task = self._get_task(i)
if task["node_ids"] == nodeid:
procid = i
break
if procid is None:
raise Exception("Can not find procid from nodeid: {}".format(nodeid))
return procid
def _get_ports(self) -> str:
return self.ports
def _get_address(self, procid: int) -> str:
address = self.nodelist[procid]
return address
def _get_tcp_link(self, address: str, port: int) -> str:
return "tcp://{}:{}".format(address, port)
def _get_node_id(self, procid: int) -> int:
return procid * self.parallel_workers
def k8s_parser(platform_spec: Optional[str] = None, **kwargs) -> dict:
return K8SParser(platform_spec, **kwargs).parse()
|