Spaces:
Sleeping
Sleeping
fix slow respond after days
Browse files
app.py
CHANGED
@@ -9,19 +9,9 @@ from dearth_model import DearthForCausalLM
|
|
9 |
import random
|
10 |
|
11 |
|
12 |
-
tk = transformers.AutoTokenizer.from_pretrained("./tk")
|
13 |
|
14 |
-
|
15 |
model_path = "./ts100-re2-h1-4000-model.pt"
|
16 |
-
yml_path = "./ts100-re2-h1.yml"
|
17 |
-
with open(yml_path, "r") as f:
|
18 |
-
config = yaml.load(f, Loader=yaml.FullLoader)['model']
|
19 |
-
if "vocab_size" not in config:
|
20 |
-
config['vocab_size'] = tk.vocab_size
|
21 |
-
config["attn_window_size"] = 500
|
22 |
-
print(config)
|
23 |
-
config = DearthConfig(**config)
|
24 |
-
model = DearthForCausalLM(config)
|
25 |
states = torch.load(model_path, map_location="cpu")
|
26 |
model_states = states
|
27 |
unwanted_prefix_dueto_compile = '_orig_mod.'
|
@@ -39,10 +29,21 @@ for k,v in list(model_states.items()):
|
|
39 |
new_key = k[len(unwanted_prefix_dueto_compile):]
|
40 |
model_states[k[len(unwanted_prefix_dueto_compile):]] = model_states.pop(k)
|
41 |
|
42 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
|
44 |
|
45 |
-
def generate(input, num_more_tokens):
|
46 |
num_more_tokens = int(num_more_tokens)
|
47 |
print(input)
|
48 |
input = input.strip()
|
@@ -84,7 +85,7 @@ The PPL on the validation set is 1.7, in comparison, the teacher model has a PPL
|
|
84 |
"""
|
85 |
|
86 |
|
87 |
-
|
88 |
fn=generate,
|
89 |
title="Tinystories LM 11M",
|
90 |
description=Description,
|
@@ -95,4 +96,6 @@ server = gr.Interface(
|
|
95 |
outputs="text"
|
96 |
)
|
97 |
|
98 |
-
|
|
|
|
|
|
9 |
import random
|
10 |
|
11 |
|
|
|
12 |
|
13 |
+
tk = transformers.AutoTokenizer.from_pretrained("./tk")
|
14 |
model_path = "./ts100-re2-h1-4000-model.pt"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
states = torch.load(model_path, map_location="cpu")
|
16 |
model_states = states
|
17 |
unwanted_prefix_dueto_compile = '_orig_mod.'
|
|
|
29 |
new_key = k[len(unwanted_prefix_dueto_compile):]
|
30 |
model_states[k[len(unwanted_prefix_dueto_compile):]] = model_states.pop(k)
|
31 |
|
32 |
+
def generate(input, num_more_tokens):
|
33 |
+
|
34 |
+
yml_path = "./ts100-re2-h1.yml"
|
35 |
+
with open(yml_path, "r") as f:
|
36 |
+
config = yaml.load(f, Loader=yaml.FullLoader)['model']
|
37 |
+
if "vocab_size" not in config:
|
38 |
+
config['vocab_size'] = tk.vocab_size
|
39 |
+
config["attn_window_size"] = 500
|
40 |
+
print(config)
|
41 |
+
config = DearthConfig(**config)
|
42 |
+
model = DearthForCausalLM(config)
|
43 |
+
|
44 |
+
model.load_state_dict(model_states)
|
45 |
|
46 |
|
|
|
47 |
num_more_tokens = int(num_more_tokens)
|
48 |
print(input)
|
49 |
input = input.strip()
|
|
|
85 |
"""
|
86 |
|
87 |
|
88 |
+
demo = gr.Interface(
|
89 |
fn=generate,
|
90 |
title="Tinystories LM 11M",
|
91 |
description=Description,
|
|
|
96 |
outputs="text"
|
97 |
)
|
98 |
|
99 |
+
if __name__ == "__main__":
|
100 |
+
demo.queue()
|
101 |
+
demo.launch(show_api=False)
|