mjbuehler commited on
Commit
c14b44d
·
verified ·
1 Parent(s): 541c55e

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +81 -1
README.md CHANGED
@@ -211,7 +211,7 @@ moe_model.set_gating_layer_params(gating_layer_params)
211
 
212
  ![image/png](https://cdn-uploads.huggingface.co/production/uploads/623ce1c6b66fedf374859fe7/xzZwBIw1yYr9v7xYblCNZ.png)
213
 
214
- ### Peparing gating network for training
215
 
216
  To freeze all parameters in the model except for the gating neural networks, you can use:
217
 
@@ -224,6 +224,86 @@ You can unfreeze:
224
  un_freeze_all(moe_model)
225
  ```
226
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
227
  ## Inference
228
 
229
  ### Chat Format
 
211
 
212
  ![image/png](https://cdn-uploads.huggingface.co/production/uploads/623ce1c6b66fedf374859fe7/xzZwBIw1yYr9v7xYblCNZ.png)
213
 
214
+ ### Peparing gating network for full training
215
 
216
  To freeze all parameters in the model except for the gating neural networks, you can use:
217
 
 
224
  un_freeze_all(moe_model)
225
  ```
226
 
227
+ Define FT_repo_id to push on HF hub/save model:
228
+ ```
229
+ FT_repo_id='xxxxx/' #<repo_ID>
230
+ ```
231
+
232
+ ```
233
+ from datasets import load_dataset
234
+
235
+ train_dataset = load_dataset("lamm-mit/Cephalo-Wikipedia-Materials", split="train")
236
+ ```
237
+
238
+ ```python
239
+ import random
240
+
241
+ class MyDataCollator:
242
+ def __init__(self, processor):
243
+ self.processor = processor
244
+
245
+ def __call__(self, examples):
246
+ texts = []
247
+ images = []
248
+ for example in examples:
249
+ image = example["image"]
250
+ question = example["query"]
251
+ answer = example["answer"]
252
+ messages = [ {
253
+ "role": "user", "content": '<|image_1|>\n'+question},
254
+ {"role": "assistant", "content": f"{answer}"}, ]
255
+
256
+ text = processor.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
257
+
258
+ images.append(image)
259
+
260
+ batch = processor(text=text, images=[image], return_tensors="pt", padding=True
261
+
262
+ labels = batch["input_ids"].clone()
263
+ labels[labels <0] = -100
264
+
265
+ batch["labels"] = labels
266
+
267
+ return batch
268
+
269
+ data_collator = MyDataCollator(processor)
270
+ ```
271
+ Then set up trainer, and train:
272
+ ```python
273
+ from transformers import TrainingArguments, Trainer
274
+
275
+ optim = "paged_adamw_8bit"
276
+
277
+ training_args = TrainingArguments(
278
+ num_train_epochs=2,
279
+ per_device_train_batch_size=1,
280
+ gradient_accumulation_steps=4,
281
+ warmup_steps=250,
282
+ learning_rate=1e-5,
283
+ weight_decay=0.01,
284
+ logging_steps=25,
285
+ output_dir="output_training",
286
+ optim=optim,
287
+ save_strategy="steps",
288
+ save_steps=1000,
289
+ save_total_limit=16,
290
+ #fp16=True,
291
+ bf16=True,
292
+ push_to_hub_model_id=FT_repo_id,
293
+ remove_unused_columns=False,
294
+ report_to="none",
295
+ )
296
+
297
+ trainer = Trainer(
298
+ model=moe_model,
299
+ args=training_args,
300
+ data_collator=data_collator,
301
+ train_dataset=train_dataset,
302
+ )
303
+
304
+ trainer.train()
305
+ ```
306
+
307
  ## Inference
308
 
309
  ### Chat Format