Allow URLs
Browse files
app.py
CHANGED
@@ -7,6 +7,7 @@ from huggingface_hub import HfApi
|
|
7 |
from huggingface_hub.utils import RepositoryNotFoundError, GatedRepoError
|
8 |
from accelerate.commands.estimate import create_empty_model, check_has_model
|
9 |
from accelerate.utils import convert_bytes, calculate_maximum_sizes
|
|
|
10 |
|
11 |
# We need to store them as globals because gradio doesn't have a way for us to pass them in to the button
|
12 |
HAS_DISCUSSION = True
|
@@ -54,12 +55,20 @@ When training with `Adam`, you can expect roughly 4x the reported results to be
|
|
54 |
discussion = api.create_discussion(MODEL_NAME, "[AUTOMATED] Model Memory Requirements", description=post)
|
55 |
webbrowser.open_new_tab(discussion.url)
|
56 |
|
57 |
-
def
|
58 |
-
"
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
|
64 |
def calculate_memory(model_name:str, library:str, options:list, access_token:str, raw=False):
|
65 |
"Calculates the memory usage for a model"
|
@@ -67,11 +76,7 @@ def calculate_memory(model_name:str, library:str, options:list, access_token:str
|
|
67 |
model_name = translate_llama2(model_name)
|
68 |
if library == "auto":
|
69 |
library = None
|
70 |
-
|
71 |
-
try:
|
72 |
-
model_name = convert_url_to_name(model_name)
|
73 |
-
except ValueError:
|
74 |
-
raise gr.Error(f"URL `{model_name}` is not a valid model URL to the Hugging Face Hub")
|
75 |
try:
|
76 |
model = create_empty_model(model_name, library_name=library, trust_remote_code=True, access_token=access_token)
|
77 |
except GatedRepoError:
|
|
|
7 |
from huggingface_hub.utils import RepositoryNotFoundError, GatedRepoError
|
8 |
from accelerate.commands.estimate import create_empty_model, check_has_model
|
9 |
from accelerate.utils import convert_bytes, calculate_maximum_sizes
|
10 |
+
from urllib.parse import urlparse
|
11 |
|
12 |
# We need to store them as globals because gradio doesn't have a way for us to pass them in to the button
|
13 |
HAS_DISCUSSION = True
|
|
|
55 |
discussion = api.create_discussion(MODEL_NAME, "[AUTOMATED] Model Memory Requirements", description=post)
|
56 |
webbrowser.open_new_tab(discussion.url)
|
57 |
|
58 |
+
def extract_from_url(name:str):
|
59 |
+
"Checks if `name` is a URL, and if so converts it to a model name"
|
60 |
+
is_url = False
|
61 |
+
try:
|
62 |
+
result = urlparse(name)
|
63 |
+
is_url = all([result.scheme, result.netloc])
|
64 |
+
except:
|
65 |
+
is_url = False
|
66 |
+
# Pass through if not a URL
|
67 |
+
if not is_url:
|
68 |
+
return name
|
69 |
+
else:
|
70 |
+
path = result.path
|
71 |
+
return path[1:]
|
72 |
|
73 |
def calculate_memory(model_name:str, library:str, options:list, access_token:str, raw=False):
|
74 |
"Calculates the memory usage for a model"
|
|
|
76 |
model_name = translate_llama2(model_name)
|
77 |
if library == "auto":
|
78 |
library = None
|
79 |
+
model_name = extract_from_url(model_name)
|
|
|
|
|
|
|
|
|
80 |
try:
|
81 |
model = create_empty_model(model_name, library_name=library, trust_remote_code=True, access_token=access_token)
|
82 |
except GatedRepoError:
|