File size: 5,656 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import os
import re
from time import sleep
import numpy as np
from typing import Any, Dict, List, Optional


class SlurmParser():

    def __init__(self, platform_spec: Optional[Dict] = None, **kwargs) -> None:
        """
        Overview:
            Should only set global cluster properties
        """
        self.kwargs = kwargs
        self.ntasks = int(os.environ["SLURM_NTASKS"])
        self.platform_spec = platform_spec
        self.tasks = {}
        self.ntasks_per_node = int(os.environ["SLURM_NTASKS_PER_NODE"])
        self.nodelist = self._parse_node_list()
        self.ports = int(kwargs.get("ports") or 15151)
        self.parallel_workers = kwargs.get("parallel_workers") or 1
        self.topology = kwargs.get("topology") or "alone"

    def parse(self) -> dict:
        procid = int(os.environ["SLURM_PROCID"])
        task = self._get_task(procid)
        # Validation
        assert task["address"] == os.environ["SLURMD_NODENAME"]
        return {**self.kwargs, **task}

    def _get_task(self, procid: int) -> Dict[str, Any]:
        if procid in self.tasks:
            return self.tasks.get(procid)
        if self.platform_spec:
            task = self.platform_spec["tasks"][procid]
        else:
            task = {}
        if "ports" not in task:
            task["ports"] = self._get_ports(procid)
        if "address" not in task:
            task["address"] = self._get_address(procid)
        if "node_ids" not in task:
            task["node_ids"] = self._get_node_id(procid)

        task["attach_to"] = self._get_attach_to(procid, task.get("attach_to"))
        task["topology"] = self.topology
        task["parallel_workers"] = self.parallel_workers

        self.tasks[procid] = task
        return task

    def _parse_node_list(self) -> List[str]:
        nodelist = os.environ["SLURM_NODELIST"]
        result = re.match(r"(.*)?\[(.*)\]$", nodelist)
        if result:
            prefix, tails = result.groups()
            nodelist = []
            for tail in tails.split(","):
                if "-" in tail:
                    start, stop = tail.split("-")
                    for number in range(int(start), int(stop) + 1):
                        nodelist.append(prefix + str(number))
                else:
                    nodelist.append(prefix + tail)
        elif isinstance(nodelist, str):
            nodelist = [nodelist]
        if self.ntasks_per_node > 1:
            expand_nodelist = []  # Expand node for each task
            for node in nodelist:
                for _ in range(self.ntasks_per_node):
                    expand_nodelist.append(node)
            nodelist = expand_nodelist
        return nodelist

    def _get_attach_to(self, procid: int, attach_to: Optional[str] = None) -> str:
        if attach_to:
            attach_to = [self._get_attach_to_part(part) for part in attach_to.split(",")]
        elif procid == 0:
            attach_to = []
        else:
            if self.topology == "mesh":
                prev_tasks = [self._get_task(i) for i in range(procid)]
                attach_to = [self._get_attach_to_from_task(task) for task in prev_tasks]
                attach_to = list(np.concatenate(attach_to))
            elif self.topology == "star":
                head_task = self._get_task(0)
                attach_to = self._get_attach_to_from_task(head_task)
            else:
                attach_to = []

        return ",".join(attach_to)

    def _get_attach_to_part(self, attach_part: str) -> str:
        """
        Overview:
            Parse each part of attach_to.
        Arguments:
            - attach_part (:obj:`str`): The attach_to field with specific pattern, e.g. $node:0
        Returns
            - attach_to (:obj:`str`): The real address, e.g. tcp://SH-0:50000
        """
        if not attach_part.startswith("$node."):
            return attach_part
        attach_node_id = int(attach_part[6:])
        attach_task = self._get_task(self._get_procid_from_nodeid(attach_node_id))
        return self._get_tcp_link(attach_task["address"], attach_task["ports"])

    def _get_attach_to_from_task(self, task: dict) -> List[str]:
        """
        Overview:
            Get attach_to list from task, note that parallel_workers will affact the connected processes.
        Arguments:
            - task (:obj:`dict`): The task object.
        Returns
            - attach_to (:obj:`str`): The real address, e.g. tcp://SH-0:50000
        """
        port = task.get("ports")
        address = task.get("address")
        ports = [int(port) + i for i in range(self.parallel_workers)]
        attach_to = [self._get_tcp_link(address, port) for port in ports]
        return attach_to

    def _get_procid_from_nodeid(self, nodeid: int) -> int:
        procid = None
        for i in range(self.ntasks):
            task = self._get_task(i)
            if task["node_ids"] == nodeid:
                procid = i
                break
        if procid is None:
            raise Exception("Can not find procid from nodeid: {}".format(nodeid))
        return procid

    def _get_ports(self, procid) -> int:
        return self.ports + (procid % self.ntasks_per_node) * self.parallel_workers

    def _get_address(self, procid: int) -> str:
        address = self.nodelist[procid]
        return address

    def _get_node_id(self, procid: int) -> int:
        return procid * self.parallel_workers

    def _get_tcp_link(self, address: str, port: int) -> str:
        return "tcp://{}:{}".format(address, port)


def slurm_parser(platform_spec: str, **kwargs) -> dict:
    return SlurmParser(platform_spec, **kwargs).parse()