Saving weights of epoch 1 at step 92
Browse files
config.json
CHANGED
@@ -12,11 +12,13 @@
|
|
12 |
"layer_norm_epsilon": 1e-05,
|
13 |
"model_type": "gpt2",
|
14 |
"n_ctx": 1024,
|
15 |
-
"n_embd":
|
16 |
-
"n_head":
|
17 |
"n_inner": null,
|
18 |
-
"n_layer":
|
19 |
"n_positions": 1024,
|
|
|
|
|
20 |
"resid_pdrop": 0.1,
|
21 |
"scale_attn_weights": true,
|
22 |
"summary_activation": null,
|
|
|
12 |
"layer_norm_epsilon": 1e-05,
|
13 |
"model_type": "gpt2",
|
14 |
"n_ctx": 1024,
|
15 |
+
"n_embd": 1024,
|
16 |
+
"n_head": 16,
|
17 |
"n_inner": null,
|
18 |
+
"n_layer": 24,
|
19 |
"n_positions": 1024,
|
20 |
+
"n_special": 0,
|
21 |
+
"predict_special_tokens": true,
|
22 |
"resid_pdrop": 0.1,
|
23 |
"scale_attn_weights": true,
|
24 |
"summary_activation": null,
|
flax_model.msgpack
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:72f9b01147de892ad202a09ca9363f9299fb3bdda34dea308fc38ee984f63547
|
3 |
+
size 1419367919
|
results_tensorboard/events.out.tfevents.1626335331.t1v-n-8cb15980-w-0.767371.3.v2
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c7149f818695d44715930271eb8a1aa5738989e5ebb127eec505e28327cdbb7f
|
3 |
+
size 40
|
results_tensorboard/events.out.tfevents.1626335616.t1v-n-8cb15980-w-0.768996.3.v2
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ea79021a763c5dcbb8337dc5eca8bd291cb3398687692ab1cd1c804a3af6189e
|
3 |
+
size 25038
|
train.py
ADDED
@@ -0,0 +1,238 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import jax
|
2 |
+
print(jax.local_device_count())
|
3 |
+
import jax.numpy as jnp
|
4 |
+
|
5 |
+
import flax
|
6 |
+
import flax.linen as nn
|
7 |
+
from flax.training.common_utils import get_metrics,onehot,shard,shard_prng_key
|
8 |
+
from flax.training import train_state
|
9 |
+
from flax.metrics.tensorboard import SummaryWriter
|
10 |
+
from flax.training import checkpoints
|
11 |
+
|
12 |
+
|
13 |
+
import logging
|
14 |
+
import optax
|
15 |
+
import math
|
16 |
+
from tqdm import tqdm
|
17 |
+
|
18 |
+
from pathlib import Path
|
19 |
+
from typing import Callable
|
20 |
+
from itertools import chain
|
21 |
+
from flax.metrics import tensorboard
|
22 |
+
|
23 |
+
from datasets import load_dataset,load_metric
|
24 |
+
from transformers import GPT2Config,GPT2Tokenizer
|
25 |
+
|
26 |
+
from model_file import FlaxGPT2ForMultipleChoice
|
27 |
+
|
28 |
+
logger = logging.getLogger()
|
29 |
+
logger.setLevel(logging.INFO)
|
30 |
+
|
31 |
+
def main():
|
32 |
+
|
33 |
+
|
34 |
+
tokenizer=GPT2Tokenizer.from_pretrained('gpt2',pad_token='<|endoftext|>')
|
35 |
+
|
36 |
+
dataset=load_dataset('cosmos_qa')
|
37 |
+
|
38 |
+
def preprocess(example):
|
39 |
+
example['context&question']=example['context']+example['question']
|
40 |
+
example['first_sentence']=[example['context&question'],example['context&question'],example['context&question'],example['context&question']]
|
41 |
+
example['second_sentence']=example['answer0'],example['answer1'],example['answer2'],example['answer3']
|
42 |
+
return example
|
43 |
+
|
44 |
+
train_dataset=dataset['train'].map(preprocess)
|
45 |
+
validation_dataset=dataset['validation'].map(preprocess)
|
46 |
+
test_dataset=dataset['test'].map(preprocess)
|
47 |
+
|
48 |
+
#Remove after experiment
|
49 |
+
len_train_dataset=25262
|
50 |
+
len_validation_dataset=2985
|
51 |
+
len_test_dataset=6963
|
52 |
+
|
53 |
+
train_dataset=train_dataset.select(range(len_train_dataset))
|
54 |
+
test_dataset=test_dataset.select(range(len_test_dataset))
|
55 |
+
validation_dataset=validation_dataset.select(range(len_validation_dataset))
|
56 |
+
|
57 |
+
#remove_cols=train_dataset.column_names
|
58 |
+
|
59 |
+
def tokenize(examples):
|
60 |
+
a=tokenizer(examples['first_sentence'],examples['second_sentence'],padding='max_length',truncation=True,max_length=256,return_tensors='jax')
|
61 |
+
a['labels']=examples['label']
|
62 |
+
return a
|
63 |
+
|
64 |
+
train_dataset=train_dataset.map(tokenize)
|
65 |
+
validation_dataset=validation_dataset.map(tokenize)
|
66 |
+
test_dataset=test_dataset.map(tokenize)
|
67 |
+
|
68 |
+
remov_col=['id', 'context', 'question', 'answer0', 'answer1', 'answer2', 'answer3', 'labels', 'context&question', 'first_sentence', 'second_sentence']
|
69 |
+
|
70 |
+
train_dataset=train_dataset.remove_columns(remov_col)
|
71 |
+
validation_dataset=validation_dataset.remove_columns(remov_col)
|
72 |
+
test_dataset=test_dataset.remove_columns(remov_col)
|
73 |
+
|
74 |
+
per_device_batch_size=4
|
75 |
+
seed=0
|
76 |
+
num_train_epochs=5
|
77 |
+
learning_rate=2e-5
|
78 |
+
|
79 |
+
|
80 |
+
total_batch_size = per_device_batch_size * jax.local_device_count()
|
81 |
+
print('The overall batch size (both for training and eval) is', total_batch_size)
|
82 |
+
num_train_steps = len(train_dataset) // total_batch_size * num_train_epochs
|
83 |
+
num_validation_steps=len(validation_dataset)//total_batch_size*num_train_epochs
|
84 |
+
|
85 |
+
learning_rate_function = optax.linear_schedule(init_value=learning_rate, end_value=0, transition_steps=num_train_steps)
|
86 |
+
|
87 |
+
class TrainState(train_state.TrainState):
|
88 |
+
logits_function:Callable=flax.struct.field(pytree_node=False)
|
89 |
+
loss_function:Callable=flax.struct.field(pytree_node=False)
|
90 |
+
|
91 |
+
def adamw(weight_decay):
|
92 |
+
return optax.adamw(learning_rate=learning_rate_function,b1=0.9,b2=0.99,eps=1e-6,weight_decay=weight_decay)
|
93 |
+
|
94 |
+
decay_path=lambda p:not any(x in p for x in ['bias','LayerNorm.weight'])
|
95 |
+
|
96 |
+
def traverse(function):
|
97 |
+
def mask(data):
|
98 |
+
flat=flax.traverse_util.flatten_dict(data)
|
99 |
+
return flax.traverse_util.unflatten_dict({k:function(k,v) for k,v in flat.items()})
|
100 |
+
return mask
|
101 |
+
|
102 |
+
gradient_transformation=optax.chain(
|
103 |
+
optax.masked(adamw(0.0),mask=traverse(lambda path,_:decay_path(path))),
|
104 |
+
optax.masked(adamw(0.01),mask=traverse(lambda path,_:not decay_path(path))))
|
105 |
+
|
106 |
+
def loss_function(logits,labels):
|
107 |
+
logits=flax.linen.log_softmax(logits)
|
108 |
+
xentropy=optax.softmax_cross_entropy(logits,onehot(labels,num_classes=4))
|
109 |
+
return jnp.mean(xentropy)
|
110 |
+
|
111 |
+
def eval_function(logits):
|
112 |
+
return logits.argmax(-1)
|
113 |
+
|
114 |
+
model = FlaxGPT2ForMultipleChoice.from_pretrained('gpt2-medium',input_shape=(1,4,1))
|
115 |
+
|
116 |
+
state=TrainState.create(apply_fn=model.__call__,
|
117 |
+
params=model.params,
|
118 |
+
tx=gradient_transformation,
|
119 |
+
logits_function=eval_function,
|
120 |
+
loss_function=loss_function)
|
121 |
+
|
122 |
+
def train_step(state,batch,dropout_rng):
|
123 |
+
targets=batch.pop("label")
|
124 |
+
dropout_rng,new_dropout_rng=jax.random.split(dropout_rng)
|
125 |
+
def loss_function(params):
|
126 |
+
logits=state.apply_fn(**batch,params=params,dropout_rng=dropout_rng,train=True)[0]
|
127 |
+
loss=state.loss_function(logits,targets)
|
128 |
+
return loss
|
129 |
+
grad_function=jax.value_and_grad(loss_function)
|
130 |
+
loss,grad=grad_function(state.params)
|
131 |
+
grad=jax.lax.pmean(grad,"batch")
|
132 |
+
new_state=state.apply_gradients(grads=grad)
|
133 |
+
#Added.
|
134 |
+
logits=new_state.apply_fn(**batch,params=new_state.params,dropout_rng=dropout_rng,train=True)[0]
|
135 |
+
accuracy=jnp.equal(jnp.argmax(logits,axis=-1),targets)
|
136 |
+
metrics=jax.lax.pmean({"loss":loss,"learning_rate":learning_rate_function(state.step),'accuracy':accuracy},axis_name="batch")
|
137 |
+
return new_state,metrics,new_dropout_rng
|
138 |
+
|
139 |
+
parallel_train_step = jax.pmap(train_step, axis_name="batch", donate_argnums=(0,))
|
140 |
+
|
141 |
+
def eval_step(state, batch):
|
142 |
+
targets=batch.pop('label')
|
143 |
+
logits = state.apply_fn(**batch, params=state.params, train=False)
|
144 |
+
loss=state.loss_function(logits,targets)
|
145 |
+
predictions=state.logits_function(logits)
|
146 |
+
eval_accuracy=jnp.equal(predictions,targets)
|
147 |
+
#eval_acc=jnp.equal(predictions,targets)
|
148 |
+
metrics=jax.lax.pmean({"loss":loss,'accuracy':eval_accuracy},axis_name="batch")
|
149 |
+
#return state.logits_function(logits) #(8,4)
|
150 |
+
return targets,predictions,metrics
|
151 |
+
|
152 |
+
parallel_eval_step = jax.pmap(eval_step, axis_name="batch")
|
153 |
+
|
154 |
+
def glue_train_data_loader(rng,dataset,batch_size):
|
155 |
+
steps_per_epoch=len_train_dataset//batch_size
|
156 |
+
perms=jax.random.permutation(rng,len(dataset))
|
157 |
+
perms=perms[:steps_per_epoch*batch_size]
|
158 |
+
perms=perms.reshape((steps_per_epoch,batch_size))
|
159 |
+
for perm in perms:
|
160 |
+
batch=dataset[perm]
|
161 |
+
batch={k:jnp.array(v) for k,v in batch.items()}
|
162 |
+
batch=shard(batch)
|
163 |
+
yield batch
|
164 |
+
|
165 |
+
rng=jax.random.PRNGKey(seed)
|
166 |
+
dropout_rngs=jax.random.split(rng,jax.local_device_count())
|
167 |
+
|
168 |
+
def glue_eval_data_loader(dataset, batch_size):
|
169 |
+
for i in range(len_validation_dataset // batch_size):
|
170 |
+
batch = dataset[i * batch_size : (i + 1) * batch_size]
|
171 |
+
batch = {k: jnp.array(v) for k, v in batch.items()}
|
172 |
+
batch = shard(batch)
|
173 |
+
|
174 |
+
yield batch
|
175 |
+
|
176 |
+
state = flax.jax_utils.replicate(state)
|
177 |
+
#metrics_list = list_metrics()
|
178 |
+
|
179 |
+
actual_task = "mnli"
|
180 |
+
metric = load_metric('glue', "mnli")
|
181 |
+
actual_taskmetric = load_metric('glue', actual_task)
|
182 |
+
|
183 |
+
workdir='./results_tensorboard'
|
184 |
+
summary_writer = tensorboard.SummaryWriter(workdir)
|
185 |
+
#summary_writer.hparams(dict(GPT2Config()))
|
186 |
+
|
187 |
+
logger.info(f"***** Running training *****")
|
188 |
+
logger.info(f" Num examples = {len_train_dataset}")
|
189 |
+
logger.info(f" Num Epochs = {1}")
|
190 |
+
logger.info(f" Instantaneous batch size per device = {per_device_batch_size}")
|
191 |
+
logger.info(f" Total train batch size = {total_batch_size}")
|
192 |
+
logger.info(f" Total optimization steps = {num_train_steps}")
|
193 |
+
|
194 |
+
for i, epoch in enumerate(tqdm(range(1, num_train_epochs+1), desc=f"Epoch ...", position=0, leave=True)):
|
195 |
+
rng, input_rng = jax.random.split(rng)
|
196 |
+
train_acc_metrics=[]
|
197 |
+
train_loss_metrics=[]
|
198 |
+
eval_acc_metrics=[]
|
199 |
+
eval_loss_metrics=[]
|
200 |
+
# train
|
201 |
+
with tqdm(total=len_train_dataset // total_batch_size, desc="Training...", leave=False) as progress_bar_train:
|
202 |
+
for idx,batch in enumerate(glue_train_data_loader(input_rng, train_dataset, total_batch_size)):
|
203 |
+
state, train_metric, dropout_rngs = parallel_train_step(state, batch, dropout_rngs)
|
204 |
+
train_acc_metrics.append(jax.device_get(train_metric['accuracy']).mean().item())
|
205 |
+
train_loss_metrics.append(flax.jax_utils.unreplicate(train_metric)['loss'].item())
|
206 |
+
if idx%5==0:
|
207 |
+
summary_writer.scalar('train_loss',flax.jax_utils.unreplicate(train_metric)['loss'].item(),idx)
|
208 |
+
summary_writer.scalar('train_accuracy', jax.device_get(train_metric['accuracy']).mean().item(),idx)
|
209 |
+
|
210 |
+
progress_bar_train.update(1)
|
211 |
+
|
212 |
+
# evaluate
|
213 |
+
with tqdm(total=len_validation_dataset // total_batch_size, desc="Evaluating...", leave=False) as progress_bar_eval:
|
214 |
+
for idx,batch in enumerate(glue_eval_data_loader(validation_dataset, total_batch_size)):
|
215 |
+
labels,predictions,eval_metric=parallel_eval_step(state, batch)
|
216 |
+
eval_acc_metrics.append(jax.device_get(eval_metric['accuracy']).mean().item())
|
217 |
+
eval_loss_metrics.append(flax.jax_utils.unreplicate(eval_metric)['loss'].item())
|
218 |
+
progress_bar_eval.update(1)
|
219 |
+
if idx%5==0:
|
220 |
+
logger.info(f"eval_step_loss{idx}: {flax.jax_utils.unreplicate(eval_metric)['loss'].item()} eval_step_acc{idx}: {jax.device_get(eval_metric['accuracy']).mean().item()}")
|
221 |
+
summary_writer.scalar('eval_loss',flax.jax_utils.unreplicate(eval_metric)['loss'].item(),idx)
|
222 |
+
summary_writer.scalar('eval_accuracy', jax.device_get(eval_metric['accuracy']).mean().item(),idx)
|
223 |
+
|
224 |
+
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
|
225 |
+
model.save_pretrained(
|
226 |
+
'.',
|
227 |
+
params=params,
|
228 |
+
push_to_hub=True,
|
229 |
+
commit_message=f"Saving weights of epoch {epoch} at step {idx}",)
|
230 |
+
|
231 |
+
#correct
|
232 |
+
logger.info(f"---------------------Epoch {epoch} done-----------------")
|
233 |
+
logger.info(f"Train loss: {jax.device_get(jnp.array(train_loss_metrics)).mean().item()} Train accuracy: {jax.device_get(jnp.array(train_acc_metrics)).mean().item()}")
|
234 |
+
logger.info(f"Eval loss: {jax.device_get(jnp.array(eval_loss_metrics)).mean().item()} Eval accuracy: {jax.device_get(jnp.array(eval_acc_metrics)).mean().item()}")
|
235 |
+
summary_writer.flush()
|
236 |
+
|
237 |
+
if __name__ == "__main__":
|
238 |
+
main()
|