XFious commited on
Commit
ce226d1
·
1 Parent(s): 4ae913a

fix slow respond after days

Browse files
Files changed (1) hide show
  1. app.py +18 -15
app.py CHANGED
@@ -9,19 +9,9 @@ from dearth_model import DearthForCausalLM
9
  import random
10
 
11
 
12
- tk = transformers.AutoTokenizer.from_pretrained("./tk")
13
 
14
- #model_path = "./ts100-re2-h1-4000.pt"
15
  model_path = "./ts100-re2-h1-4000-model.pt"
16
- yml_path = "./ts100-re2-h1.yml"
17
- with open(yml_path, "r") as f:
18
- config = yaml.load(f, Loader=yaml.FullLoader)['model']
19
- if "vocab_size" not in config:
20
- config['vocab_size'] = tk.vocab_size
21
- config["attn_window_size"] = 500
22
- print(config)
23
- config = DearthConfig(**config)
24
- model = DearthForCausalLM(config)
25
  states = torch.load(model_path, map_location="cpu")
26
  model_states = states
27
  unwanted_prefix_dueto_compile = '_orig_mod.'
@@ -39,10 +29,21 @@ for k,v in list(model_states.items()):
39
  new_key = k[len(unwanted_prefix_dueto_compile):]
40
  model_states[k[len(unwanted_prefix_dueto_compile):]] = model_states.pop(k)
41
 
42
- model.load_state_dict(model_states)
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
 
45
- def generate(input, num_more_tokens):
46
  num_more_tokens = int(num_more_tokens)
47
  print(input)
48
  input = input.strip()
@@ -84,7 +85,7 @@ The PPL on the validation set is 1.7, in comparison, the teacher model has a PPL
84
  """
85
 
86
 
87
- server = gr.Interface(
88
  fn=generate,
89
  title="Tinystories LM 11M",
90
  description=Description,
@@ -95,4 +96,6 @@ server = gr.Interface(
95
  outputs="text"
96
  )
97
 
98
- server.launch()
 
 
 
9
  import random
10
 
11
 
 
12
 
13
+ tk = transformers.AutoTokenizer.from_pretrained("./tk")
14
  model_path = "./ts100-re2-h1-4000-model.pt"
 
 
 
 
 
 
 
 
 
15
  states = torch.load(model_path, map_location="cpu")
16
  model_states = states
17
  unwanted_prefix_dueto_compile = '_orig_mod.'
 
29
  new_key = k[len(unwanted_prefix_dueto_compile):]
30
  model_states[k[len(unwanted_prefix_dueto_compile):]] = model_states.pop(k)
31
 
32
+ def generate(input, num_more_tokens):
33
+
34
+ yml_path = "./ts100-re2-h1.yml"
35
+ with open(yml_path, "r") as f:
36
+ config = yaml.load(f, Loader=yaml.FullLoader)['model']
37
+ if "vocab_size" not in config:
38
+ config['vocab_size'] = tk.vocab_size
39
+ config["attn_window_size"] = 500
40
+ print(config)
41
+ config = DearthConfig(**config)
42
+ model = DearthForCausalLM(config)
43
+
44
+ model.load_state_dict(model_states)
45
 
46
 
 
47
  num_more_tokens = int(num_more_tokens)
48
  print(input)
49
  input = input.strip()
 
85
  """
86
 
87
 
88
+ demo = gr.Interface(
89
  fn=generate,
90
  title="Tinystories LM 11M",
91
  description=Description,
 
96
  outputs="text"
97
  )
98
 
99
+ if __name__ == "__main__":
100
+ demo.queue()
101
+ demo.launch(show_api=False)