import gradio as gr import subprocess import os import gdown # Ensure 'checkpoint' directory exists os.makedirs("checkpoint", exist_ok=True) # Function to download the model weights from a Google Drive folder def download_weights_from_folder(google_drive_folder_link): # Extract the folder ID from the Google Drive link folder_id = google_drive_folder_link.split('/')[-1] output_folder = "checkpoint" # Directory to save the downloaded files # Download all files in the Google Drive folder gdown_url = f"https://drive.google.com/drive/folders/{folder_id}" try: gdown.download_folder(gdown_url, quiet=False, output=output_folder) # Check if the model file exists and rename if necessary downloaded_model_path = os.path.join(output_folder, "model_state-415001.th") if os.path.exists(downloaded_model_path): return f"Downloaded model weights to {downloaded_model_path}" else: return "Model file 'model_state-415001.th' not found in the folder." except Exception as e: return f"Failed to download weights: {e}" download_weights("https://drive.google.com/drive/folders/1Bq0n-w1VT5l99CoaVg02hFpqE5eGLo9O") # Define a function to run your script with selected inputs def run_xportrait( model_config, output_dir, resume_dir, seed, uc_scale, source_image, driving_video, best_frame, out_frames, num_mix, ddim_steps ): # Construct the command command = [ "python3", "core/test_xportrait.py", "--model_config", model_config, "--output_dir", output_dir, "--resume_dir", resume_dir, "--seed", str(seed), "--uc_scale", str(uc_scale), "--source_image", source_image, "--driving_video", driving_video, "--best_frame", str(best_frame), "--out_frames", str(out_frames), "--num_mix", str(num_mix), "--ddim_steps", str(ddim_steps) ] # Run the command try: subprocess.run(command, check=True) return f"Output saved in: {output_dir}" except subprocess.CalledProcessError as e: return f"An error occurred: {e}" # Set up Gradio interface app = gr.Interface( fn=run_xportrait, inputs=[ gr.Textbox(value="config/cldm_v15_appearance_pose_local_mm.yaml", label="Model Config Path"), gr.Textbox(value="outputs", label="Output Directory"), gr.Textbox(value="checkpoint/model_state-415001.th", label="Resume Directory"), gr.Number(value=999, label="Seed"), gr.Number(value=5, label="UC Scale"), gr.Image(label="Source Image"), gr.Video(label="Driving Video"), gr.Number(value=36, label="Best Frame"), gr.Number(value=-1, label="Out Frames"), gr.Number(value=4, label="Number of Mix"), gr.Number(value=30, label="DDIM Steps") ], outputs="text", title="XPortrait Model Runner", description="Run XPortrait with customizable parameters." ) # Launch the Gradio app app.launch()