Spaces:
Sleeping
Sleeping
Alan Liu
commited on
Commit
•
d1c8a18
1
Parent(s):
3849813
check compute_module_sizes
Browse files- app.py +1 -1
- model_util.py +86 -2
app.py
CHANGED
@@ -27,7 +27,7 @@ def load_model_config(model_id):
|
|
27 |
model_config['max_position_embeddings'] = dictionary_content['max_position_embeddings']
|
28 |
model_config['layernorm_operation'] = 2
|
29 |
else:
|
30 |
-
st.warning("
|
31 |
model_config['model_id'] = 'opt-1.3b'
|
32 |
model_config['hidden_size'] = 2048
|
33 |
model_config['num_attention_heads'] = 32
|
|
|
27 |
model_config['max_position_embeddings'] = dictionary_content['max_position_embeddings']
|
28 |
model_config['layernorm_operation'] = 2
|
29 |
else:
|
30 |
+
st.warning("Fetching information failed! Maybe model info is not public!")
|
31 |
model_config['model_id'] = 'opt-1.3b'
|
32 |
model_config['hidden_size'] = 2048
|
33 |
model_config['num_attention_heads'] = 32
|
model_util.py
CHANGED
@@ -1,5 +1,10 @@
|
|
1 |
import requests
|
2 |
-
|
|
|
|
|
|
|
|
|
|
|
3 |
|
4 |
def fetch_dictionary_content(model_id):
|
5 |
MODEL_URL = "https://huggingface.co/{model_id}/raw/main/config.json"
|
@@ -15,4 +20,83 @@ def load_parameter(model_dict, cand_keys):
|
|
15 |
for k in cand_keys:
|
16 |
if k in model_dict:
|
17 |
return model_dict[k]
|
18 |
-
return 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import requests
|
2 |
+
# Utilities related to loading in and working with models/specific models
|
3 |
+
from urllib.parse import urlparse
|
4 |
+
import torch
|
5 |
+
from accelerate.commands.estimate import check_has_model, create_empty_model
|
6 |
+
from accelerate.utils import compute_module_sizes
|
7 |
+
from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError
|
8 |
|
9 |
def fetch_dictionary_content(model_id):
|
10 |
MODEL_URL = "https://huggingface.co/{model_id}/raw/main/config.json"
|
|
|
20 |
for k in cand_keys:
|
21 |
if k in model_dict:
|
22 |
return model_dict[k]
|
23 |
+
return 0
|
24 |
+
|
25 |
+
# Reference: https://huggingface.co/spaces/hf-accelerate/model-memory-usage
|
26 |
+
def extract_from_url(name: str):
|
27 |
+
"Checks if `name` is a URL, and if so converts it to a model name"
|
28 |
+
is_url = False
|
29 |
+
try:
|
30 |
+
result = urlparse(name)
|
31 |
+
is_url = all([result.scheme, result.netloc])
|
32 |
+
except Exception:
|
33 |
+
is_url = False
|
34 |
+
# Pass through if not a URL
|
35 |
+
if not is_url:
|
36 |
+
return name
|
37 |
+
else:
|
38 |
+
path = result.path
|
39 |
+
return path[1:]
|
40 |
+
|
41 |
+
|
42 |
+
def translate_llama2(text):
|
43 |
+
"Translates llama-2 to its hf counterpart"
|
44 |
+
if not text.endswith("-hf"):
|
45 |
+
return text + "-hf"
|
46 |
+
return text
|
47 |
+
|
48 |
+
|
49 |
+
def get_model(model_name: str, library: str, access_token: str):
|
50 |
+
"Finds and grabs model from the Hub, and initializes on `meta`"
|
51 |
+
if "meta-llama" in model_name:
|
52 |
+
model_name = translate_llama2(model_name)
|
53 |
+
if library == "auto":
|
54 |
+
library = None
|
55 |
+
model_name = extract_from_url(model_name)
|
56 |
+
try:
|
57 |
+
model = create_empty_model(model_name, library_name=library, trust_remote_code=True, access_token=access_token)
|
58 |
+
except GatedRepoError:
|
59 |
+
raise RuntimeError(
|
60 |
+
f"Model `{model_name}` is a gated model, please ensure to pass in your access token and try again if you have access. You can find your access token here : https://huggingface.co/settings/tokens. "
|
61 |
+
)
|
62 |
+
except RepositoryNotFoundError:
|
63 |
+
raise RuntimeError(f"Model `{model_name}` was not found on the Hub, please try another model name.")
|
64 |
+
except ValueError:
|
65 |
+
raise RuntimeError(
|
66 |
+
f"Model `{model_name}` does not have any library metadata on the Hub, please manually select a library_name to use (such as `transformers`)"
|
67 |
+
)
|
68 |
+
except (RuntimeError, OSError) as e:
|
69 |
+
library = check_has_model(e)
|
70 |
+
if library != "unknown":
|
71 |
+
raise RuntimeError(
|
72 |
+
f"Tried to load `{model_name}` with `{library}` but a possible model to load was not found inside the repo."
|
73 |
+
)
|
74 |
+
raise RuntimeError(
|
75 |
+
f"Model `{model_name}` had an error, please open a discussion on the model's page with the error message and name: `{e}`"
|
76 |
+
)
|
77 |
+
except ImportError:
|
78 |
+
# hacky way to check if it works with `trust_remote_code=False`
|
79 |
+
model = create_empty_model(
|
80 |
+
model_name, library_name=library, trust_remote_code=False, access_token=access_token
|
81 |
+
)
|
82 |
+
except Exception as e:
|
83 |
+
raise RuntimeError(
|
84 |
+
f"Model `{model_name}` had an error, please open a discussion on the model's page with the error message and name: `{e}`"
|
85 |
+
)
|
86 |
+
return model
|
87 |
+
|
88 |
+
|
89 |
+
|
90 |
+
if __name__ == '__main__':
|
91 |
+
model = get_model('NousResearch/Nous-Hermes-Llama2-13b', None, None)
|
92 |
+
sizes = compute_module_sizes(model, dtype=torch.int8)
|
93 |
+
size_dict = {
|
94 |
+
'attn':0,
|
95 |
+
'mlp':0,
|
96 |
+
'embed':0,
|
97 |
+
}
|
98 |
+
for k, v in sizes.items():
|
99 |
+
for kk in size_dict:
|
100 |
+
if kk in k and 'weight' in k:
|
101 |
+
size_dict[kk] += v/1024**3
|
102 |
+
print(sizes)
|