| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import gc |
| import os |
| from abc import ABC |
| from typing import Any |
|
|
| import numpy as np |
| import torch |
|
|
| from .t5_text_encoder import CosmosT5TextEncoder |
| from .presets import presets as guardrail_presets |
|
|
|
|
| class BaseWorldGenerationPipeline(ABC): |
| def __init__( |
| self, |
| inference_type: str | None = None, |
| checkpoint_dir: str | None = None, |
| checkpoint_name: str | None = None, |
| enable_text_guardrail: bool = False, |
| enable_video_guardrail: bool = False, |
| offload_network: bool = False, |
| offload_tokenizer: bool = False, |
| offload_text_encoder_model: bool = False, |
| offload_guardrail_models: bool = False, |
| ): |
| """Initialize base world generation pipeline. |
| |
| This abstract base class provides core functionality for world generation models including: |
| - Model loading and initialization |
| - Text encoding and embedding |
| - Safety checks and content filtering |
| - Memory management through model offloading |
| |
| Args: |
| inference_type: The type of inference pipeline ("text2world" or "video2world") |
| checkpoint_dir: Root directory containing model checkpoints |
| checkpoint_name: Name of the specific checkpoint file to load |
| enable_text_guardrail: If True, validates input prompts for safety |
| enable_video_guardrail: If True, validates generated videos for safety |
| offload_network: If True, moves main model to CPU after inference |
| offload_tokenizer: If True, moves tokenizer to CPU after use |
| offload_text_encoder_model: If True, moves T5 encoder to CPU after encoding |
| offload_guardrail_models: If True, moves safety models to CPU after checks |
| """ |
| self.inference_type = inference_type |
| self.checkpoint_dir = checkpoint_dir |
| self.checkpoint_name = checkpoint_name |
| self.guardrail_dir = "Cosmos-1.0-Guardrail" |
| self.enable_text_guardrail = enable_text_guardrail |
| self.enable_video_guardrail = enable_video_guardrail |
|
|
| |
| self.offload_network = offload_network |
| self.offload_tokenizer = offload_tokenizer |
| self.offload_text_encoder_model = offload_text_encoder_model |
| self.offload_guardrail_models = offload_guardrail_models |
|
|
| |
| self.text_guardrail = None |
| self.video_guardrail = None |
| self.text_encoder = None |
| self.model = None |
|
|
| self._load_model() |
|
|
| if not self.offload_text_encoder_model: |
| self._load_text_encoder_model() |
| if not self.offload_guardrail_models: |
| if self.enable_text_guardrail: |
| self._load_text_guardrail() |
| if self.enable_video_guardrail: |
| self._load_video_guardrail() |
| if not self.offload_network: |
| self._load_network() |
| if not self.offload_tokenizer: |
| self._load_tokenizer() |
|
|
| def _load_tokenizer(self): |
| pass |
|
|
| def _load_network(self): |
| pass |
|
|
| def _load_model(self, checkpoint_name: str) -> Any: |
| """Load the world generation model from a checkpoint. |
| |
| This abstract method must be implemented by subclasses to load their specific |
| model architecture and weights. |
| |
| Args: |
| checkpoint_name: Path to the model checkpoint file |
| |
| Returns: |
| The loaded model instance |
| |
| Raises: |
| NotImplementedError: Must be implemented by subclasses |
| """ |
| pass |
|
|
| def _load_text_encoder_model(self): |
| """Load the T5 text encoder model. |
| |
| Initializes and loads the T5 encoder model used for converting text prompts |
| into embeddings that condition the world generation model. |
| |
| Returns: |
| Loaded T5 text encoder model instance |
| """ |
| self.text_encoder = CosmosT5TextEncoder(cache_dir=self.checkpoint_dir) |
|
|
| def _load_text_guardrail(self): |
| """Load text safety classifier models. |
| |
| Initializes models used for checking input prompts against safety policies. |
| Models are loaded from the specified guardrail directory. |
| """ |
| self.text_guardrail = guardrail_presets.create_text_guardrail_runner( |
| checkpoint_dir=os.path.join(self.checkpoint_dir, self.guardrail_dir) |
| ) |
|
|
| def _load_video_guardrail(self): |
| """Load video safety classifier models. |
| |
| Initializes models used for validating generated video content against |
| safety policies. Models are loaded from the specified guardrail directory. |
| """ |
| self.video_guardrail = guardrail_presets.create_video_guardrail_runner( |
| checkpoint_dir=os.path.join(self.checkpoint_dir, self.guardrail_dir) |
| ) |
|
|
| def _offload_network(self): |
| if self.model.model: |
| del self.model.model |
| self.model.model = None |
| gc.collect() |
| torch.cuda.empty_cache() |
|
|
| def _offload_tokenizer(self): |
| if self.model.tokenizer: |
| del self.model.tokenizer |
| self.model.tokenizer = None |
| gc.collect() |
| torch.cuda.empty_cache() |
|
|
| def _offload_guardrail_models(self): |
| """Offload safety classifier models to reduce memory usage. |
| |
| Moves safety models to CPU and clears GPU memory if they are no longer needed. |
| This helps manage memory when processing multiple inputs sequentially. |
| """ |
| if self.text_guardrail: |
| del self.text_guardrail |
| self.text_guardrail = None |
| if self.video_guardrail: |
| del self.video_guardrail |
| self.video_guardrail = None |
| gc.collect() |
| torch.cuda.empty_cache() |
|
|
| def _offload_text_encoder_model(self): |
| """Offload T5 text encoder to reduce memory usage. |
| |
| Moves the T5 encoder to CPU and clears GPU memory after text encoding is complete. |
| This helps manage memory when processing multiple inputs sequentially. |
| """ |
| if self.text_encoder: |
| del self.text_encoder |
| self.text_encoder = None |
| gc.collect() |
| torch.cuda.empty_cache() |
|
|
| def _run_model(self, *args: Any, **kwargs: Any) -> torch.Tensor: |
| """Generate world latents using the model. |
| |
| This abstract method must be implemented by subclasses to define their specific |
| generation process. |
| |
| Args: |
| *args: Variable positional arguments for model inference |
| **kwargs: Variable keyword arguments for model inference |
| |
| Returns: |
| torch.Tensor: Generated world representation tensor |
| """ |
| pass |
|
|
| def _run_model_with_offload(self, *args: Any, **kwargs: Any) -> torch.Tensor: |
| """Generate world representation with memory management. |
| |
| Handles loading the model before inference and offloading afterward if enabled. |
| This helps minimize GPU memory usage during inference. |
| |
| Args: |
| *args: Arguments passed to _run_model |
| **kwargs: Keyword arguments passed to _run_model |
| |
| Returns: |
| np.ndarray: Generated world representation as numpy array |
| """ |
| pass |
|
|
| def _run_guardrail_on_prompt(self, prompt: str) -> bool: |
| """Check if prompt meets safety requirements. |
| |
| Validates the input prompt against safety policies using loaded guardrail models. |
| |
| Args: |
| prompt: Raw text prompt to validate |
| |
| Returns: |
| bool: True if prompt passes all safety checks, False otherwise |
| """ |
| return guardrail_presets.run_text_guardrail(prompt, self.text_guardrail) |
|
|
| def _run_guardrail_on_prompt_with_offload(self, prompt: str) -> bool: |
| """Check prompt safety with memory management. |
| |
| Validates prompt safety while handling model loading/offloading to manage memory. |
| |
| Args: |
| prompt: Raw text prompt to validate |
| |
| Returns: |
| bool: True if prompt passes all safety checks, False otherwise |
| """ |
| if self.offload_guardrail_models: |
| self._load_text_guardrail() |
|
|
| is_safe = self._run_guardrail_on_prompt(prompt) |
|
|
| if self.offload_guardrail_models: |
| self._offload_guardrail_models() |
|
|
| return is_safe |
|
|
| def _run_guardrail_on_video(self, video: np.ndarray) -> np.ndarray | None: |
| """Check if video meets safety requirements. |
| |
| Validates generated video content against safety policies using guardrail models. |
| |
| Args: |
| video: Video frames to validate |
| |
| Returns: |
| np.ndarray: Processed video if safe, None if unsafe |
| """ |
| return guardrail_presets.run_video_guardrail(video, self.video_guardrail) |
|
|
| def _run_guardrail_on_video_with_offload(self, video: np.ndarray) -> np.ndarray | None: |
| """Check if generated video meets safety requirements. |
| |
| Args: |
| video: Video frames to validate |
| |
| Returns: |
| np.ndarray: Processed video frames if safe, None otherwise |
| |
| Note: |
| Guardrail models are offloaded after checks if enabled. |
| """ |
| if self.offload_guardrail_models: |
| self._load_video_guardrail() |
|
|
| video = self._run_guardrail_on_video(video) |
|
|
| if self.offload_guardrail_models: |
| self._offload_guardrail_models() |
| return video |
|
|
| def _run_text_embedding_on_prompt( |
| self, prompts: list[str], **kwargs: Any |
| ) -> tuple[list[torch.Tensor], list[torch.Tensor]]: |
| """Convert text prompts to embeddings. |
| |
| Processes text prompts into embedding tensors that condition the generation model. |
| |
| Args: |
| prompts: List of text prompts to encode |
| **kwargs: Additional arguments for text encoding |
| |
| Returns: |
| tuple containing: |
| - List of text embedding tensors for each prompt |
| - List of attention masks for each embedding |
| """ |
|
|
| embeddings = [] |
| masks = [] |
| for prompt in prompts: |
| embedding, mask = self.text_encoder.encode_prompts( |
| [prompt], |
| **kwargs, |
| ) |
| embeddings.append(embedding) |
| masks.append(mask) |
|
|
| return embeddings, masks |
|
|
| def _run_text_embedding_on_prompt_with_offload( |
| self, prompts: list[str], **kwargs: Any |
| ) -> tuple[list[torch.Tensor], list[torch.Tensor]]: |
| """Convert text prompt into embeddings using T5 encoder. |
| |
| Args: |
| prompt: Processed and validated text prompt |
| |
| Returns: |
| Text embedding tensor to condition diffusion model |
| |
| Note: |
| T5 model is offloaded after encoding if enabled. |
| """ |
| if self.offload_text_encoder_model: |
| self._load_text_encoder_model() |
|
|
| embeddings, masks = self._run_text_embedding_on_prompt(prompts, **kwargs) |
|
|
| if self.offload_text_encoder_model: |
| self._offload_text_encoder_model() |
| return embeddings, masks |
|
|
| def _run_tokenizer_decoding(self, samples: torch.Tensor) -> np.ndarray: |
| """Decode model outputs into final world representation. |
| |
| This abstract method must be implemented by subclasses to convert raw model |
| outputs into their specific world representation format. |
| |
| Args: |
| samples: Raw output tensor from the generation model |
| |
| Returns: |
| np.ndarray: Decoded world representation |
| """ |
| pass |
|
|
| def generate(self, *args: Any, **kwargs: Any): |
| """Generate world representation. |
| |
| This abstract method must be implemented by subclasses to convert raw model |
| outputs into their specific world representation format. |
| |
| Args: |
| *args: Variable positional arguments for model inference |
| **kwargs: Variable keyword arguments for model inference |
| """ |
| pass |
|
|