winglian commited on
Commit
0c6f928
1 Parent(s): eea2731

address PR feedback

Browse files
examples/pythia-12b/README.md CHANGED
@@ -1,4 +1,4 @@
1
- # Python 12B
2
 
3
  - Single-GPU A100 only (?)
4
 
 
1
+ # Pythia 12B
2
 
3
  - Single-GPU A100 only (?)
4
 
examples/pythia-12b/config.yml CHANGED
@@ -22,7 +22,7 @@ lora_dropout: 0.0
22
  lora_target_modules:
23
  lora_target_linear: true
24
  lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific
25
- wandb_project: pythia-12b
26
  wandb_watch:
27
  wandb_run_id:
28
  wandb_log_model:
@@ -45,5 +45,5 @@ resume_from_checkpoint:
45
  local_rank:
46
  gradient_checkpointing: true
47
  fsdp:
48
- fsdp_transformer_layer_cls_to_wrap:
49
  collator_pad_to_longest: true
 
22
  lora_target_modules:
23
  lora_target_linear: true
24
  lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific
25
+ wandb_project:
26
  wandb_watch:
27
  wandb_run_id:
28
  wandb_log_model:
 
45
  local_rank:
46
  gradient_checkpointing: true
47
  fsdp:
48
+ fsdp_config:
49
  collator_pad_to_longest: true
scripts/finetune.py CHANGED
@@ -208,7 +208,10 @@ def train(
208
  )
209
  else:
210
  train_dataset = load_pretraining_dataset(
211
- cfg.pretraining_dataset, tokenizer, max_tokens=cfg.sequence_len
 
 
 
212
  )
213
  # https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
214
  train_dataset = train_dataset.with_format("torch")
 
208
  )
209
  else:
210
  train_dataset = load_pretraining_dataset(
211
+ cfg.pretraining_dataset,
212
+ tokenizer,
213
+ max_tokens=cfg.sequence_len,
214
+ seed=cfg.seed,
215
  )
216
  # https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
217
  train_dataset = train_dataset.with_format("torch")
src/axolotl/utils/data.py CHANGED
@@ -505,10 +505,10 @@ def encode_pretraining(tokenizer, max_tokens, examples):
505
  return ret
506
 
507
 
508
- def load_pretraining_dataset(path, tokenizer, max_tokens=2048):
509
  encode = functools.partial(encode_pretraining, tokenizer, max_tokens)
510
  dataset = load_dataset(path, streaming=True, split="train")
511
- dataset = dataset.shuffle(seed=42, buffer_size=10_000)
512
  # TODO dynamically figure out which columns/features to remove
513
  dataset = dataset.map(encode, batched=True, remove_columns=["text", "meta"])
514
  return dataset
 
505
  return ret
506
 
507
 
508
+ def load_pretraining_dataset(path, tokenizer, max_tokens=2048, seed=42):
509
  encode = functools.partial(encode_pretraining, tokenizer, max_tokens)
510
  dataset = load_dataset(path, streaming=True, split="train")
511
+ dataset = dataset.shuffle(seed=seed, buffer_size=10_000)
512
  # TODO dynamically figure out which columns/features to remove
513
  dataset = dataset.map(encode, batched=True, remove_columns=["text", "meta"])
514
  return dataset
src/axolotl/utils/trainer.py CHANGED
@@ -1,7 +1,6 @@
1
  """Module containing the Trainer class and related functions"""
2
 
3
  import importlib
4
- import logging
5
  import math
6
  import os
7
  import sys
@@ -232,7 +231,6 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
232
  callbacks.append(SavePeftModelCallback)
233
 
234
  if hasattr(model, "use_bettertransformer") and model.use_bettertransformer is True:
235
- logging.info("Setting up SaveBetterTransformerModelCallback.")
236
  callbacks.append(SaveBetterTransformerModelCallback)
237
 
238
  data_collator_kwargs = {
 
1
  """Module containing the Trainer class and related functions"""
2
 
3
  import importlib
 
4
  import math
5
  import os
6
  import sys
 
231
  callbacks.append(SavePeftModelCallback)
232
 
233
  if hasattr(model, "use_bettertransformer") and model.use_bettertransformer is True:
 
234
  callbacks.append(SaveBetterTransformerModelCallback)
235
 
236
  data_collator_kwargs = {