Meloo commited on
Commit
aa63182
·
verified ·
1 Parent(s): b53be1a

Create utils/download_url.py

Browse files
Files changed (1) hide show
  1. utils/utils/download_url.py +99 -0
utils/utils/download_url.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ import requests
4
+ from torch.hub import download_url_to_file, get_dir
5
+ from tqdm import tqdm
6
+ from urllib.parse import urlparse
7
+
8
+ from .misc import sizeof_fmt
9
+
10
+
11
+ def download_file_from_google_drive(file_id, save_path):
12
+ """Download files from google drive.
13
+
14
+ Ref:
15
+ https://stackoverflow.com/questions/25010369/wget-curl-large-file-from-google-drive # noqa E501
16
+
17
+ Args:
18
+ file_id (str): File id.
19
+ save_path (str): Save path.
20
+ """
21
+
22
+ session = requests.Session()
23
+ URL = 'https://docs.google.com/uc?export=download'
24
+ params = {'id': file_id}
25
+
26
+ response = session.get(URL, params=params, stream=True)
27
+ token = get_confirm_token(response)
28
+ if token:
29
+ params['confirm'] = token
30
+ response = session.get(URL, params=params, stream=True)
31
+
32
+ # get file size
33
+ response_file_size = session.get(URL, params=params, stream=True, headers={'Range': 'bytes=0-2'})
34
+ if 'Content-Range' in response_file_size.headers:
35
+ file_size = int(response_file_size.headers['Content-Range'].split('/')[1])
36
+ else:
37
+ file_size = None
38
+
39
+ save_response_content(response, save_path, file_size)
40
+
41
+
42
+ def get_confirm_token(response):
43
+ for key, value in response.cookies.items():
44
+ if key.startswith('download_warning'):
45
+ return value
46
+ return None
47
+
48
+
49
+ def save_response_content(response, destination, file_size=None, chunk_size=32768):
50
+ if file_size is not None:
51
+ pbar = tqdm(total=math.ceil(file_size / chunk_size), unit='chunk')
52
+
53
+ readable_file_size = sizeof_fmt(file_size)
54
+ else:
55
+ pbar = None
56
+
57
+ with open(destination, 'wb') as f:
58
+ downloaded_size = 0
59
+ for chunk in response.iter_content(chunk_size):
60
+ downloaded_size += chunk_size
61
+ if pbar is not None:
62
+ pbar.update(1)
63
+ pbar.set_description(f'Download {sizeof_fmt(downloaded_size)} / {readable_file_size}')
64
+ if chunk: # filter out keep-alive new chunks
65
+ f.write(chunk)
66
+ if pbar is not None:
67
+ pbar.close()
68
+
69
+
70
+ def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
71
+ """Load file form http url, will download models if necessary.
72
+
73
+ Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
74
+
75
+ Args:
76
+ url (str): URL to be downloaded.
77
+ model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir.
78
+ Default: None.
79
+ progress (bool): Whether to show the download progress. Default: True.
80
+ file_name (str): The downloaded file name. If None, use the file name in the url. Default: None.
81
+
82
+ Returns:
83
+ str: The path to the downloaded file.
84
+ """
85
+ if model_dir is None: # use the pytorch hub_dir
86
+ hub_dir = get_dir()
87
+ model_dir = os.path.join(hub_dir, 'checkpoints')
88
+
89
+ os.makedirs(model_dir, exist_ok=True)
90
+
91
+ parts = urlparse(url)
92
+ filename = os.path.basename(parts.path)
93
+ if file_name is not None:
94
+ filename = file_name
95
+ cached_file = os.path.abspath(os.path.join(model_dir, filename))
96
+ if not os.path.exists(cached_file):
97
+ print(f'Downloading: "{url}" to {cached_file}\n')
98
+ download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
99
+ return cached_file