import os import random from abc import ABC, abstractmethod from contextlib import contextmanager from functools import partial from typing import Any, Dict, List, Literal, Optional, Union, cast import numpy as np import ray import torch import torch.distributed as dist import torch.nn as nn import torch.nn.functional as F from einops import rearrange, repeat from safetensors.torch import load_file from torch import nn from torch.distributed.fsdp import ( BackwardPrefetch, MixedPrecision, ShardingStrategy, ) from torch.distributed.fsdp import ( FullyShardedDataParallel as FSDP, ) from torch.distributed.fsdp.wrap import ( lambda_auto_wrap_policy, transformer_auto_wrap_policy, ) from transformers import T5EncoderModel, T5Tokenizer from transformers.models.t5.modeling_t5 import T5Block import genmo.mochi_preview.dit.joint_model.context_parallel as cp import genmo.mochi_preview.vae.cp_conv as cp_conv from genmo.mochi_preview.vae.model import Decoder, apply_tiled from genmo.lib.progress import get_new_progress_bar, progress_bar from genmo.lib.utils import Timer def linear_quadratic_schedule(num_steps, threshold_noise, linear_steps=None): if linear_steps is None: linear_steps = num_steps // 2 linear_sigma_schedule = [i * threshold_noise / linear_steps for i in range(linear_steps)] threshold_noise_step_diff = linear_steps - threshold_noise * num_steps quadratic_steps = num_steps - linear_steps quadratic_coef = threshold_noise_step_diff / (linear_steps * quadratic_steps**2) linear_coef = threshold_noise / linear_steps - 2 * threshold_noise_step_diff / (quadratic_steps**2) const = quadratic_coef * (linear_steps**2) quadratic_sigma_schedule = [ quadratic_coef * (i**2) + linear_coef * i + const for i in range(linear_steps, num_steps) ] sigma_schedule = linear_sigma_schedule + quadratic_sigma_schedule + [1.0] sigma_schedule = [1.0 - x for x in sigma_schedule] return sigma_schedule T5_MODEL = "google/t5-v1_1-xxl" MAX_T5_TOKEN_LENGTH = 256 def setup_fsdp_sync(model, device_id, *, param_dtype, auto_wrap_policy) -> FSDP: model = FSDP( model, sharding_strategy=ShardingStrategy.FULL_SHARD, mixed_precision=MixedPrecision( param_dtype=param_dtype, reduce_dtype=torch.float32, buffer_dtype=torch.float32, ), auto_wrap_policy=auto_wrap_policy, backward_prefetch=BackwardPrefetch.BACKWARD_PRE, limit_all_gathers=True, device_id=device_id, sync_module_states=True, use_orig_params=True, ) torch.cuda.synchronize() return model class ModelFactory(ABC): def __init__(self, **kwargs): self.kwargs = kwargs @abstractmethod def get_model(self, *, local_rank: int, device_id: Union[int, Literal["cpu"]], world_size: int) -> Any: if device_id == "cpu": assert world_size == 1, "CPU offload only supports single-GPU inference" class T5ModelFactory(ModelFactory): def __init__(self): super().__init__() def get_model(self, *, local_rank, device_id, world_size): super().get_model(local_rank=local_rank, device_id=device_id, world_size=world_size) model = T5EncoderModel.from_pretrained(T5_MODEL) if world_size > 1: model = setup_fsdp_sync( model, device_id=device_id, param_dtype=torch.float32, auto_wrap_policy=partial( transformer_auto_wrap_policy, transformer_layer_cls={ T5Block, }, ), ) elif isinstance(device_id, int): model = model.to(torch.device(f"cuda:{device_id}")) # type: ignore return model.eval() class DitModelFactory(ModelFactory): def __init__(self, *, model_path: str, model_dtype: str, attention_mode: Optional[str] = None): if attention_mode is None: from genmo.lib.attn_imports import flash_varlen_qkvpacked_attn # type: ignore attention_mode = "sdpa" if flash_varlen_qkvpacked_attn is None else "flash" print(f"Attention mode: {attention_mode}") super().__init__(model_path=model_path, model_dtype=model_dtype, attention_mode=attention_mode) def get_model(self, *, local_rank, device_id, world_size): # TODO(ved): Set flag for torch.compile from genmo.mochi_preview.dit.joint_model.asymm_models_joint import ( AsymmDiTJoint, ) model: nn.Module = torch.nn.utils.skip_init( AsymmDiTJoint, depth=48, patch_size=2, num_heads=24, hidden_size_x=3072, hidden_size_y=1536, mlp_ratio_x=4.0, mlp_ratio_y=4.0, in_channels=12, qk_norm=True, qkv_bias=False, out_bias=True, patch_embed_bias=True, timestep_mlp_bias=True, timestep_scale=1000.0, t5_feat_dim=4096, t5_token_length=256, rope_theta=10000.0, attention_mode=self.kwargs["attention_mode"], ) if local_rank == 0: # FSDP syncs weights from rank 0 to all other ranks model.load_state_dict(load_file(self.kwargs["model_path"])) if world_size > 1: assert self.kwargs["model_dtype"] == "bf16", "FP8 is not supported for multi-GPU inference" model = setup_fsdp_sync( model, device_id=device_id, param_dtype=torch.bfloat16, auto_wrap_policy=partial( lambda_auto_wrap_policy, lambda_fn=lambda m: m in model.blocks, ), ) elif isinstance(device_id, int): model = model.to(torch.device(f"cuda:{device_id}")) return model.eval() class DecoderModelFactory(ModelFactory): def __init__(self, *, model_path: str, model_stats_path: str): super().__init__(model_path=model_path, model_stats_path=model_stats_path) def get_model(self, *, local_rank, device_id, world_size): # TODO(ved): Set flag for torch.compile # TODO(ved): Use skip_init import json decoder = Decoder( out_channels=3, base_channels=128, channel_multipliers=[1, 2, 4, 6], temporal_expansions=[1, 2, 3], spatial_expansions=[2, 2, 2], num_res_blocks=[3, 3, 4, 6, 3], latent_dim=12, has_attention=[False, False, False, False, False], padding_mode="replicate", output_norm=False, nonlinearity="silu", output_nonlinearity="silu", causal=True, ) # VAE is not FSDP-wrapped state_dict = load_file(self.kwargs["model_path"]) decoder.load_state_dict(state_dict, strict=True) device = torch.device(f"cuda:{device_id}") if isinstance(device_id, int) else "cpu" decoder.eval().to(device) vae_stats = json.load(open(self.kwargs["model_stats_path"])) decoder.register_buffer("vae_mean", torch.tensor(vae_stats["mean"], device=device)) decoder.register_buffer("vae_std", torch.tensor(vae_stats["std"], device=device)) return decoder def get_conditioning(tokenizer, encoder, device, batch_inputs, *, prompt: str, negative_prompt: str): if batch_inputs: return dict(batched=get_conditioning_for_prompts(tokenizer, encoder, device, [prompt, negative_prompt])) else: cond_input = get_conditioning_for_prompts(tokenizer, encoder, device, [prompt]) null_input = get_conditioning_for_prompts(tokenizer, encoder, device, [negative_prompt]) return dict(cond=cond_input, null=null_input) def get_conditioning_for_prompts(tokenizer, encoder, device, prompts: List[str]): assert len(prompts) in [1, 2] # [neg] or [pos] or [pos, neg] B = len(prompts) t5_toks = tokenizer( prompts, padding="max_length", truncation=True, max_length=MAX_T5_TOKEN_LENGTH, return_tensors="pt", return_attention_mask=True, ) caption_input_ids_t5 = t5_toks["input_ids"] caption_attention_mask_t5 = t5_toks["attention_mask"].bool() del t5_toks assert caption_input_ids_t5.shape == (B, MAX_T5_TOKEN_LENGTH) assert caption_attention_mask_t5.shape == (B, MAX_T5_TOKEN_LENGTH) # Special-case empty negative prompt by zero-ing it if prompts[-1] == "": caption_input_ids_t5[-1] = 0 caption_attention_mask_t5[-1] = False caption_input_ids_t5 = caption_input_ids_t5.to(device, non_blocking=True) caption_attention_mask_t5 = caption_attention_mask_t5.to(device, non_blocking=True) y_mask = [caption_attention_mask_t5] y_feat = [encoder(caption_input_ids_t5, caption_attention_mask_t5).last_hidden_state.detach()] # Sometimes returns a tensor, othertimes a tuple, not sure why # See: https://huggingface.co/genmo/mochi-1-preview/discussions/3 assert tuple(y_feat[-1].shape) == (B, MAX_T5_TOKEN_LENGTH, 4096) assert y_feat[-1].dtype == torch.float32 return dict(y_mask=y_mask, y_feat=y_feat) def compute_packed_indices( device: torch.device, text_mask: torch.Tensor, num_latents: int ) -> Dict[str, Union[torch.Tensor, int]]: """ Based on https://github.com/Dao-AILab/flash-attention/blob/765741c1eeb86c96ee71a3291ad6968cfbf4e4a1/flash_attn/bert_padding.py#L60-L80 Args: num_latents: Number of latent tokens text_mask: (B, L) List of boolean tensor indicating which text tokens are not padding. Returns: packed_indices: Dict with keys for Flash Attention: - valid_token_indices_kv: up to (B * (N + L),) tensor of valid token indices (non-padding) in the packed sequence. - cu_seqlens_kv: (B + 1,) tensor of cumulative sequence lengths in the packed sequence. - max_seqlen_in_batch_kv: int of the maximum sequence length in the batch. """ # Create an expanded token mask saying which tokens are valid across both visual and text tokens. PATCH_SIZE = 2 num_visual_tokens = num_latents // (PATCH_SIZE**2) assert num_visual_tokens > 0 mask = F.pad(text_mask, (num_visual_tokens, 0), value=True) # (B, N + L) seqlens_in_batch = mask.sum(dim=-1, dtype=torch.int32) # (B,) valid_token_indices = torch.nonzero(mask.flatten(), as_tuple=False).flatten() # up to (B * (N + L),) assert valid_token_indices.size(0) >= text_mask.size(0) * num_visual_tokens # At least (B * N,) cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) max_seqlen_in_batch = seqlens_in_batch.max().item() return { "cu_seqlens_kv": cu_seqlens.to(device, non_blocking=True), "max_seqlen_in_batch_kv": cast(int, max_seqlen_in_batch), "valid_token_indices_kv": valid_token_indices.to(device, non_blocking=True), } def assert_eq(x, y, msg=None): assert x == y, f"{msg or 'Assertion failed'}: {x} != {y}" def sample_model(device, dit, conditioning, **args): random.seed(args["seed"]) np.random.seed(args["seed"]) torch.manual_seed(args["seed"]) generator = torch.Generator(device=device) generator.manual_seed(args["seed"]) w, h, t = args["width"], args["height"], args["num_frames"] sample_steps = args["num_inference_steps"] cfg_schedule = args["cfg_schedule"] sigma_schedule = args["sigma_schedule"] assert_eq(len(cfg_schedule), sample_steps, "cfg_schedule must have length sample_steps") assert_eq((t - 1) % 6, 0, "t - 1 must be divisible by 6") assert_eq( len(sigma_schedule), sample_steps + 1, "sigma_schedule must have length sample_steps + 1", ) B = 1 SPATIAL_DOWNSAMPLE = 8 TEMPORAL_DOWNSAMPLE = 6 IN_CHANNELS = 12 latent_t = ((t - 1) // TEMPORAL_DOWNSAMPLE) + 1 latent_w, latent_h = w // SPATIAL_DOWNSAMPLE, h // SPATIAL_DOWNSAMPLE z = torch.randn( (B, IN_CHANNELS, latent_t, latent_h, latent_w), device=device, dtype=torch.float32, ) num_latents = latent_t * latent_h * latent_w cond_batched = cond_text = cond_null = None if "cond" in conditioning: cond_text = conditioning["cond"] cond_null = conditioning["null"] cond_text["packed_indices"] = compute_packed_indices(device, cond_text["y_mask"][0], num_latents) cond_null["packed_indices"] = compute_packed_indices(device, cond_null["y_mask"][0], num_latents) else: cond_batched = conditioning["batched"] cond_batched["packed_indices"] = compute_packed_indices(device, cond_batched["y_mask"][0], num_latents) z = repeat(z, "b ... -> (repeat b) ...", repeat=2) def model_fn(*, z, sigma, cfg_scale): if cond_batched: with torch.autocast("cuda", dtype=torch.bfloat16): out = dit(z, sigma, **cond_batched) out_cond, out_uncond = torch.chunk(out, chunks=2, dim=0) else: nonlocal cond_text, cond_null with torch.autocast("cuda", dtype=torch.bfloat16): out_cond = dit(z, sigma, **cond_text) out_uncond = dit(z, sigma, **cond_null) assert out_cond.shape == out_uncond.shape return out_uncond + cfg_scale * (out_cond - out_uncond), out_cond for i in get_new_progress_bar(range(0, sample_steps), desc="Sampling"): sigma = sigma_schedule[i] dsigma = sigma - sigma_schedule[i + 1] # `pred` estimates `z_0 - eps`. pred, output_cond = model_fn( z=z, sigma=torch.full([B] if cond_text else [B * 2], sigma, device=z.device), cfg_scale=cfg_schedule[i], ) pred = pred.to(z) output_cond = output_cond.to(z) z = z + dsigma * pred return z[:B] if cond_batched else z def decoded_latents_to_frames(samples): samples = samples.float() samples = (samples + 1.0) / 2.0 samples.clamp_(0.0, 1.0) frames = rearrange(samples, "b c t h w -> b t h w c") return frames def decode_latents(decoder, z): cp_rank, cp_size = cp.get_cp_rank_size() z = z.tensor_split(cp_size, dim=2)[cp_rank] # split along temporal dim with torch.autocast("cuda", dtype=torch.bfloat16): samples = decoder(z) samples = cp_conv.gather_all_frames(samples) return decoded_latents_to_frames(samples) @torch.inference_mode() def decode_latents_tiled_full( decoder, z, *, tile_sample_min_height: int = 240, tile_sample_min_width: int = 424, tile_overlap_factor_height: float = 0.1666, tile_overlap_factor_width: float = 0.2, auto_tile_size: bool = True, frame_batch_size: int = 6, ): B, C, T, H, W = z.shape assert frame_batch_size <= T, f"frame_batch_size must be <= T, got {frame_batch_size} > {T}" tile_sample_min_height = tile_sample_min_height if not auto_tile_size else H // 2 * 8 tile_sample_min_width = tile_sample_min_width if not auto_tile_size else W // 2 * 8 tile_latent_min_height = int(tile_sample_min_height / 8) tile_latent_min_width = int(tile_sample_min_width / 8) def blend_v(a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: blend_extent = min(a.shape[3], b.shape[3], blend_extent) for y in range(blend_extent): b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * ( y / blend_extent ) return b def blend_h(a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: blend_extent = min(a.shape[4], b.shape[4], blend_extent) for x in range(blend_extent): b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * ( x / blend_extent ) return b overlap_height = int(tile_latent_min_height * (1 - tile_overlap_factor_height)) overlap_width = int(tile_latent_min_width * (1 - tile_overlap_factor_width)) blend_extent_height = int(tile_sample_min_height * tile_overlap_factor_height) blend_extent_width = int(tile_sample_min_width * tile_overlap_factor_width) row_limit_height = tile_sample_min_height - blend_extent_height row_limit_width = tile_sample_min_width - blend_extent_width # Split z into overlapping tiles and decode them separately. # The tiles have an overlap to avoid seams between tiles. pbar = get_new_progress_bar( desc="Decoding latent tiles", total=len(range(0, H, overlap_height)) * len(range(0, W, overlap_width)) * len(range(T // frame_batch_size)), ) rows = [] for i in range(0, H, overlap_height): row = [] for j in range(0, W, overlap_width): temporal = [] for k in range(T // frame_batch_size): remaining_frames = T % frame_batch_size start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames) end_frame = frame_batch_size * (k + 1) + remaining_frames tile = z[ :, :, start_frame:end_frame, i : i + tile_latent_min_height, j : j + tile_latent_min_width, ] tile = decoder(tile) temporal.append(tile) pbar.update(1) row.append(torch.cat(temporal, dim=2)) rows.append(row) result_rows = [] for i, row in enumerate(rows): result_row = [] for j, tile in enumerate(row): # blend the above tile and the left tile # to the current tile and add the current tile to the result row if i > 0: tile = blend_v(rows[i - 1][j], tile, blend_extent_height) if j > 0: tile = blend_h(row[j - 1], tile, blend_extent_width) result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width]) result_rows.append(torch.cat(result_row, dim=4)) return decoded_latents_to_frames(torch.cat(result_rows, dim=3)) @torch.inference_mode() def decode_latents_tiled_spatial( decoder, z, *, num_tiles_w: int, num_tiles_h: int, overlap: int = 0, # Number of pixel of overlap between adjacent tiles. # Use a factor of 2 times the latent downsample factor. min_block_size: int = 1, # Minimum number of pixels in each dimension when subdividing. ): decoded = apply_tiled(decoder, z, num_tiles_w, num_tiles_h, overlap, min_block_size) assert decoded is not None, f"Failed to decode latents with tiled spatial method" return decoded @contextmanager def move_to_device(model: nn.Module, target_device): og_device = next(model.parameters()).device if og_device == target_device: print(f"move_to_device is a no-op model is already on {target_device}") else: print(f"moving model from {og_device} -> {target_device}") model.to(target_device) yield if og_device != target_device: print(f"moving model from {target_device} -> {og_device}") model.to(og_device) def t5_tokenizer(): return T5Tokenizer.from_pretrained(T5_MODEL, legacy=False) class MochiSingleGPUPipeline: def __init__( self, *, text_encoder_factory: ModelFactory, dit_factory: ModelFactory, decoder_factory: ModelFactory, cpu_offload: Optional[bool] = False, decode_type: str = "full", decode_args: Optional[Dict[str, Any]] = None, ): self.device = torch.device("cuda:0") self.tokenizer = t5_tokenizer() t = Timer() self.cpu_offload = cpu_offload self.decode_args = decode_args or {} self.decode_type = decode_type init_id = "cpu" if cpu_offload else 0 with t("load_text_encoder"): self.text_encoder = text_encoder_factory.get_model( local_rank=0, device_id=init_id, world_size=1, ) with t("load_dit"): self.dit = dit_factory.get_model(local_rank=0, device_id=init_id, world_size=1) with t("load_vae"): self.decoder = decoder_factory.get_model(local_rank=0, device_id=init_id, world_size=1) t.print_stats() def __call__(self, batch_cfg, prompt, negative_prompt, **kwargs): with progress_bar(type="tqdm"), torch.inference_mode(): print_max_memory = lambda: print( f"Max memory reserved: {torch.cuda.max_memory_reserved() / 1024**3:.2f} GB" ) print_max_memory() with move_to_device(self.text_encoder, self.device): conditioning = get_conditioning( self.tokenizer, self.text_encoder, self.device, batch_cfg, prompt=prompt, negative_prompt=negative_prompt, ) print_max_memory() with move_to_device(self.dit, self.device): latents = sample_model(self.device, self.dit, conditioning, **kwargs) print_max_memory() with move_to_device(self.decoder, self.device): frames = ( decode_latents_tiled_full(self.decoder, latents, **self.decode_args) if self.decode_type == "tiled_full" else decode_latents_tiled_spatial(self.decoder, latents, **self.decode_args) if self.decode_type == "tiled_spatial" else decode_latents(self.decoder, latents) ) print_max_memory() return frames.cpu().numpy() ### ALL CODE BELOW HERE IS FOR MULTI-GPU MODE ### # In multi-gpu mode, all models must belong to a device which has a predefined context parallel group # So it doesn't make sense to work with models individually class MultiGPUContext: def __init__( self, *, text_encoder_factory, dit_factory, decoder_factory, device_id, local_rank, world_size, ): t = Timer() self.device = torch.device(f"cuda:{device_id}") print(f"Initializing rank {local_rank+1}/{world_size}") assert world_size > 1, f"Multi-GPU mode requires world_size > 1, got {world_size}" os.environ["MASTER_ADDR"] = "127.0.0.1" os.environ["MASTER_PORT"] = "29500" with t("init_process_group"): dist.init_process_group( "nccl", rank=local_rank, world_size=world_size, device_id=self.device, # force non-lazy init ) pg = dist.group.WORLD cp.set_cp_group(pg, list(range(world_size)), local_rank) distributed_kwargs = dict(local_rank=local_rank, device_id=device_id, world_size=world_size) self.world_size = world_size self.tokenizer = t5_tokenizer() with t("load_text_encoder"): self.text_encoder = text_encoder_factory.get_model(**distributed_kwargs) with t("load_dit"): self.dit = dit_factory.get_model(**distributed_kwargs) with t("load_vae"): self.decoder = decoder_factory.get_model(**distributed_kwargs) self.local_rank = local_rank t.print_stats() def run(self, *, fn, **kwargs): return fn(self, **kwargs) class MochiMultiGPUPipeline: def __init__( self, *, text_encoder_factory: ModelFactory, dit_factory: ModelFactory, decoder_factory: ModelFactory, world_size: int, ): ray.init() RemoteClass = ray.remote(MultiGPUContext) self.ctxs = [ RemoteClass.options(num_gpus=1).remote( text_encoder_factory=text_encoder_factory, dit_factory=dit_factory, decoder_factory=decoder_factory, world_size=world_size, device_id=0, local_rank=i, ) for i in range(world_size) ] for ctx in self.ctxs: ray.get(ctx.__ray_ready__.remote()) def __call__(self, **kwargs): def sample(ctx, *, batch_cfg, prompt, negative_prompt, **kwargs): with progress_bar(type="ray_tqdm", enabled=ctx.local_rank == 0), torch.inference_mode(): conditioning = get_conditioning( ctx.tokenizer, ctx.text_encoder, ctx.device, batch_cfg, prompt=prompt, negative_prompt=negative_prompt, ) latents = sample_model(ctx.device, ctx.dit, conditioning=conditioning, **kwargs) if ctx.local_rank == 0: torch.save(latents, "latents.pt") frames = decode_latents(ctx.decoder, latents) return frames.cpu().numpy() return ray.get([ctx.run.remote(fn=sample, **kwargs, show_progress=i == 0) for i, ctx in enumerate(self.ctxs)])[ 0 ]