Pyramid-Flow / inference_multigpu.py
sironagasuyagi's picture
Upload folder using huggingface_hub
910e2ad verified
import os
import torch
import sys
import argparse
import random
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from diffusers.utils import export_to_video
from pyramid_dit import PyramidDiTForVideoGeneration
from trainer_misc import init_distributed_mode, init_sequence_parallel_group
import PIL
from PIL import Image
def get_args():
parser = argparse.ArgumentParser('Pytorch Multi-process Script', add_help=False)
parser.add_argument('--model_name', default='pyramid_flux', type=str, help="The model name", choices=["pyramid_flux", "pyramid_mmdit"])
parser.add_argument('--model_dtype', default='bf16', type=str, help="The Model Dtype: bf16")
parser.add_argument('--model_path', default='/home/jinyang06/models/pyramid-flow', type=str, help='Set it to the downloaded checkpoint dir')
parser.add_argument('--variant', default='diffusion_transformer_768p', type=str,)
parser.add_argument('--task', default='t2v', type=str, choices=['i2v', 't2v'])
parser.add_argument('--temp', default=16, type=int, help='The generated latent num, num_frames = temp * 8 + 1')
parser.add_argument('--sp_group_size', default=2, type=int, help="The number of gpus used for inference, should be 2 or 4")
parser.add_argument('--sp_proc_num', default=-1, type=int, help="The number of process used for video training, default=-1 means using all process.")
return parser.parse_args()
def main():
args = get_args()
# setup DDP
init_distributed_mode(args)
assert args.world_size == args.sp_group_size, "The sequence parallel size should be DDP world size"
# Enable sequence parallel
init_sequence_parallel_group(args)
device = torch.device('cuda')
rank = args.rank
model_dtype = args.model_dtype
model = PyramidDiTForVideoGeneration(
args.model_path,
model_dtype,
model_name=args.model_name,
model_variant=args.variant,
)
model.vae.to(device)
model.dit.to(device)
model.text_encoder.to(device)
model.vae.enable_tiling()
if model_dtype == "bf16":
torch_dtype = torch.bfloat16
elif model_dtype == "fp16":
torch_dtype = torch.float16
else:
torch_dtype = torch.float32
# The video generation config
if args.variant == 'diffusion_transformer_768p':
width = 1280
height = 768
else:
assert args.variant == 'diffusion_transformer_384p'
width = 640
height = 384
if args.task == 't2v':
prompt = "A movie trailer featuring the adventures of the 30 year old space man wearing a red wool knitted motorcycle helmet, blue sky, salt desert, cinematic style, shot on 35mm film, vivid colors"
with torch.no_grad(), torch.cuda.amp.autocast(enabled=True if model_dtype != 'fp32' else False, dtype=torch_dtype):
frames = model.generate(
prompt=prompt,
num_inference_steps=[20, 20, 20],
video_num_inference_steps=[10, 10, 10],
height=height,
width=width,
temp=args.temp,
guidance_scale=7.0, # The guidance for the first frame, set it to 7 for 384p variant
video_guidance_scale=5.0, # The guidance for the other video latent
output_type="pil",
save_memory=True, # If you have enough GPU memory, set it to `False` to improve vae decoding speed
cpu_offloading=False, # If OOM, set it to True to reduce memory usage
inference_multigpu=True,
)
if rank == 0:
export_to_video(frames, "./text_to_video_sample.mp4", fps=24)
else:
assert args.task == 'i2v'
image_path = 'assets/the_great_wall.jpg'
image = Image.open(image_path).convert("RGB")
image = image.resize((width, height))
prompt = "FPV flying over the Great Wall"
with torch.no_grad(), torch.cuda.amp.autocast(enabled=True if model_dtype != 'fp32' else False, dtype=torch_dtype):
frames = model.generate_i2v(
prompt=prompt,
input_image=image,
num_inference_steps=[10, 10, 10],
temp=args.temp,
video_guidance_scale=4.0,
output_type="pil",
save_memory=True, # If you have enough GPU memory, set it to `False` to improve vae decoding speed
cpu_offloading=False, # If OOM, set it to True to reduce memory usage
inference_multigpu=True,
)
if rank == 0:
export_to_video(frames, "./image_to_video_sample.mp4", fps=24)
torch.distributed.barrier()
if __name__ == "__main__":
main()