|
from __future__ import annotations |
|
|
|
import multiprocessing as mp |
|
import threading |
|
from typing import Dict, Optional, Sequence, Union |
|
|
|
import torch |
|
from hivemind import DHT, MAX_DHT_TIME_DISCREPANCY_SECONDS, BatchTensorDescriptor, get_dht_time |
|
from hivemind.moe.server.layers import add_custom_models_from_file |
|
from hivemind.moe.server.runtime import Runtime |
|
from hivemind.proto.runtime_pb2 import CompressionType |
|
from hivemind.utils.logging import get_logger, use_hivemind_log_handler |
|
|
|
from src import declare_active_modules, BloomConfig |
|
from src.bloom.from_pretrained import DTYPE_MAP, load_pretrained_block |
|
from src.data_structures import CHAIN_DELIMITER, UID_DELIMITER |
|
from src.server.backend import TransformerBackend |
|
from src.server.cache import MemoryCache |
|
from src.server.handler import TransformerConnectionHandler |
|
|
|
use_hivemind_log_handler("in_root_logger") |
|
logger = get_logger(__file__) |
|
|
|
|
|
class Server(threading.Thread): |
|
"""Serves one or more bloom layers for inference, forward and backward; announces oneself to the DHT""" |
|
|
|
def __init__( |
|
self, |
|
dht: DHT, |
|
module_backends: Dict[str, TransformerBackend], |
|
*, |
|
device: torch.device, |
|
num_connection_handlers: int = 8, |
|
update_period: float = 30, |
|
expiration: Optional[float] = None, |
|
start: bool, |
|
**kwargs, |
|
): |
|
threading.Thread.__init__(self) |
|
self.dht, self.module_backends, self.update_period = dht, module_backends, update_period |
|
self.conn_handlers = [ |
|
TransformerConnectionHandler(dht, self.module_backends) for _ in range(num_connection_handlers) |
|
] |
|
self.runtime = Runtime(self.module_backends, device=device, **kwargs) |
|
self.dht_handler_thread = ModuleAnnouncerThread( |
|
self.module_backends, dht, update_period, expiration, daemon=True |
|
) |
|
self.checkpoint_saver = None |
|
|
|
if start: |
|
self.run_in_background(await_ready=True) |
|
|
|
def run(self): |
|
""" |
|
Starts Server in the current thread. Initializes dht if necessary, starts connection handlers, |
|
runs Runtime (self.runtime) to process incoming requests. |
|
""" |
|
logger.info(f"Serving {len(self.module_backends)} blocks:") |
|
for expert_name, backend in self.module_backends.items(): |
|
num_parameters = sum(p.numel() for p in backend.module.parameters() if p.requires_grad) |
|
logger.info(f"{expert_name}: {backend.module.__class__.__name__}, {num_parameters} parameters") |
|
|
|
if not self.dht.is_alive(): |
|
self.dht.run_in_background(await_ready=True) |
|
|
|
if self.module_backends: |
|
self.dht_handler_thread.start() |
|
|
|
if self.checkpoint_saver is not None: |
|
self.checkpoint_saver.start() |
|
|
|
for process in self.conn_handlers: |
|
if not process.is_alive(): |
|
process.start() |
|
process.ready.result() |
|
|
|
try: |
|
self.runtime.run() |
|
finally: |
|
self.shutdown() |
|
|
|
|
|
@classmethod |
|
def create( |
|
cls, |
|
prefix: Optional[str], |
|
converted_model_name_or_path: str, |
|
num_blocks: Optional[int] = None, |
|
block_indices: Optional[str] = None, |
|
num_handlers: Optional[int] = None, |
|
min_batch_size: int = 1, |
|
max_batch_size: int = 4096, |
|
torch_dtype: str = "auto", |
|
cache_size_bytes: Optional[int] = None, |
|
device: Union[str, torch.device] = None, |
|
initial_peers: Sequence[str] = (), |
|
compression=CompressionType.NONE, |
|
stats_report_interval: Optional[int] = None, |
|
custom_module_path=None, |
|
update_period: float = 30, |
|
expiration: Optional[float] = None, |
|
use_auth_token: Optional[str] = None, |
|
*, |
|
start: bool, |
|
**kwargs, |
|
) -> Server: |
|
"""Create a server with one or more bloom blocks. See run_server.py for documentation.""" |
|
if custom_module_path is not None: |
|
add_custom_models_from_file(custom_module_path) |
|
if prefix is None: |
|
prefix = converted_model_name_or_path |
|
assert UID_DELIMITER not in prefix and CHAIN_DELIMITER not in prefix, ( |
|
f"Cannot use model name as prefix (contains '{UID_DELIMITER}' or '{CHAIN_DELIMITER}'); " |
|
f"Please specify --prefix manually when starting a server" |
|
) |
|
logger.info(f"Automatic dht prefix: {prefix}") |
|
assert (block_indices is None) != (num_blocks is None), "please specify num_blocks or block_indices, not both" |
|
dht = DHT(initial_peers=initial_peers, start=True, **kwargs) |
|
visible_maddrs_str = [str(a) for a in dht.get_visible_maddrs()] |
|
logger.info(f"Running DHT node on {visible_maddrs_str}, initial peers = {initial_peers}") |
|
|
|
device = device or ("cuda" if torch.cuda.is_available() else "cpu") |
|
memory_cache = MemoryCache(device, cache_size_bytes) |
|
|
|
if isinstance(torch_dtype, str): |
|
torch_dtype = DTYPE_MAP[torch_dtype] |
|
assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}" |
|
|
|
if block_indices is not None: |
|
try: |
|
first_block_index, last_block_index = block_indices.split(":") |
|
first_block_index, last_block_index = map(int, map(str.strip, (first_block_index, last_block_index))) |
|
except Exception as e: |
|
logger.error(f"Failed to parse --block_indices ({e}), must be start:end (e.g. 0:18)") |
|
raise |
|
block_indices = range(first_block_index, last_block_index) |
|
else: |
|
assert num_blocks is not None |
|
block_indices = range(num_blocks) |
|
|
|
block_config = BloomConfig.from_pretrained( |
|
converted_model_name_or_path, use_auth_token=use_auth_token |
|
) |
|
|
|
|
|
blocks = {} |
|
for block_index in block_indices: |
|
module_uid = f"{prefix}.{block_index}" |
|
block = load_pretrained_block( |
|
converted_model_name_or_path, |
|
block_index, |
|
block_config, |
|
torch_dtype=torch_dtype, |
|
use_auth_token=use_auth_token, |
|
) |
|
for param in block.parameters(): |
|
param.requires_grad = False |
|
|
|
blocks[module_uid] = TransformerBackend( |
|
module_uid, |
|
block, |
|
memory_cache=memory_cache, |
|
args_schema=(BatchTensorDescriptor(1, 2048, block_config.hidden_size, compression=compression),), |
|
kwargs_schema={}, |
|
outputs_schema=(BatchTensorDescriptor(1, 2048, block_config.hidden_size, compression=compression),), |
|
min_batch_size=min_batch_size, |
|
max_batch_size=max_batch_size, |
|
) |
|
|
|
num_handlers = num_handlers if num_handlers is not None else len(blocks) * 4 |
|
|
|
return cls( |
|
dht, |
|
blocks, |
|
num_connection_handlers=num_handlers, |
|
device=device, |
|
stats_report_interval=stats_report_interval, |
|
update_period=update_period, |
|
expiration=expiration, |
|
start=start, |
|
) |
|
|
|
def run_in_background(self, await_ready=True, timeout=None): |
|
""" |
|
Starts Server in a background thread. if await_ready, this method will wait until background server |
|
is ready to process incoming requests or for :timeout: seconds max. |
|
""" |
|
self.start() |
|
if await_ready and not self.ready.wait(timeout=timeout): |
|
raise TimeoutError("Server didn't notify .ready in {timeout} seconds") |
|
|
|
@property |
|
def ready(self) -> mp.synchronize.Event: |
|
""" |
|
An event (multiprocessing.Event) that is set when the server is ready to process requests. |
|
|
|
Example |
|
======= |
|
>>> server.start() |
|
>>> server.ready.wait(timeout=10) |
|
>>> print("Server ready" if server.ready.is_set() else "Server didn't start in 10 seconds") |
|
""" |
|
return self.runtime.ready |
|
|
|
def shutdown(self): |
|
""" |
|
Gracefully terminate the server, process-safe. |
|
Please note that terminating server otherwise (e.g. by killing processes) may result in zombie processes. |
|
If you did already cause a zombie outbreak, your only option is to kill them with -9 (SIGKILL). |
|
""" |
|
self.ready.clear() |
|
|
|
for process in self.conn_handlers: |
|
process.terminate() |
|
process.join() |
|
logger.debug("Connection handlers terminated") |
|
|
|
if self.module_backends: |
|
self.dht_handler_thread.stop.set() |
|
self.dht_handler_thread.join() |
|
|
|
if self.checkpoint_saver is not None: |
|
self.checkpoint_saver.stop.set() |
|
self.checkpoint_saver.join() |
|
|
|
self.dht.shutdown() |
|
self.dht.join() |
|
|
|
logger.debug(f"Shutting down runtime") |
|
|
|
self.runtime.shutdown() |
|
logger.info("Server shutdown succesfully") |
|
|
|
|
|
class ModuleAnnouncerThread(threading.Thread): |
|
"""Periodically announces that this server hosts the specified modules, visible to all DHT peers""" |
|
|
|
def __init__( |
|
self, module_backends, dht: DHT, update_period: float = 30, expiration: Optional[int] = None, **kwargs |
|
): |
|
super().__init__(**kwargs) |
|
if expiration is None: |
|
expiration = max(2 * update_period, MAX_DHT_TIME_DISCREPANCY_SECONDS) |
|
self.module_backends = module_backends |
|
self.dht = dht |
|
self.update_period = update_period |
|
self.expiration = expiration |
|
self.stop = threading.Event() |
|
|
|
def run(self) -> None: |
|
declare_active_modules(self.dht, self.module_backends.keys(), get_dht_time() + self.expiration) |
|
while not self.stop.wait(self.update_period): |
|
declare_active_modules(self.dht, self.module_backends.keys(), get_dht_time() + self.expiration) |
|
|