fffiloni commited on
Commit
2945355
1 Parent(s): bfed184

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +90 -0
app.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import subprocess
3
+ import os
4
+ import gdown
5
+
6
+ # Ensure 'checkpoint' directory exists
7
+ os.makedirs("checkpoint", exist_ok=True)
8
+
9
+ # Function to download the model weights from a Google Drive folder
10
+ def download_weights_from_folder(google_drive_folder_link):
11
+ # Extract the folder ID from the Google Drive link
12
+ folder_id = google_drive_folder_link.split('/')[-1]
13
+ output_folder = "checkpoint" # Directory to save the downloaded files
14
+
15
+ # Download all files in the Google Drive folder
16
+ gdown_url = f"https://drive.google.com/drive/folders/{folder_id}"
17
+ try:
18
+ gdown.download_folder(gdown_url, quiet=False, output=output_folder)
19
+
20
+ # Check if the model file exists and rename if necessary
21
+ downloaded_model_path = os.path.join(output_folder, "model_state-415001.th")
22
+ if os.path.exists(downloaded_model_path):
23
+ return f"Downloaded model weights to {downloaded_model_path}"
24
+ else:
25
+ return "Model file 'model_state-415001.th' not found in the folder."
26
+ except Exception as e:
27
+ return f"Failed to download weights: {e}"
28
+
29
+ download_weights("https://drive.google.com/drive/folders/1Bq0n-w1VT5l99CoaVg02hFpqE5eGLo9O")
30
+
31
+ # Define a function to run your script with selected inputs
32
+ def run_xportrait(
33
+ model_config,
34
+ output_dir,
35
+ resume_dir,
36
+ seed,
37
+ uc_scale,
38
+ source_image,
39
+ driving_video,
40
+ best_frame,
41
+ out_frames,
42
+ num_mix,
43
+ ddim_steps
44
+ ):
45
+ # Construct the command
46
+ command = [
47
+ "python3", "core/test_xportrait.py",
48
+ "--model_config", model_config,
49
+ "--output_dir", output_dir,
50
+ "--resume_dir", resume_dir,
51
+ "--seed", str(seed),
52
+ "--uc_scale", str(uc_scale),
53
+ "--source_image", source_image,
54
+ "--driving_video", driving_video,
55
+ "--best_frame", str(best_frame),
56
+ "--out_frames", str(out_frames),
57
+ "--num_mix", str(num_mix),
58
+ "--ddim_steps", str(ddim_steps)
59
+ ]
60
+
61
+ # Run the command
62
+ try:
63
+ subprocess.run(command, check=True)
64
+ return f"Output saved in: {output_dir}"
65
+ except subprocess.CalledProcessError as e:
66
+ return f"An error occurred: {e}"
67
+
68
+ # Set up Gradio interface
69
+ app = gr.Interface(
70
+ fn=run_xportrait,
71
+ inputs=[
72
+ gr.Textbox(value="config/cldm_v15_appearance_pose_local_mm.yaml", label="Model Config Path"),
73
+ gr.Textbox(value="outputs", label="Output Directory"),
74
+ gr.Textbox(value="checkpoint/model_state-415001.th", label="Resume Directory"),
75
+ gr.Number(value=999, label="Seed"),
76
+ gr.Number(value=5, label="UC Scale"),
77
+ gr.Image(label="Source Image"),
78
+ gr.Video(label="Driving Video"),
79
+ gr.Number(value=36, label="Best Frame"),
80
+ gr.Number(value=-1, label="Out Frames"),
81
+ gr.Number(value=4, label="Number of Mix"),
82
+ gr.Number(value=30, label="DDIM Steps")
83
+ ],
84
+ outputs="text",
85
+ title="XPortrait Model Runner",
86
+ description="Run XPortrait with customizable parameters."
87
+ )
88
+
89
+ # Launch the Gradio app
90
+ app.launch()