|
import os |
|
import random |
|
from abc import ABC, abstractmethod |
|
from contextlib import contextmanager |
|
from functools import partial |
|
from typing import Any, Dict, List, Literal, Optional, Union, cast |
|
|
|
import numpy as np |
|
import ray |
|
import torch |
|
import torch.distributed as dist |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from einops import rearrange, repeat |
|
from safetensors.torch import load_file |
|
from torch import nn |
|
from torch.distributed.fsdp import ( |
|
BackwardPrefetch, |
|
MixedPrecision, |
|
ShardingStrategy, |
|
) |
|
from torch.distributed.fsdp import ( |
|
FullyShardedDataParallel as FSDP, |
|
) |
|
from torch.distributed.fsdp.wrap import ( |
|
lambda_auto_wrap_policy, |
|
transformer_auto_wrap_policy, |
|
) |
|
from transformers import T5EncoderModel, T5Tokenizer |
|
from transformers.models.t5.modeling_t5 import T5Block |
|
|
|
import genmo.mochi_preview.dit.joint_model.context_parallel as cp |
|
import genmo.mochi_preview.vae.cp_conv as cp_conv |
|
from genmo.mochi_preview.vae.model import Decoder, apply_tiled |
|
from genmo.lib.progress import get_new_progress_bar, progress_bar |
|
from genmo.lib.utils import Timer |
|
|
|
|
|
def linear_quadratic_schedule(num_steps, threshold_noise, linear_steps=None): |
|
if linear_steps is None: |
|
linear_steps = num_steps // 2 |
|
linear_sigma_schedule = [i * threshold_noise / linear_steps for i in range(linear_steps)] |
|
threshold_noise_step_diff = linear_steps - threshold_noise * num_steps |
|
quadratic_steps = num_steps - linear_steps |
|
quadratic_coef = threshold_noise_step_diff / (linear_steps * quadratic_steps**2) |
|
linear_coef = threshold_noise / linear_steps - 2 * threshold_noise_step_diff / (quadratic_steps**2) |
|
const = quadratic_coef * (linear_steps**2) |
|
quadratic_sigma_schedule = [ |
|
quadratic_coef * (i**2) + linear_coef * i + const for i in range(linear_steps, num_steps) |
|
] |
|
sigma_schedule = linear_sigma_schedule + quadratic_sigma_schedule + [1.0] |
|
sigma_schedule = [1.0 - x for x in sigma_schedule] |
|
return sigma_schedule |
|
|
|
|
|
T5_MODEL = "google/t5-v1_1-xxl" |
|
MAX_T5_TOKEN_LENGTH = 256 |
|
|
|
|
|
def setup_fsdp_sync(model, device_id, *, param_dtype, auto_wrap_policy) -> FSDP: |
|
model = FSDP( |
|
model, |
|
sharding_strategy=ShardingStrategy.FULL_SHARD, |
|
mixed_precision=MixedPrecision( |
|
param_dtype=param_dtype, |
|
reduce_dtype=torch.float32, |
|
buffer_dtype=torch.float32, |
|
), |
|
auto_wrap_policy=auto_wrap_policy, |
|
backward_prefetch=BackwardPrefetch.BACKWARD_PRE, |
|
limit_all_gathers=True, |
|
device_id=device_id, |
|
sync_module_states=True, |
|
use_orig_params=True, |
|
) |
|
torch.cuda.synchronize() |
|
return model |
|
|
|
|
|
class ModelFactory(ABC): |
|
def __init__(self, **kwargs): |
|
self.kwargs = kwargs |
|
|
|
@abstractmethod |
|
def get_model(self, *, local_rank: int, device_id: Union[int, Literal["cpu"]], world_size: int) -> Any: |
|
if device_id == "cpu": |
|
assert world_size == 1, "CPU offload only supports single-GPU inference" |
|
|
|
|
|
class T5ModelFactory(ModelFactory): |
|
def __init__(self): |
|
super().__init__() |
|
|
|
def get_model(self, *, local_rank, device_id, world_size): |
|
super().get_model(local_rank=local_rank, device_id=device_id, world_size=world_size) |
|
model = T5EncoderModel.from_pretrained(T5_MODEL) |
|
if world_size > 1: |
|
model = setup_fsdp_sync( |
|
model, |
|
device_id=device_id, |
|
param_dtype=torch.float32, |
|
auto_wrap_policy=partial( |
|
transformer_auto_wrap_policy, |
|
transformer_layer_cls={ |
|
T5Block, |
|
}, |
|
), |
|
) |
|
elif isinstance(device_id, int): |
|
model = model.to(torch.device(f"cuda:{device_id}")) |
|
return model.eval() |
|
|
|
|
|
class DitModelFactory(ModelFactory): |
|
def __init__(self, *, model_path: str, model_dtype: str, attention_mode: Optional[str] = None): |
|
if attention_mode is None: |
|
from genmo.lib.attn_imports import flash_varlen_qkvpacked_attn |
|
|
|
attention_mode = "sdpa" if flash_varlen_qkvpacked_attn is None else "flash" |
|
print(f"Attention mode: {attention_mode}") |
|
super().__init__(model_path=model_path, model_dtype=model_dtype, attention_mode=attention_mode) |
|
|
|
def get_model(self, *, local_rank, device_id, world_size): |
|
|
|
from genmo.mochi_preview.dit.joint_model.asymm_models_joint import ( |
|
AsymmDiTJoint, |
|
) |
|
|
|
model: nn.Module = torch.nn.utils.skip_init( |
|
AsymmDiTJoint, |
|
depth=48, |
|
patch_size=2, |
|
num_heads=24, |
|
hidden_size_x=3072, |
|
hidden_size_y=1536, |
|
mlp_ratio_x=4.0, |
|
mlp_ratio_y=4.0, |
|
in_channels=12, |
|
qk_norm=True, |
|
qkv_bias=False, |
|
out_bias=True, |
|
patch_embed_bias=True, |
|
timestep_mlp_bias=True, |
|
timestep_scale=1000.0, |
|
t5_feat_dim=4096, |
|
t5_token_length=256, |
|
rope_theta=10000.0, |
|
attention_mode=self.kwargs["attention_mode"], |
|
) |
|
|
|
if local_rank == 0: |
|
|
|
model.load_state_dict(load_file(self.kwargs["model_path"])) |
|
|
|
if world_size > 1: |
|
assert self.kwargs["model_dtype"] == "bf16", "FP8 is not supported for multi-GPU inference" |
|
model = setup_fsdp_sync( |
|
model, |
|
device_id=device_id, |
|
param_dtype=torch.bfloat16, |
|
auto_wrap_policy=partial( |
|
lambda_auto_wrap_policy, |
|
lambda_fn=lambda m: m in model.blocks, |
|
), |
|
) |
|
elif isinstance(device_id, int): |
|
model = model.to(torch.device(f"cuda:{device_id}")) |
|
return model.eval() |
|
|
|
|
|
class DecoderModelFactory(ModelFactory): |
|
def __init__(self, *, model_path: str, model_stats_path: str): |
|
super().__init__(model_path=model_path, model_stats_path=model_stats_path) |
|
|
|
def get_model(self, *, local_rank, device_id, world_size): |
|
|
|
|
|
import json |
|
|
|
decoder = Decoder( |
|
out_channels=3, |
|
base_channels=128, |
|
channel_multipliers=[1, 2, 4, 6], |
|
temporal_expansions=[1, 2, 3], |
|
spatial_expansions=[2, 2, 2], |
|
num_res_blocks=[3, 3, 4, 6, 3], |
|
latent_dim=12, |
|
has_attention=[False, False, False, False, False], |
|
padding_mode="replicate", |
|
output_norm=False, |
|
nonlinearity="silu", |
|
output_nonlinearity="silu", |
|
causal=True, |
|
) |
|
|
|
state_dict = load_file(self.kwargs["model_path"]) |
|
decoder.load_state_dict(state_dict, strict=True) |
|
device = torch.device(f"cuda:{device_id}") if isinstance(device_id, int) else "cpu" |
|
decoder.eval().to(device) |
|
vae_stats = json.load(open(self.kwargs["model_stats_path"])) |
|
decoder.register_buffer("vae_mean", torch.tensor(vae_stats["mean"], device=device)) |
|
decoder.register_buffer("vae_std", torch.tensor(vae_stats["std"], device=device)) |
|
return decoder |
|
|
|
|
|
def get_conditioning(tokenizer, encoder, device, batch_inputs, *, prompt: str, negative_prompt: str): |
|
if batch_inputs: |
|
return dict(batched=get_conditioning_for_prompts(tokenizer, encoder, device, [prompt, negative_prompt])) |
|
else: |
|
cond_input = get_conditioning_for_prompts(tokenizer, encoder, device, [prompt]) |
|
null_input = get_conditioning_for_prompts(tokenizer, encoder, device, [negative_prompt]) |
|
return dict(cond=cond_input, null=null_input) |
|
|
|
|
|
def get_conditioning_for_prompts(tokenizer, encoder, device, prompts: List[str]): |
|
assert len(prompts) in [1, 2] |
|
B = len(prompts) |
|
t5_toks = tokenizer( |
|
prompts, |
|
padding="max_length", |
|
truncation=True, |
|
max_length=MAX_T5_TOKEN_LENGTH, |
|
return_tensors="pt", |
|
return_attention_mask=True, |
|
) |
|
caption_input_ids_t5 = t5_toks["input_ids"] |
|
caption_attention_mask_t5 = t5_toks["attention_mask"].bool() |
|
del t5_toks |
|
|
|
assert caption_input_ids_t5.shape == (B, MAX_T5_TOKEN_LENGTH) |
|
assert caption_attention_mask_t5.shape == (B, MAX_T5_TOKEN_LENGTH) |
|
|
|
|
|
if prompts[-1] == "": |
|
caption_input_ids_t5[-1] = 0 |
|
caption_attention_mask_t5[-1] = False |
|
|
|
caption_input_ids_t5 = caption_input_ids_t5.to(device, non_blocking=True) |
|
caption_attention_mask_t5 = caption_attention_mask_t5.to(device, non_blocking=True) |
|
|
|
y_mask = [caption_attention_mask_t5] |
|
y_feat = [encoder(caption_input_ids_t5, caption_attention_mask_t5).last_hidden_state.detach()] |
|
|
|
|
|
assert tuple(y_feat[-1].shape) == (B, MAX_T5_TOKEN_LENGTH, 4096) |
|
assert y_feat[-1].dtype == torch.float32 |
|
|
|
return dict(y_mask=y_mask, y_feat=y_feat) |
|
|
|
|
|
def compute_packed_indices( |
|
device: torch.device, text_mask: torch.Tensor, num_latents: int |
|
) -> Dict[str, Union[torch.Tensor, int]]: |
|
""" |
|
Based on https://github.com/Dao-AILab/flash-attention/blob/765741c1eeb86c96ee71a3291ad6968cfbf4e4a1/flash_attn/bert_padding.py#L60-L80 |
|
|
|
Args: |
|
num_latents: Number of latent tokens |
|
text_mask: (B, L) List of boolean tensor indicating which text tokens are not padding. |
|
|
|
Returns: |
|
packed_indices: Dict with keys for Flash Attention: |
|
- valid_token_indices_kv: up to (B * (N + L),) tensor of valid token indices (non-padding) |
|
in the packed sequence. |
|
- cu_seqlens_kv: (B + 1,) tensor of cumulative sequence lengths in the packed sequence. |
|
- max_seqlen_in_batch_kv: int of the maximum sequence length in the batch. |
|
""" |
|
|
|
PATCH_SIZE = 2 |
|
num_visual_tokens = num_latents // (PATCH_SIZE**2) |
|
assert num_visual_tokens > 0 |
|
|
|
mask = F.pad(text_mask, (num_visual_tokens, 0), value=True) |
|
seqlens_in_batch = mask.sum(dim=-1, dtype=torch.int32) |
|
valid_token_indices = torch.nonzero(mask.flatten(), as_tuple=False).flatten() |
|
assert valid_token_indices.size(0) >= text_mask.size(0) * num_visual_tokens |
|
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) |
|
max_seqlen_in_batch = seqlens_in_batch.max().item() |
|
|
|
return { |
|
"cu_seqlens_kv": cu_seqlens.to(device, non_blocking=True), |
|
"max_seqlen_in_batch_kv": cast(int, max_seqlen_in_batch), |
|
"valid_token_indices_kv": valid_token_indices.to(device, non_blocking=True), |
|
} |
|
|
|
|
|
def assert_eq(x, y, msg=None): |
|
assert x == y, f"{msg or 'Assertion failed'}: {x} != {y}" |
|
|
|
|
|
def sample_model(device, dit, conditioning, **args): |
|
random.seed(args["seed"]) |
|
np.random.seed(args["seed"]) |
|
torch.manual_seed(args["seed"]) |
|
|
|
generator = torch.Generator(device=device) |
|
generator.manual_seed(args["seed"]) |
|
|
|
w, h, t = args["width"], args["height"], args["num_frames"] |
|
sample_steps = args["num_inference_steps"] |
|
cfg_schedule = args["cfg_schedule"] |
|
sigma_schedule = args["sigma_schedule"] |
|
|
|
assert_eq(len(cfg_schedule), sample_steps, "cfg_schedule must have length sample_steps") |
|
assert_eq((t - 1) % 6, 0, "t - 1 must be divisible by 6") |
|
assert_eq( |
|
len(sigma_schedule), |
|
sample_steps + 1, |
|
"sigma_schedule must have length sample_steps + 1", |
|
) |
|
|
|
B = 1 |
|
SPATIAL_DOWNSAMPLE = 8 |
|
TEMPORAL_DOWNSAMPLE = 6 |
|
IN_CHANNELS = 12 |
|
latent_t = ((t - 1) // TEMPORAL_DOWNSAMPLE) + 1 |
|
latent_w, latent_h = w // SPATIAL_DOWNSAMPLE, h // SPATIAL_DOWNSAMPLE |
|
|
|
z = torch.randn( |
|
(B, IN_CHANNELS, latent_t, latent_h, latent_w), |
|
device=device, |
|
dtype=torch.float32, |
|
) |
|
|
|
num_latents = latent_t * latent_h * latent_w |
|
cond_batched = cond_text = cond_null = None |
|
if "cond" in conditioning: |
|
cond_text = conditioning["cond"] |
|
cond_null = conditioning["null"] |
|
cond_text["packed_indices"] = compute_packed_indices(device, cond_text["y_mask"][0], num_latents) |
|
cond_null["packed_indices"] = compute_packed_indices(device, cond_null["y_mask"][0], num_latents) |
|
else: |
|
cond_batched = conditioning["batched"] |
|
cond_batched["packed_indices"] = compute_packed_indices(device, cond_batched["y_mask"][0], num_latents) |
|
z = repeat(z, "b ... -> (repeat b) ...", repeat=2) |
|
|
|
def model_fn(*, z, sigma, cfg_scale): |
|
if cond_batched: |
|
with torch.autocast("cuda", dtype=torch.bfloat16): |
|
out = dit(z, sigma, **cond_batched) |
|
out_cond, out_uncond = torch.chunk(out, chunks=2, dim=0) |
|
else: |
|
nonlocal cond_text, cond_null |
|
with torch.autocast("cuda", dtype=torch.bfloat16): |
|
out_cond = dit(z, sigma, **cond_text) |
|
out_uncond = dit(z, sigma, **cond_null) |
|
assert out_cond.shape == out_uncond.shape |
|
return out_uncond + cfg_scale * (out_cond - out_uncond), out_cond |
|
|
|
for i in get_new_progress_bar(range(0, sample_steps), desc="Sampling"): |
|
sigma = sigma_schedule[i] |
|
dsigma = sigma - sigma_schedule[i + 1] |
|
|
|
|
|
pred, output_cond = model_fn( |
|
z=z, |
|
sigma=torch.full([B] if cond_text else [B * 2], sigma, device=z.device), |
|
cfg_scale=cfg_schedule[i], |
|
) |
|
pred = pred.to(z) |
|
output_cond = output_cond.to(z) |
|
z = z + dsigma * pred |
|
|
|
return z[:B] if cond_batched else z |
|
|
|
|
|
def decoded_latents_to_frames(samples): |
|
samples = samples.float() |
|
samples = (samples + 1.0) / 2.0 |
|
samples.clamp_(0.0, 1.0) |
|
frames = rearrange(samples, "b c t h w -> b t h w c") |
|
return frames |
|
|
|
|
|
def decode_latents(decoder, z): |
|
cp_rank, cp_size = cp.get_cp_rank_size() |
|
z = z.tensor_split(cp_size, dim=2)[cp_rank] |
|
with torch.autocast("cuda", dtype=torch.bfloat16): |
|
samples = decoder(z) |
|
samples = cp_conv.gather_all_frames(samples) |
|
return decoded_latents_to_frames(samples) |
|
|
|
|
|
@torch.inference_mode() |
|
def decode_latents_tiled_full( |
|
decoder, |
|
z, |
|
*, |
|
tile_sample_min_height: int = 240, |
|
tile_sample_min_width: int = 424, |
|
tile_overlap_factor_height: float = 0.1666, |
|
tile_overlap_factor_width: float = 0.2, |
|
auto_tile_size: bool = True, |
|
frame_batch_size: int = 6, |
|
): |
|
B, C, T, H, W = z.shape |
|
assert frame_batch_size <= T, f"frame_batch_size must be <= T, got {frame_batch_size} > {T}" |
|
|
|
tile_sample_min_height = tile_sample_min_height if not auto_tile_size else H // 2 * 8 |
|
tile_sample_min_width = tile_sample_min_width if not auto_tile_size else W // 2 * 8 |
|
|
|
tile_latent_min_height = int(tile_sample_min_height / 8) |
|
tile_latent_min_width = int(tile_sample_min_width / 8) |
|
|
|
def blend_v(a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: |
|
blend_extent = min(a.shape[3], b.shape[3], blend_extent) |
|
for y in range(blend_extent): |
|
b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * ( |
|
y / blend_extent |
|
) |
|
return b |
|
|
|
def blend_h(a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: |
|
blend_extent = min(a.shape[4], b.shape[4], blend_extent) |
|
for x in range(blend_extent): |
|
b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * ( |
|
x / blend_extent |
|
) |
|
return b |
|
|
|
overlap_height = int(tile_latent_min_height * (1 - tile_overlap_factor_height)) |
|
overlap_width = int(tile_latent_min_width * (1 - tile_overlap_factor_width)) |
|
blend_extent_height = int(tile_sample_min_height * tile_overlap_factor_height) |
|
blend_extent_width = int(tile_sample_min_width * tile_overlap_factor_width) |
|
row_limit_height = tile_sample_min_height - blend_extent_height |
|
row_limit_width = tile_sample_min_width - blend_extent_width |
|
|
|
|
|
|
|
pbar = get_new_progress_bar( |
|
desc="Decoding latent tiles", |
|
total=len(range(0, H, overlap_height)) * len(range(0, W, overlap_width)) * len(range(T // frame_batch_size)), |
|
) |
|
rows = [] |
|
for i in range(0, H, overlap_height): |
|
row = [] |
|
for j in range(0, W, overlap_width): |
|
temporal = [] |
|
for k in range(T // frame_batch_size): |
|
remaining_frames = T % frame_batch_size |
|
start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames) |
|
end_frame = frame_batch_size * (k + 1) + remaining_frames |
|
tile = z[ |
|
:, |
|
:, |
|
start_frame:end_frame, |
|
i : i + tile_latent_min_height, |
|
j : j + tile_latent_min_width, |
|
] |
|
tile = decoder(tile) |
|
temporal.append(tile) |
|
pbar.update(1) |
|
row.append(torch.cat(temporal, dim=2)) |
|
rows.append(row) |
|
|
|
result_rows = [] |
|
for i, row in enumerate(rows): |
|
result_row = [] |
|
for j, tile in enumerate(row): |
|
|
|
|
|
if i > 0: |
|
tile = blend_v(rows[i - 1][j], tile, blend_extent_height) |
|
if j > 0: |
|
tile = blend_h(row[j - 1], tile, blend_extent_width) |
|
result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width]) |
|
result_rows.append(torch.cat(result_row, dim=4)) |
|
|
|
return decoded_latents_to_frames(torch.cat(result_rows, dim=3)) |
|
|
|
@torch.inference_mode() |
|
def decode_latents_tiled_spatial( |
|
decoder, |
|
z, |
|
*, |
|
num_tiles_w: int, |
|
num_tiles_h: int, |
|
overlap: int = 0, |
|
|
|
min_block_size: int = 1, |
|
): |
|
decoded = apply_tiled(decoder, z, num_tiles_w, num_tiles_h, overlap, min_block_size) |
|
assert decoded is not None, f"Failed to decode latents with tiled spatial method" |
|
return decoded |
|
|
|
@contextmanager |
|
def move_to_device(model: nn.Module, target_device): |
|
og_device = next(model.parameters()).device |
|
if og_device == target_device: |
|
print(f"move_to_device is a no-op model is already on {target_device}") |
|
else: |
|
print(f"moving model from {og_device} -> {target_device}") |
|
|
|
model.to(target_device) |
|
yield |
|
if og_device != target_device: |
|
print(f"moving model from {target_device} -> {og_device}") |
|
model.to(og_device) |
|
|
|
|
|
def t5_tokenizer(): |
|
return T5Tokenizer.from_pretrained(T5_MODEL, legacy=False) |
|
|
|
|
|
class MochiSingleGPUPipeline: |
|
def __init__( |
|
self, |
|
*, |
|
text_encoder_factory: ModelFactory, |
|
dit_factory: ModelFactory, |
|
decoder_factory: ModelFactory, |
|
cpu_offload: Optional[bool] = False, |
|
decode_type: str = "full", |
|
decode_args: Optional[Dict[str, Any]] = None, |
|
): |
|
self.device = torch.device("cuda:0") |
|
self.tokenizer = t5_tokenizer() |
|
t = Timer() |
|
self.cpu_offload = cpu_offload |
|
self.decode_args = decode_args or {} |
|
self.decode_type = decode_type |
|
init_id = "cpu" if cpu_offload else 0 |
|
with t("load_text_encoder"): |
|
self.text_encoder = text_encoder_factory.get_model( |
|
local_rank=0, |
|
device_id=init_id, |
|
world_size=1, |
|
) |
|
with t("load_dit"): |
|
self.dit = dit_factory.get_model(local_rank=0, device_id=init_id, world_size=1) |
|
with t("load_vae"): |
|
self.decoder = decoder_factory.get_model(local_rank=0, device_id=init_id, world_size=1) |
|
t.print_stats() |
|
|
|
def __call__(self, batch_cfg, prompt, negative_prompt, **kwargs): |
|
with progress_bar(type="tqdm"), torch.inference_mode(): |
|
print_max_memory = lambda: print( |
|
f"Max memory reserved: {torch.cuda.max_memory_reserved() / 1024**3:.2f} GB" |
|
) |
|
print_max_memory() |
|
with move_to_device(self.text_encoder, self.device): |
|
conditioning = get_conditioning( |
|
self.tokenizer, |
|
self.text_encoder, |
|
self.device, |
|
batch_cfg, |
|
prompt=prompt, |
|
negative_prompt=negative_prompt, |
|
) |
|
print_max_memory() |
|
with move_to_device(self.dit, self.device): |
|
latents = sample_model(self.device, self.dit, conditioning, **kwargs) |
|
print_max_memory() |
|
with move_to_device(self.decoder, self.device): |
|
frames = ( |
|
decode_latents_tiled_full(self.decoder, latents, **self.decode_args) |
|
if self.decode_type == "tiled_full" |
|
else |
|
decode_latents_tiled_spatial(self.decoder, latents, **self.decode_args) |
|
if self.decode_type == "tiled_spatial" |
|
else decode_latents(self.decoder, latents) |
|
) |
|
print_max_memory() |
|
return frames.cpu().numpy() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MultiGPUContext: |
|
def __init__( |
|
self, |
|
*, |
|
text_encoder_factory, |
|
dit_factory, |
|
decoder_factory, |
|
device_id, |
|
local_rank, |
|
world_size, |
|
): |
|
t = Timer() |
|
self.device = torch.device(f"cuda:{device_id}") |
|
print(f"Initializing rank {local_rank+1}/{world_size}") |
|
assert world_size > 1, f"Multi-GPU mode requires world_size > 1, got {world_size}" |
|
os.environ["MASTER_ADDR"] = "127.0.0.1" |
|
os.environ["MASTER_PORT"] = "29500" |
|
with t("init_process_group"): |
|
dist.init_process_group( |
|
"nccl", |
|
rank=local_rank, |
|
world_size=world_size, |
|
device_id=self.device, |
|
) |
|
pg = dist.group.WORLD |
|
cp.set_cp_group(pg, list(range(world_size)), local_rank) |
|
distributed_kwargs = dict(local_rank=local_rank, device_id=device_id, world_size=world_size) |
|
self.world_size = world_size |
|
self.tokenizer = t5_tokenizer() |
|
with t("load_text_encoder"): |
|
self.text_encoder = text_encoder_factory.get_model(**distributed_kwargs) |
|
with t("load_dit"): |
|
self.dit = dit_factory.get_model(**distributed_kwargs) |
|
with t("load_vae"): |
|
self.decoder = decoder_factory.get_model(**distributed_kwargs) |
|
self.local_rank = local_rank |
|
t.print_stats() |
|
|
|
def run(self, *, fn, **kwargs): |
|
return fn(self, **kwargs) |
|
|
|
|
|
class MochiMultiGPUPipeline: |
|
def __init__( |
|
self, |
|
*, |
|
text_encoder_factory: ModelFactory, |
|
dit_factory: ModelFactory, |
|
decoder_factory: ModelFactory, |
|
world_size: int, |
|
): |
|
ray.init() |
|
RemoteClass = ray.remote(MultiGPUContext) |
|
self.ctxs = [ |
|
RemoteClass.options(num_gpus=1).remote( |
|
text_encoder_factory=text_encoder_factory, |
|
dit_factory=dit_factory, |
|
decoder_factory=decoder_factory, |
|
world_size=world_size, |
|
device_id=0, |
|
local_rank=i, |
|
) |
|
for i in range(world_size) |
|
] |
|
for ctx in self.ctxs: |
|
ray.get(ctx.__ray_ready__.remote()) |
|
|
|
def __call__(self, **kwargs): |
|
def sample(ctx, *, batch_cfg, prompt, negative_prompt, **kwargs): |
|
with progress_bar(type="ray_tqdm", enabled=ctx.local_rank == 0), torch.inference_mode(): |
|
conditioning = get_conditioning( |
|
ctx.tokenizer, |
|
ctx.text_encoder, |
|
ctx.device, |
|
batch_cfg, |
|
prompt=prompt, |
|
negative_prompt=negative_prompt, |
|
) |
|
latents = sample_model(ctx.device, ctx.dit, conditioning=conditioning, **kwargs) |
|
if ctx.local_rank == 0: |
|
torch.save(latents, "latents.pt") |
|
frames = decode_latents(ctx.decoder, latents) |
|
return frames.cpu().numpy() |
|
|
|
return ray.get([ctx.run.remote(fn=sample, **kwargs, show_progress=i == 0) for i, ctx in enumerate(self.ctxs)])[ |
|
0 |
|
] |
|
|