|
from __future__ import annotations |
|
|
|
import dataclasses |
|
import threading |
|
from functools import partial |
|
from typing import List, NamedTuple, Optional, Sequence, Tuple |
|
|
|
from hivemind import DHT, PeerID |
|
from hivemind.utils.logging import get_logger, use_hivemind_log_handler |
|
|
|
from src.data_structures import ModuleUID, RemoteModuleInfo |
|
from src.dht_utils import _get_remote_module_infos |
|
|
|
use_hivemind_log_handler("in_root_logger") |
|
logger = get_logger(__file__) |
|
|
|
|
|
Span = NamedTuple("Span", [("start", int), ("end", Optional[int]), ("peer_id", PeerID)]) |
|
|
|
|
|
@dataclasses.dataclass(frozen=False, init=False) |
|
class RemoteSequenceInfo: |
|
"""Keeps and updates the meta-information about which peers host which blocks""" |
|
|
|
dht: DHT |
|
block_uids: List[ModuleUID, ...] |
|
block_infos: List[Optional[RemoteModuleInfo], ...] |
|
spans_by_priority: List[Span] |
|
spans_containing_block: Tuple[List[Span], ...] |
|
lock_changes: threading.Lock |
|
|
|
def __init__(self, dht: DHT, block_uids: Sequence[ModuleUID]): |
|
self.dht = dht |
|
self.block_uids = list(block_uids) |
|
self.block_infos: List[Optional[RemoteModuleInfo], ...] = [None] * len(self.block_uids) |
|
self.spans_by_priority = [] |
|
self.spans_containing_block = tuple(list() for _ in range(len(self.block_uids))) |
|
self.lock_changes = threading.Lock() |
|
self.update_() |
|
|
|
for uid, info in zip(self.block_uids, self.block_infos): |
|
assert info is not None, f"Found no remote peers for block {uid}" |
|
assert self.spans_by_priority and self.spans_containing_block |
|
|
|
def update_(self): |
|
with self.lock_changes: |
|
self.update_block_infos_() |
|
self.spans_by_priority, self.spans_containing_block = self.compute_spans(self.block_infos) |
|
|
|
def update_block_infos_(self): |
|
new_block_infos: Sequence[RemoteModuleInfo] = self.dht.run_coroutine( |
|
partial(_get_remote_module_infos, uids=self.block_uids, expiration_time=float("inf")), return_future=False |
|
) |
|
assert len(new_block_infos) == len(self.block_uids) |
|
for block_index, (uid, info) in enumerate(zip(self.block_uids, new_block_infos)): |
|
if info is None: |
|
logger.warning(f"Found no block info for block {uid}") |
|
if not isinstance(info, RemoteModuleInfo): |
|
logger.warning(f"Unexpected dht entry type for {uid}: {info}") |
|
if not info.peer_ids: |
|
logger.warning(f"Found no active peers for block {uid}") |
|
if info.uid != uid: |
|
logger.warning(f"The DHT entry for {uid} actually points to {info.uid}") |
|
if not isinstance(info.peer_ids, set): |
|
logger.warning(f"Expected peer_ids for {uid} to be a set, got {type(info.peer_ids)}") |
|
self.block_infos[block_index] = info |
|
|
|
@staticmethod |
|
def compute_spans(block_infos: Sequence[RemoteModuleInfo]): |
|
closed_spans = [] |
|
active_spans = {} |
|
for block_index, info in enumerate(block_infos): |
|
for peer_id in info.peer_ids: |
|
if peer_id not in active_spans: |
|
active_spans[peer_id] = Span(start=block_index, end=block_index + 1, peer_id=peer_id) |
|
else: |
|
active_spans[peer_id] = active_spans[peer_id]._replace(end=block_index + 1) |
|
|
|
for peer_id in list(active_spans.keys()): |
|
if peer_id not in info.peer_ids or block_index == len(block_infos) - 1: |
|
closed_spans.append(active_spans.pop(peer_id)) |
|
assert not active_spans |
|
|
|
closed_spans.sort(key=lambda span: span.end - span.start, reverse=True) |
|
|
|
spans_containing_block = tuple(list() for _ in range(len(block_infos))) |
|
for span in closed_spans: |
|
for block_index in range(span.start, span.end): |
|
spans_containing_block[block_index].append(span) |
|
|
|
return closed_spans, spans_containing_block |
|
|
|
def __len__(self): |
|
return len(self.block_uids) |
|
|