Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,422 Bytes
6a87547 dae6484 6a87547 702754c ad0cea8 702754c 12e9e51 d36add3 beb0b25 6a87547 12e9e51 6a87547 b832af5 6a87547 0b7d3bf 4254e9c 0b7d3bf a40b1c7 6a87547 ad0cea8 0328b82 ad0cea8 12e9e51 ad0cea8 12e9e51 ad0cea8 0b7d3bf ad0cea8 a40b1c7 0b7d3bf 6a87547 12e9e51 0b7d3bf ad0cea8 0b7d3bf 12e9e51 6a87547 ad0cea8 beb0b25 d36add3 beb0b25 12e9e51 beb0b25 4254e9c d093812 4254e9c e0e8e31 16c1b5a 4254e9c ad0cea8 4254e9c 0b7d3bf dd0fd3b 8157bfe ad0cea8 6a87547 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
import os
import torch
import gradio as gr
from PIL import Image, ImageOps
from huggingface_hub import snapshot_download
from pyramid_dit import PyramidDiTForVideoGeneration
from diffusers.utils import export_to_video
import spaces
import uuid
import subprocess
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
is_canonical = True if os.environ.get("SPACE_ID") == "Pyramid-Flow/pyramid-flow" else False
# Constants
MODEL_PATH = "pyramid-flow-model"
MODEL_REPO = "rain1011/pyramid-flow-sd3"
MODEL_VARIANT = "diffusion_transformer_768p"
MODEL_DTYPE = "bf16"
def center_crop(image, target_width, target_height):
width, height = image.size
aspect_ratio_target = target_width / target_height
aspect_ratio_image = width / height
if aspect_ratio_image > aspect_ratio_target:
# Crop the width (left and right)
new_width = int(height * aspect_ratio_target)
left = (width - new_width) // 2
right = left + new_width
top, bottom = 0, height
else:
# Crop the height (top and bottom)
new_height = int(width / aspect_ratio_target)
top = (height - new_height) // 2
bottom = top + new_height
left, right = 0, width
image = image.crop((left, top, right, bottom))
return image
# Download and load the model
def load_model():
if not os.path.exists(MODEL_PATH):
snapshot_download(MODEL_REPO, local_dir=MODEL_PATH, local_dir_use_symlinks=False, repo_type='model')
model = PyramidDiTForVideoGeneration(
MODEL_PATH,
MODEL_DTYPE,
model_variant=MODEL_VARIANT,
)
model.vae.to("cuda")
model.dit.to("cuda")
model.text_encoder.to("cuda")
model.vae.enable_tiling()
return model
# Global model variable
model = load_model()
# Text-to-video generation function
@spaces.GPU(duration=160)
def generate_video(prompt, image=None, duration=5, guidance_scale=9, video_guidance_scale=5, progress=gr.Progress(track_tqdm=True)):
multiplier = 3
temp = int(duration * multiplier) + 1
torch_dtype = torch.bfloat16 if MODEL_DTYPE == "bf16" else torch.float32
if(image):
cropped_image = center_crop(image, 1280, 768)
resized_image = cropped_image.resize((1280, 768))
with torch.no_grad(), torch.cuda.amp.autocast(enabled=True, dtype=torch_dtype):
frames = model.generate_i2v(
prompt=prompt,
input_image=resized_image,
num_inference_steps=[10, 10, 10],
temp=temp,
guidance_scale=7.0,
video_guidance_scale=video_guidance_scale,
output_type="pil",
save_memory=True,
)
else:
with torch.no_grad(), torch.cuda.amp.autocast(enabled=True, dtype=torch_dtype):
frames = model.generate(
prompt=prompt,
num_inference_steps=[20, 20, 20],
video_num_inference_steps=[10, 10, 10],
height=768,
width=1280,
temp=temp,
guidance_scale=guidance_scale,
video_guidance_scale=video_guidance_scale,
output_type="pil",
save_memory=True,
)
return frames, gr.update()
def compose_video(frames):
output_path = f"{str(uuid.uuid4())}_output_video.mp4"
export_to_video(frames, output_path, fps=24)
return output_path
# Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# Pyramid Flow")
gr.Markdown("Pyramid Flow is a training-efficient Autoregressive Video Generation model based on Flow Matching. It is trained only on open-source datasets within 20.7k A100 GPU hours")
gr.Markdown("[[Paper](https://arxiv.org/pdf/2410.05954)], [[Model](https://huggingface.co/rain1011/pyramid-flow-sd3)], [[Code](https://github.com/jy0205/Pyramid-Flow)]")
frames = gr.State()
with gr.Row():
with gr.Column():
with gr.Accordion("Image to Video (optional)", open=False):
i2v_image = gr.Image(type="pil", label="Input Image")
t2v_prompt = gr.Textbox(label="Prompt")
with gr.Accordion("Advanced settings", open=False):
t2v_duration = gr.Slider(minimum=1, maximum=2 if is_canonical else 10, value=2 if is_canonical else 5, step=1, label="Duration (seconds)")
t2v_guidance_scale = gr.Slider(minimum=1, maximum=15, value=9, step=0.1, label="Guidance Scale")
t2v_video_guidance_scale = gr.Slider(minimum=1, maximum=15, value=5, step=0.1, label="Video Guidance Scale")
t2v_generate_btn = gr.Button("Generate Video")
with gr.Column():
t2v_output = gr.Video(label="Generated Video")
gr.HTML("""
<div style="display: flex; flex-direction: column;justify-content: center; align-items: center; text-align: center;">
<p style="display: flex;gap: 6px;">
<a href="https://huggingface.co/spaces/Pyramid-Flow/pyramid-flow?duplicate=true">
<img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/duplicate-this-space-lg.svg" alt="Duplicate this Space">
</a>
</p>
<p>to use privately and generate videos up to 10s at 24fps</p>
</div>
""")
gr.Examples(
examples=[
"A movie trailer featuring the adventures of the 30 year old space man wearing a red wool knitted motorcycle helmet, blue sky, salt desert, cinematic style, shot on 35mm film, vivid colors",
"Beautiful, snowy Tokyo city is bustling. The camera moves through the bustling city street, following several people enjoying the beautiful snowy weather and shopping at nearby stalls. Gorgeous sakura petals are flying through the wind along with snowflakes"
],
fn=generate_video,
inputs=t2v_prompt,
outputs=t2v_output,
cache_examples="lazy"
)
t2v_generate_btn.click(
generate_video,
inputs=[t2v_prompt, i2v_image, t2v_duration, t2v_guidance_scale, t2v_video_guidance_scale],
outputs=[frames, t2v_output]
).then(
compose_video,
inputs=[frames],
outputs=t2v_output
)
demo.launch() |