fffiloni commited on
Commit
1bb6b57
1 Parent(s): 1cddde4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -6
app.py CHANGED
@@ -7,6 +7,7 @@ import glob
7
  # Ensure 'checkpoint' directory exists
8
  os.makedirs("checkpoint", exist_ok=True)
9
 
 
10
  # Function to download the model weights from a Google Drive folder
11
  def download_weights_from_folder(google_drive_folder_link):
12
  # Extract the folder ID from the Google Drive link
@@ -16,16 +17,24 @@ def download_weights_from_folder(google_drive_folder_link):
16
  # Download all files in the Google Drive folder
17
  gdown_url = f"https://drive.google.com/drive/folders/{folder_id}"
18
  try:
 
19
  gdown.download_folder(gdown_url, quiet=False, output=output_folder)
20
 
21
- # Check if the model file exists and rename if necessary
22
- downloaded_model_path = os.path.join(output_folder, "model_state-415001.th")
23
- if os.path.exists(downloaded_model_path):
24
- return f"Downloaded model weights to {downloaded_model_path}"
 
 
 
 
 
 
 
25
  else:
26
- return "Model file 'model_state-415001.th' not found in the folder."
27
  except Exception as e:
28
- return f"Failed to download weights: {e}"
29
 
30
  download_weights_from_folder("https://drive.google.com/drive/folders/1Bq0n-w1VT5l99CoaVg02hFpqE5eGLo9O")
31
 
 
7
  # Ensure 'checkpoint' directory exists
8
  os.makedirs("checkpoint", exist_ok=True)
9
 
10
+ # Function to download the model weights from a Google Drive folder
11
  # Function to download the model weights from a Google Drive folder
12
  def download_weights_from_folder(google_drive_folder_link):
13
  # Extract the folder ID from the Google Drive link
 
17
  # Download all files in the Google Drive folder
18
  gdown_url = f"https://drive.google.com/drive/folders/{folder_id}"
19
  try:
20
+ # Download the folder contents
21
  gdown.download_folder(gdown_url, quiet=False, output=output_folder)
22
 
23
+ # Ensure the downloaded file is named 'model_state-415001.th'
24
+ downloaded_files = glob.glob(os.path.join(output_folder, "*.th"))
25
+ if downloaded_files:
26
+ downloaded_model_path = downloaded_files[0]
27
+ target_model_path = os.path.join(output_folder, "model_state-415001.th")
28
+
29
+ # Rename if necessary
30
+ if downloaded_model_path != target_model_path:
31
+ os.rename(downloaded_model_path, target_model_path)
32
+
33
+ print(f"Downloaded model weights to {target_model_path}")
34
  else:
35
+ print("Model file 'model_state-415001.th' not found in the folder.")
36
  except Exception as e:
37
+ print(f"Failed to download weights: {e}")
38
 
39
  download_weights_from_folder("https://drive.google.com/drive/folders/1Bq0n-w1VT5l99CoaVg02hFpqE5eGLo9O")
40