|
import os |
|
from functools import partial |
|
from typing import Any, Optional |
|
|
|
import torch |
|
|
|
import videosys |
|
|
|
from .mp_utils import ProcessWorkerWrapper, ResultHandler, WorkerMonitor, get_distributed_init_method, get_open_port |
|
|
|
|
|
class VideoSysEngine: |
|
""" |
|
this is partly inspired by vllm |
|
""" |
|
|
|
def __init__(self, config): |
|
self.config = config |
|
self.parallel_worker_tasks = None |
|
self._init_worker(config.pipeline_cls) |
|
|
|
def _init_worker(self, pipeline_cls): |
|
world_size = self.config.num_gpus |
|
|
|
if "CUDA_VISIBLE_DEVICES" not in os.environ: |
|
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(str(i) for i in range(world_size)) |
|
|
|
|
|
os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1" |
|
|
|
|
|
|
|
if "OMP_NUM_THREADS" not in os.environ: |
|
os.environ["OMP_NUM_THREADS"] = "1" |
|
|
|
|
|
assert world_size <= torch.cuda.device_count() |
|
|
|
|
|
distributed_init_method = get_distributed_init_method("127.0.0.1", get_open_port()) |
|
|
|
if world_size == 1: |
|
self.workers = [] |
|
self.worker_monitor = None |
|
else: |
|
result_handler = ResultHandler() |
|
self.workers = [ |
|
ProcessWorkerWrapper( |
|
result_handler, |
|
partial( |
|
self._create_pipeline, |
|
pipeline_cls=pipeline_cls, |
|
rank=rank, |
|
local_rank=rank, |
|
distributed_init_method=distributed_init_method, |
|
), |
|
) |
|
for rank in range(1, world_size) |
|
] |
|
|
|
self.worker_monitor = WorkerMonitor(self.workers, result_handler) |
|
result_handler.start() |
|
self.worker_monitor.start() |
|
|
|
self.driver_worker = self._create_pipeline( |
|
pipeline_cls=pipeline_cls, distributed_init_method=distributed_init_method |
|
) |
|
|
|
|
|
def _create_pipeline(self, pipeline_cls, rank=0, local_rank=0, distributed_init_method=None): |
|
videosys.initialize(rank=rank, world_size=self.config.num_gpus, init_method=distributed_init_method, seed=42) |
|
|
|
pipeline = pipeline_cls(self.config) |
|
return pipeline |
|
|
|
def _run_workers( |
|
self, |
|
method: str, |
|
*args, |
|
async_run_tensor_parallel_workers_only: bool = False, |
|
max_concurrent_workers: Optional[int] = None, |
|
**kwargs, |
|
) -> Any: |
|
"""Runs the given method on all workers.""" |
|
|
|
|
|
worker_outputs = [worker.execute_method(method, *args, **kwargs) for worker in self.workers] |
|
|
|
if async_run_tensor_parallel_workers_only: |
|
|
|
return worker_outputs |
|
|
|
driver_worker_method = getattr(self.driver_worker, method) |
|
driver_worker_output = driver_worker_method(*args, **kwargs) |
|
|
|
|
|
return [driver_worker_output] + [output.get() for output in worker_outputs] |
|
|
|
def _driver_execute_model(self, *args, **kwargs): |
|
return self.driver_worker.generate(*args, **kwargs) |
|
|
|
def generate(self, *args, **kwargs): |
|
return self._run_workers("generate", *args, **kwargs)[0] |
|
|
|
def stop_remote_worker_execution_loop(self) -> None: |
|
if self.parallel_worker_tasks is None: |
|
return |
|
|
|
parallel_worker_tasks = self.parallel_worker_tasks |
|
self.parallel_worker_tasks = None |
|
|
|
|
|
self._wait_for_tasks_completion(parallel_worker_tasks) |
|
|
|
def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None: |
|
"""Wait for futures returned from _run_workers() with |
|
async_run_remote_workers_only to complete.""" |
|
for result in parallel_worker_tasks: |
|
result.get() |
|
|
|
def save_video(self, video, output_path): |
|
return self.driver_worker.save_video(video, output_path) |
|
|
|
def shutdown(self): |
|
if (worker_monitor := getattr(self, "worker_monitor", None)) is not None: |
|
worker_monitor.close() |
|
torch.distributed.destroy_process_group() |
|
|
|
def __del__(self): |
|
self.shutdown() |
|
|