Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
@@ -7,6 +7,7 @@ import glob
|
|
7 |
# Ensure 'checkpoint' directory exists
|
8 |
os.makedirs("checkpoint", exist_ok=True)
|
9 |
|
|
|
10 |
# Function to download the model weights from a Google Drive folder
|
11 |
def download_weights_from_folder(google_drive_folder_link):
|
12 |
# Extract the folder ID from the Google Drive link
|
@@ -16,16 +17,24 @@ def download_weights_from_folder(google_drive_folder_link):
|
|
16 |
# Download all files in the Google Drive folder
|
17 |
gdown_url = f"https://drive.google.com/drive/folders/{folder_id}"
|
18 |
try:
|
|
|
19 |
gdown.download_folder(gdown_url, quiet=False, output=output_folder)
|
20 |
|
21 |
-
#
|
22 |
-
|
23 |
-
if
|
24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
else:
|
26 |
-
|
27 |
except Exception as e:
|
28 |
-
|
29 |
|
30 |
download_weights_from_folder("https://drive.google.com/drive/folders/1Bq0n-w1VT5l99CoaVg02hFpqE5eGLo9O")
|
31 |
|
|
|
7 |
# Ensure 'checkpoint' directory exists
|
8 |
os.makedirs("checkpoint", exist_ok=True)
|
9 |
|
10 |
+
# Function to download the model weights from a Google Drive folder
|
11 |
# Function to download the model weights from a Google Drive folder
|
12 |
def download_weights_from_folder(google_drive_folder_link):
|
13 |
# Extract the folder ID from the Google Drive link
|
|
|
17 |
# Download all files in the Google Drive folder
|
18 |
gdown_url = f"https://drive.google.com/drive/folders/{folder_id}"
|
19 |
try:
|
20 |
+
# Download the folder contents
|
21 |
gdown.download_folder(gdown_url, quiet=False, output=output_folder)
|
22 |
|
23 |
+
# Ensure the downloaded file is named 'model_state-415001.th'
|
24 |
+
downloaded_files = glob.glob(os.path.join(output_folder, "*.th"))
|
25 |
+
if downloaded_files:
|
26 |
+
downloaded_model_path = downloaded_files[0]
|
27 |
+
target_model_path = os.path.join(output_folder, "model_state-415001.th")
|
28 |
+
|
29 |
+
# Rename if necessary
|
30 |
+
if downloaded_model_path != target_model_path:
|
31 |
+
os.rename(downloaded_model_path, target_model_path)
|
32 |
+
|
33 |
+
print(f"Downloaded model weights to {target_model_path}")
|
34 |
else:
|
35 |
+
print("Model file 'model_state-415001.th' not found in the folder.")
|
36 |
except Exception as e:
|
37 |
+
print(f"Failed to download weights: {e}")
|
38 |
|
39 |
download_weights_from_folder("https://drive.google.com/drive/folders/1Bq0n-w1VT5l99CoaVg02hFpqE5eGLo9O")
|
40 |
|