Spaces:
Paused
Paused
import gradio as gr | |
import subprocess | |
import os | |
import cv2 | |
from huggingface_hub import hf_hub_download | |
import glob | |
from datetime import datetime | |
is_shared_ui = True if "fffiloni/X-Portrait" in os.environ['SPACE_ID'] else False | |
# Ensure 'checkpoint' directory exists | |
os.makedirs("checkpoint", exist_ok=True) | |
hf_hub_download( | |
repo_id="fffiloni/X-Portrait", | |
filename="model_state-415001.th", | |
local_dir="checkpoint" | |
) | |
def trim_video(video_path, output_dir="trimmed_videos", max_duration=2): | |
# Create output directory if it does not exist | |
os.makedirs(output_dir, exist_ok=True) | |
# Generate a timestamp for the output filename | |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
output_path = os.path.join(output_dir, f"trimmed_video_{timestamp}.mp4") | |
# Load the video | |
with VideoFileClip(video_path) as video: | |
# Check the duration of the video | |
if video.duration > max_duration: | |
# Trim the video to the first max_duration seconds | |
trimmed_video = video.subclip(0, max_duration) | |
# Write the trimmed video to a file | |
trimmed_video.write_videofile(output_path, codec="libx264") | |
return output_path | |
else: | |
# If the video is within the duration, return the original path | |
return video_path | |
def extract_frames_with_labels(video_path, base_output_dir="frames"): | |
if is_shared_ui : | |
video_path = trim_video(video_path) | |
print("Path to the (trimmed) driving video:", video_path) | |
# Generate a timestamped folder name | |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
output_dir = os.path.join(base_output_dir, f"frames_{timestamp}") | |
# Ensure output directory exists | |
os.makedirs(output_dir, exist_ok=True) | |
# Open the video file | |
video_capture = cv2.VideoCapture(video_path) | |
if not video_capture.isOpened(): | |
raise ValueError(f"Cannot open video file: {video_path}") | |
frame_data = [] | |
frame_index = 0 | |
# Loop through the video frames | |
while True: | |
ret, frame = video_capture.read() | |
if not ret: | |
break # Exit the loop if there are no frames left to read | |
# Zero-padded frame index for filename and label | |
frame_label = f"{frame_index:04}" | |
frame_filename = os.path.join(output_dir, f"frame_{frame_label}.jpg") | |
# Save the frame as a .jpg file | |
cv2.imwrite(frame_filename, frame) | |
# Append the tuple (filename, label) to the list | |
frame_data.append((frame_filename, frame_label)) | |
# Increment frame index | |
frame_index += 1 | |
# Release the video capture object | |
video_capture.release() | |
return frame_data | |
# Define a function to run your script with selected inputs | |
def run_xportrait(source_image, driving_video, seed, uc_scale, best_frame, out_frames, num_mix, ddim_steps): | |
# Create a unique output directory name based on current date and time | |
output_dir_base = "outputs" | |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
output_dir = os.path.join(output_dir_base, f"output_{timestamp}") | |
os.makedirs(output_dir, exist_ok=True) | |
model_config = "config/cldm_v15_appearance_pose_local_mm.yaml" | |
resume_dir = "checkpoint/model_state-415001.th" | |
# Construct the command | |
command = [ | |
"python3", "core/test_xportrait.py", | |
"--model_config", model_config, | |
"--output_dir", output_dir, | |
"--resume_dir", resume_dir, | |
"--seed", str(seed), | |
"--uc_scale", str(uc_scale), | |
"--source_image", source_image, | |
"--driving_video", driving_video, | |
"--best_frame", str(best_frame), | |
"--out_frames", str(out_frames), | |
"--num_mix", str(num_mix), | |
"--ddim_steps", str(ddim_steps) | |
] | |
# Run the command | |
try: | |
subprocess.run(command, check=True) | |
# Find the generated video file in the output directory | |
video_files = glob.glob(os.path.join(output_dir, "*.mp4")) | |
print(video_files) | |
if video_files: | |
return f"Output video saved at: {video_files[0]}", video_files[0] | |
else: | |
return "No video file was found in the output directory.", None | |
except subprocess.CalledProcessError as e: | |
return f"An error occurred: {e}", None | |
# Set up Gradio interface | |
css=""" | |
div#frames-gallery{ | |
overflow: scroll!important; | |
} | |
""" | |
with gr.Blocks(css=css) as demo: | |
with gr.Column(elem_id="col-container"): | |
gr.Markdown("# X-Portrait: Expressive Portrait Animation with Hierarchical Motion Attention") | |
gr.Markdown("On this shared UI, drinving video input will be trimmed to 2 seconds max. Duplicate this space for more controls.") | |
gr.HTML(""" | |
<div style="display:flex;column-gap:4px;"> | |
<a href='https://github.com/bytedance/X-Portrait'> | |
<img src='https://img.shields.io/badge/GitHub-Repo-blue'> | |
</a> | |
<a href='https://byteaigc.github.io/x-portrait/'> | |
<img src='https://img.shields.io/badge/Project-Page-green'> | |
</a> | |
</div> | |
""") | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Row(): | |
source_image = gr.Image(label="Source Image", type="filepath") | |
driving_video = gr.Video(label="Driving Video") | |
with gr.Group(): | |
with gr.Row(): | |
best_frame = gr.Number(value=36, label="Best Frame", info="specify the frame index in the driving video where the head pose best matches the source image (note: precision of best_frame index might affect the final quality)") | |
out_frames = gr.Number(value=-1, label="Out Frames", info="number of generation frames") | |
with gr.Accordion("Driving video Frames"): | |
driving_frames = gr.Gallery(show_label=True, columns=6, height=512, elem_id="frames-gallery") | |
with gr.Row(): | |
seed = gr.Number(value=999, label="Seed") | |
uc_scale = gr.Number(value=5, label="UC Scale") | |
with gr.Row(): | |
num_mix = gr.Number(value=4, label="Number of Mix") | |
ddim_steps = gr.Number(value=30, label="DDIM Steps") | |
submit_btn = gr.Button("Submit") | |
with gr.Column(): | |
video_output = gr.Video(label="Output Video") | |
status = gr.Textbox(label="status") | |
gr.Examples( | |
examples=[ | |
["./assets/source_image.png", "./assets/driving_video.mp4"] | |
], | |
inputs=[source_image, driving_video] | |
) | |
gr.HTML(""" | |
<div style="display:flex;column-gap:4px;"> | |
<a href="https://huggingface.co/spaces/fffiloni/X-Portrait?duplicate=true"> | |
<img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/duplicate-this-space-xl.svg" alt="Duplicate this Space"> | |
</a> | |
<a href="https://huggingface.co/fffiloni"> | |
<img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/follow-me-on-HF-xl-dark.svg" alt="Follow me on HF"> | |
</a> | |
</div> | |
""") | |
driving_video.upload( | |
fn = extract_frames_with_labels, | |
inputs = [driving_video], | |
outputs = [driving_frames], | |
queue = False | |
) | |
submit_btn.click( | |
fn = run_xportrait, | |
inputs = [source_image, driving_video, seed, uc_scale, best_frame, out_frames, num_mix, ddim_steps], | |
outputs = [status, video_output] | |
) | |
# Launch the Gradio app | |
demo.launch() |