XFious commited on
Commit
7569c73
·
1 Parent(s): ce226d1

random input example

Browse files
Files changed (2) hide show
  1. app.py +40 -18
  2. random_input_example.js +8 -0
app.py CHANGED
@@ -7,6 +7,7 @@ from dearth_config import DearthConfig
7
  from dearth_model import DearthForCausalLM
8
 
9
  import random
 
10
 
11
 
12
 
@@ -37,7 +38,7 @@ def generate(input, num_more_tokens):
37
  if "vocab_size" not in config:
38
  config['vocab_size'] = tk.vocab_size
39
  config["attn_window_size"] = 500
40
- print(config)
41
  config = DearthConfig(**config)
42
  model = DearthForCausalLM(config)
43
 
@@ -45,12 +46,12 @@ def generate(input, num_more_tokens):
45
 
46
 
47
  num_more_tokens = int(num_more_tokens)
48
- print(input)
49
  input = input.strip()
50
  input_ids = tk.encode(input)
51
  input_ids = [tk.bos_token_id] + input_ids
52
  input_ids = torch.tensor(input_ids, dtype=torch.long).view(1, -1)
53
- print(input_ids)
54
 
55
  output_ids = input_ids.squeeze(0).tolist()
56
  for i in range(num_more_tokens):
@@ -58,7 +59,7 @@ def generate(input, num_more_tokens):
58
  with torch.no_grad():
59
  output = model(input)[0]
60
  last_token_logits = output[0, -1, :]
61
- last_token_logits_topk = torch.topk(last_token_logits, k=8, dim=-1)
62
  probs = torch.softmax(last_token_logits_topk.values, dim=-1)
63
  new_token = torch.multinomial(probs, num_samples=1).item()
64
  new_token = last_token_logits_topk.indices[new_token].item()
@@ -66,8 +67,8 @@ def generate(input, num_more_tokens):
66
  break
67
  output_ids.append(new_token)
68
 
69
- print(output_ids)
70
- print(tk.decode(output_ids))
71
  output_ids = output_ids[1:]
72
 
73
  return tk.decode(output_ids)
@@ -77,7 +78,7 @@ example_input = ["Once upon a time, there was a little girl",
77
  "It was a warm summer day when Billy and",
78
  ]
79
 
80
-
81
  Description = """
82
  This is a small language model with 11M parameters, trained with the TinyStories dataset, and distilled from a 28M parameter teacher model.\n
83
  This model has been trained with 512M tokens, which is about 0.9 epoch of the TinyStories dataset.\n
@@ -85,17 +86,38 @@ The PPL on the validation set is 1.7, in comparison, the teacher model has a PPL
85
  """
86
 
87
 
88
- demo = gr.Interface(
89
- fn=generate,
90
- title="Tinystories LM 11M",
91
- description=Description,
92
- inputs=[
93
- gr.Textbox(lines=5, label="Input Text", value=example_input[random.randint(0, len(example_input)-1)]),
94
- gr.Slider(16, 64, step=1.0, value=32, label="more tokens", info="")
95
- ],
96
- outputs="text"
97
- )
 
 
 
98
 
99
  if __name__ == "__main__":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  demo.queue()
101
- demo.launch(show_api=False)
 
7
  from dearth_model import DearthForCausalLM
8
 
9
  import random
10
+ import time
11
 
12
 
13
 
 
38
  if "vocab_size" not in config:
39
  config['vocab_size'] = tk.vocab_size
40
  config["attn_window_size"] = 500
41
+ # print(config)
42
  config = DearthConfig(**config)
43
  model = DearthForCausalLM(config)
44
 
 
46
 
47
 
48
  num_more_tokens = int(num_more_tokens)
49
+ # print(input)
50
  input = input.strip()
51
  input_ids = tk.encode(input)
52
  input_ids = [tk.bos_token_id] + input_ids
53
  input_ids = torch.tensor(input_ids, dtype=torch.long).view(1, -1)
54
+ # print(input_ids)
55
 
56
  output_ids = input_ids.squeeze(0).tolist()
57
  for i in range(num_more_tokens):
 
59
  with torch.no_grad():
60
  output = model(input)[0]
61
  last_token_logits = output[0, -1, :]
62
+ last_token_logits_topk = torch.topk(last_token_logits, k=5, dim=-1)
63
  probs = torch.softmax(last_token_logits_topk.values, dim=-1)
64
  new_token = torch.multinomial(probs, num_samples=1).item()
65
  new_token = last_token_logits_topk.indices[new_token].item()
 
67
  break
68
  output_ids.append(new_token)
69
 
70
+ # print(output_ids)
71
+ # print(tk.decode(output_ids))
72
  output_ids = output_ids[1:]
73
 
74
  return tk.decode(output_ids)
 
78
  "It was a warm summer day when Billy and",
79
  ]
80
 
81
+ ui_title = "Tinystories LM 11M"
82
  Description = """
83
  This is a small language model with 11M parameters, trained with the TinyStories dataset, and distilled from a 28M parameter teacher model.\n
84
  This model has been trained with 512M tokens, which is about 0.9 epoch of the TinyStories dataset.\n
 
86
  """
87
 
88
 
89
+ # demo = gr.Interface(
90
+ # fn=generate,
91
+ # title="Tinystories LM 11M",
92
+ # description=Description,
93
+ # inputs=[
94
+ # gr.Textbox(lines=5, label="Input Text", value=example_input[random.randint(0, len(example_input)-1)]),
95
+ # gr.Slider(16, 64, step=1.0, value=32, label="more tokens", info="")
96
+ # ],
97
+ # outputs="text"
98
+ # )
99
+
100
+ with open("./random_input_example.js" , "r") as f:
101
+ file_content = f.read()
102
 
103
  if __name__ == "__main__":
104
+ with gr.Blocks(
105
+ title="Tinystories LM 11M",
106
+ js="./random_input_example.js"
107
+ ) as demo:
108
+ with gr.Blocks(title="Description"):
109
+ gr.HTML(f"<h1>{ui_title}</h1>")
110
+ gr.Markdown(Description)
111
+ with gr.Row():
112
+ with gr.Column():
113
+ inp = gr.Textbox(lines=5, label="Input Text", value=example_input[random.randint(0, len(example_input)-1)], elem_id="input_textbox")
114
+ generate_max_slider = gr.Slider(16, 64, step=1.0, value=32, label="more tokens", info="")
115
+ generate_button = gr.Button(value="Generate")
116
+ with gr.Column():
117
+ out = gr.Textbox(lines=5, label="Output Text", value="")
118
+ out.readonly = True
119
+ @generate_button.click(inputs=[inp, generate_max_slider], outputs=[out])
120
+ def generate_inside(input, num_more_tokens):
121
+ return generate(input, num_more_tokens)
122
  demo.queue()
123
+ demo.launch()
random_input_example.js ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ () => {
2
+ let example_input = ["Once upon a time, there was a little girl",
3
+ "John and Sarah were playing together in their backyard when",
4
+ "It was a warm summer day when Billy and",
5
+ ];
6
+ let input_textbox = document.querySelector("#input_textbox textarea")
7
+ input_textbox.value = example_input[Math.floor(Math.random() * example_input.length)];
8
+ }