Vivek commited on
Commit
aa35330
1 Parent(s): ae69f53

Saving weights of epoch 1 at step 92

Browse files
config.json CHANGED
@@ -12,11 +12,13 @@
12
  "layer_norm_epsilon": 1e-05,
13
  "model_type": "gpt2",
14
  "n_ctx": 1024,
15
- "n_embd": 768,
16
- "n_head": 12,
17
  "n_inner": null,
18
- "n_layer": 12,
19
  "n_positions": 1024,
 
 
20
  "resid_pdrop": 0.1,
21
  "scale_attn_weights": true,
22
  "summary_activation": null,
 
12
  "layer_norm_epsilon": 1e-05,
13
  "model_type": "gpt2",
14
  "n_ctx": 1024,
15
+ "n_embd": 1024,
16
+ "n_head": 16,
17
  "n_inner": null,
18
+ "n_layer": 24,
19
  "n_positions": 1024,
20
+ "n_special": 0,
21
+ "predict_special_tokens": true,
22
  "resid_pdrop": 0.1,
23
  "scale_attn_weights": true,
24
  "summary_activation": null,
flax_model.msgpack CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:1f221fe358430aba9453bed74f64075b1276fec09453963df0741413f1c26e26
3
- size 3982472483
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:72f9b01147de892ad202a09ca9363f9299fb3bdda34dea308fc38ee984f63547
3
+ size 1419367919
results_tensorboard/events.out.tfevents.1626335331.t1v-n-8cb15980-w-0.767371.3.v2 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c7149f818695d44715930271eb8a1aa5738989e5ebb127eec505e28327cdbb7f
3
+ size 40
results_tensorboard/events.out.tfevents.1626335616.t1v-n-8cb15980-w-0.768996.3.v2 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ea79021a763c5dcbb8337dc5eca8bd291cb3398687692ab1cd1c804a3af6189e
3
+ size 25038
train.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import jax
2
+ print(jax.local_device_count())
3
+ import jax.numpy as jnp
4
+
5
+ import flax
6
+ import flax.linen as nn
7
+ from flax.training.common_utils import get_metrics,onehot,shard,shard_prng_key
8
+ from flax.training import train_state
9
+ from flax.metrics.tensorboard import SummaryWriter
10
+ from flax.training import checkpoints
11
+
12
+
13
+ import logging
14
+ import optax
15
+ import math
16
+ from tqdm import tqdm
17
+
18
+ from pathlib import Path
19
+ from typing import Callable
20
+ from itertools import chain
21
+ from flax.metrics import tensorboard
22
+
23
+ from datasets import load_dataset,load_metric
24
+ from transformers import GPT2Config,GPT2Tokenizer
25
+
26
+ from model_file import FlaxGPT2ForMultipleChoice
27
+
28
+ logger = logging.getLogger()
29
+ logger.setLevel(logging.INFO)
30
+
31
+ def main():
32
+
33
+
34
+ tokenizer=GPT2Tokenizer.from_pretrained('gpt2',pad_token='<|endoftext|>')
35
+
36
+ dataset=load_dataset('cosmos_qa')
37
+
38
+ def preprocess(example):
39
+ example['context&question']=example['context']+example['question']
40
+ example['first_sentence']=[example['context&question'],example['context&question'],example['context&question'],example['context&question']]
41
+ example['second_sentence']=example['answer0'],example['answer1'],example['answer2'],example['answer3']
42
+ return example
43
+
44
+ train_dataset=dataset['train'].map(preprocess)
45
+ validation_dataset=dataset['validation'].map(preprocess)
46
+ test_dataset=dataset['test'].map(preprocess)
47
+
48
+ #Remove after experiment
49
+ len_train_dataset=25262
50
+ len_validation_dataset=2985
51
+ len_test_dataset=6963
52
+
53
+ train_dataset=train_dataset.select(range(len_train_dataset))
54
+ test_dataset=test_dataset.select(range(len_test_dataset))
55
+ validation_dataset=validation_dataset.select(range(len_validation_dataset))
56
+
57
+ #remove_cols=train_dataset.column_names
58
+
59
+ def tokenize(examples):
60
+ a=tokenizer(examples['first_sentence'],examples['second_sentence'],padding='max_length',truncation=True,max_length=256,return_tensors='jax')
61
+ a['labels']=examples['label']
62
+ return a
63
+
64
+ train_dataset=train_dataset.map(tokenize)
65
+ validation_dataset=validation_dataset.map(tokenize)
66
+ test_dataset=test_dataset.map(tokenize)
67
+
68
+ remov_col=['id', 'context', 'question', 'answer0', 'answer1', 'answer2', 'answer3', 'labels', 'context&question', 'first_sentence', 'second_sentence']
69
+
70
+ train_dataset=train_dataset.remove_columns(remov_col)
71
+ validation_dataset=validation_dataset.remove_columns(remov_col)
72
+ test_dataset=test_dataset.remove_columns(remov_col)
73
+
74
+ per_device_batch_size=4
75
+ seed=0
76
+ num_train_epochs=5
77
+ learning_rate=2e-5
78
+
79
+
80
+ total_batch_size = per_device_batch_size * jax.local_device_count()
81
+ print('The overall batch size (both for training and eval) is', total_batch_size)
82
+ num_train_steps = len(train_dataset) // total_batch_size * num_train_epochs
83
+ num_validation_steps=len(validation_dataset)//total_batch_size*num_train_epochs
84
+
85
+ learning_rate_function = optax.linear_schedule(init_value=learning_rate, end_value=0, transition_steps=num_train_steps)
86
+
87
+ class TrainState(train_state.TrainState):
88
+ logits_function:Callable=flax.struct.field(pytree_node=False)
89
+ loss_function:Callable=flax.struct.field(pytree_node=False)
90
+
91
+ def adamw(weight_decay):
92
+ return optax.adamw(learning_rate=learning_rate_function,b1=0.9,b2=0.99,eps=1e-6,weight_decay=weight_decay)
93
+
94
+ decay_path=lambda p:not any(x in p for x in ['bias','LayerNorm.weight'])
95
+
96
+ def traverse(function):
97
+ def mask(data):
98
+ flat=flax.traverse_util.flatten_dict(data)
99
+ return flax.traverse_util.unflatten_dict({k:function(k,v) for k,v in flat.items()})
100
+ return mask
101
+
102
+ gradient_transformation=optax.chain(
103
+ optax.masked(adamw(0.0),mask=traverse(lambda path,_:decay_path(path))),
104
+ optax.masked(adamw(0.01),mask=traverse(lambda path,_:not decay_path(path))))
105
+
106
+ def loss_function(logits,labels):
107
+ logits=flax.linen.log_softmax(logits)
108
+ xentropy=optax.softmax_cross_entropy(logits,onehot(labels,num_classes=4))
109
+ return jnp.mean(xentropy)
110
+
111
+ def eval_function(logits):
112
+ return logits.argmax(-1)
113
+
114
+ model = FlaxGPT2ForMultipleChoice.from_pretrained('gpt2-medium',input_shape=(1,4,1))
115
+
116
+ state=TrainState.create(apply_fn=model.__call__,
117
+ params=model.params,
118
+ tx=gradient_transformation,
119
+ logits_function=eval_function,
120
+ loss_function=loss_function)
121
+
122
+ def train_step(state,batch,dropout_rng):
123
+ targets=batch.pop("label")
124
+ dropout_rng,new_dropout_rng=jax.random.split(dropout_rng)
125
+ def loss_function(params):
126
+ logits=state.apply_fn(**batch,params=params,dropout_rng=dropout_rng,train=True)[0]
127
+ loss=state.loss_function(logits,targets)
128
+ return loss
129
+ grad_function=jax.value_and_grad(loss_function)
130
+ loss,grad=grad_function(state.params)
131
+ grad=jax.lax.pmean(grad,"batch")
132
+ new_state=state.apply_gradients(grads=grad)
133
+ #Added.
134
+ logits=new_state.apply_fn(**batch,params=new_state.params,dropout_rng=dropout_rng,train=True)[0]
135
+ accuracy=jnp.equal(jnp.argmax(logits,axis=-1),targets)
136
+ metrics=jax.lax.pmean({"loss":loss,"learning_rate":learning_rate_function(state.step),'accuracy':accuracy},axis_name="batch")
137
+ return new_state,metrics,new_dropout_rng
138
+
139
+ parallel_train_step = jax.pmap(train_step, axis_name="batch", donate_argnums=(0,))
140
+
141
+ def eval_step(state, batch):
142
+ targets=batch.pop('label')
143
+ logits = state.apply_fn(**batch, params=state.params, train=False)
144
+ loss=state.loss_function(logits,targets)
145
+ predictions=state.logits_function(logits)
146
+ eval_accuracy=jnp.equal(predictions,targets)
147
+ #eval_acc=jnp.equal(predictions,targets)
148
+ metrics=jax.lax.pmean({"loss":loss,'accuracy':eval_accuracy},axis_name="batch")
149
+ #return state.logits_function(logits) #(8,4)
150
+ return targets,predictions,metrics
151
+
152
+ parallel_eval_step = jax.pmap(eval_step, axis_name="batch")
153
+
154
+ def glue_train_data_loader(rng,dataset,batch_size):
155
+ steps_per_epoch=len_train_dataset//batch_size
156
+ perms=jax.random.permutation(rng,len(dataset))
157
+ perms=perms[:steps_per_epoch*batch_size]
158
+ perms=perms.reshape((steps_per_epoch,batch_size))
159
+ for perm in perms:
160
+ batch=dataset[perm]
161
+ batch={k:jnp.array(v) for k,v in batch.items()}
162
+ batch=shard(batch)
163
+ yield batch
164
+
165
+ rng=jax.random.PRNGKey(seed)
166
+ dropout_rngs=jax.random.split(rng,jax.local_device_count())
167
+
168
+ def glue_eval_data_loader(dataset, batch_size):
169
+ for i in range(len_validation_dataset // batch_size):
170
+ batch = dataset[i * batch_size : (i + 1) * batch_size]
171
+ batch = {k: jnp.array(v) for k, v in batch.items()}
172
+ batch = shard(batch)
173
+
174
+ yield batch
175
+
176
+ state = flax.jax_utils.replicate(state)
177
+ #metrics_list = list_metrics()
178
+
179
+ actual_task = "mnli"
180
+ metric = load_metric('glue', "mnli")
181
+ actual_taskmetric = load_metric('glue', actual_task)
182
+
183
+ workdir='./results_tensorboard'
184
+ summary_writer = tensorboard.SummaryWriter(workdir)
185
+ #summary_writer.hparams(dict(GPT2Config()))
186
+
187
+ logger.info(f"***** Running training *****")
188
+ logger.info(f" Num examples = {len_train_dataset}")
189
+ logger.info(f" Num Epochs = {1}")
190
+ logger.info(f" Instantaneous batch size per device = {per_device_batch_size}")
191
+ logger.info(f" Total train batch size = {total_batch_size}")
192
+ logger.info(f" Total optimization steps = {num_train_steps}")
193
+
194
+ for i, epoch in enumerate(tqdm(range(1, num_train_epochs+1), desc=f"Epoch ...", position=0, leave=True)):
195
+ rng, input_rng = jax.random.split(rng)
196
+ train_acc_metrics=[]
197
+ train_loss_metrics=[]
198
+ eval_acc_metrics=[]
199
+ eval_loss_metrics=[]
200
+ # train
201
+ with tqdm(total=len_train_dataset // total_batch_size, desc="Training...", leave=False) as progress_bar_train:
202
+ for idx,batch in enumerate(glue_train_data_loader(input_rng, train_dataset, total_batch_size)):
203
+ state, train_metric, dropout_rngs = parallel_train_step(state, batch, dropout_rngs)
204
+ train_acc_metrics.append(jax.device_get(train_metric['accuracy']).mean().item())
205
+ train_loss_metrics.append(flax.jax_utils.unreplicate(train_metric)['loss'].item())
206
+ if idx%5==0:
207
+ summary_writer.scalar('train_loss',flax.jax_utils.unreplicate(train_metric)['loss'].item(),idx)
208
+ summary_writer.scalar('train_accuracy', jax.device_get(train_metric['accuracy']).mean().item(),idx)
209
+
210
+ progress_bar_train.update(1)
211
+
212
+ # evaluate
213
+ with tqdm(total=len_validation_dataset // total_batch_size, desc="Evaluating...", leave=False) as progress_bar_eval:
214
+ for idx,batch in enumerate(glue_eval_data_loader(validation_dataset, total_batch_size)):
215
+ labels,predictions,eval_metric=parallel_eval_step(state, batch)
216
+ eval_acc_metrics.append(jax.device_get(eval_metric['accuracy']).mean().item())
217
+ eval_loss_metrics.append(flax.jax_utils.unreplicate(eval_metric)['loss'].item())
218
+ progress_bar_eval.update(1)
219
+ if idx%5==0:
220
+ logger.info(f"eval_step_loss{idx}: {flax.jax_utils.unreplicate(eval_metric)['loss'].item()} eval_step_acc{idx}: {jax.device_get(eval_metric['accuracy']).mean().item()}")
221
+ summary_writer.scalar('eval_loss',flax.jax_utils.unreplicate(eval_metric)['loss'].item(),idx)
222
+ summary_writer.scalar('eval_accuracy', jax.device_get(eval_metric['accuracy']).mean().item(),idx)
223
+
224
+ params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
225
+ model.save_pretrained(
226
+ '.',
227
+ params=params,
228
+ push_to_hub=True,
229
+ commit_message=f"Saving weights of epoch {epoch} at step {idx}",)
230
+
231
+ #correct
232
+ logger.info(f"---------------------Epoch {epoch} done-----------------")
233
+ logger.info(f"Train loss: {jax.device_get(jnp.array(train_loss_metrics)).mean().item()} Train accuracy: {jax.device_get(jnp.array(train_acc_metrics)).mean().item()}")
234
+ logger.info(f"Eval loss: {jax.device_get(jnp.array(eval_loss_metrics)).mean().item()} Eval accuracy: {jax.device_get(jnp.array(eval_acc_metrics)).mean().item()}")
235
+ summary_writer.flush()
236
+
237
+ if __name__ == "__main__":
238
+ main()