winglian commited on
Commit
4fd0c2d
·
unverified ·
2 Parent(s): 8d6a289 943961f

Merge pull request #57 from OpenAccess-AI-Collective/fixes-for-basic-samples

Browse files
examples/lora-openllama-3b/config.yml ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ base_model: openlm-research/open_llama_3b_600bt_preview
2
+ base_model_config: openlm-research/open_llama_3b_600bt_preview
3
+ model_type: LlamaForCausalLM
4
+ tokenizer_type: LlamaTokenizer
5
+ load_in_8bit: true
6
+ load_in_4bit: false
7
+ strict: false
8
+ push_dataset_to_hub:
9
+ datasets:
10
+ - path: teknium/GPT4-LLM-Cleaned
11
+ type: alpaca
12
+ dataset_prepared_path: last_run_prepared
13
+ val_set_size: 0.02
14
+ adapter: lora
15
+ lora_model_dir:
16
+ sequence_len: 256
17
+ max_packed_sequence_len:
18
+ lora_r: 8
19
+ lora_alpha: 16
20
+ lora_dropout: 0.0
21
+ lora_target_modules:
22
+ - gate_proj
23
+ - down_proj
24
+ - up_proj
25
+ - q_proj
26
+ - v_proj
27
+ - k_proj
28
+ - o_proj
29
+ lora_fan_in_fan_out:
30
+ wandb_project:
31
+ wandb_watch:
32
+ wandb_run_id:
33
+ wandb_log_model:
34
+ output_dir: ./lora-out
35
+ batch_size: 16
36
+ micro_batch_size: 4
37
+ num_epochs: 3
38
+ optimizer: adamw_bnb_8bit
39
+ torchdistx_path:
40
+ lr_scheduler: cosine
41
+ learning_rate: 0.0002
42
+ train_on_inputs: false
43
+ group_by_length: false
44
+ bf16: false
45
+ fp16: true
46
+ tf32: false
47
+ gradient_checkpointing: true
48
+ early_stopping_patience:
49
+ resume_from_checkpoint:
50
+ local_rank:
51
+ logging_steps: 1
52
+ xformers_attention: true
53
+ flash_attention:
54
+ gptq_groupsize:
55
+ gptq_model_v1:
56
+ warmup_steps: 10
57
+ eval_steps: 50
58
+ save_steps:
59
+ debug:
60
+ deepspeed:
61
+ weight_decay: 0.0
62
+ fsdp:
63
+ fsdp_config:
64
+ special_tokens:
65
+ bos_token: "<s>"
66
+ eos_token: "</s>"
67
+ unk_token: "<unk>"
src/axolotl/prompters.py CHANGED
@@ -17,8 +17,8 @@ class AlpacaPrompter:
17
  system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n"
18
  prompt_style = None
19
 
20
- def __init__(self, prompt_style="instruct"):
21
- self.prompt_style = prompt_style
22
  self.match_prompt_style()
23
 
24
  def match_prompt_style(self):
 
17
  system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n"
18
  prompt_style = None
19
 
20
+ def __init__(self, prompt_style=PromptStyle.instruct.value):
21
+ self.prompt_style = prompt_style if prompt_style else PromptStyle.instruct.value
22
  self.match_prompt_style()
23
 
24
  def match_prompt_style(self):
src/axolotl/utils/models.py CHANGED
@@ -211,12 +211,12 @@ def load_model(
211
  try:
212
  if is_llama_derived_model and "LlamaTokenizer" in globals():
213
  tokenizer = LlamaTokenizer.from_pretrained(
214
- model,
215
  trust_remote_code=True if cfg.trust_remote_code is True else False,
216
  )
217
  else:
218
  tokenizer = getattr(transformers, tokenizer_type).from_pretrained(
219
- model,
220
  trust_remote_code=True if cfg.trust_remote_code is True else False,
221
  )
222
  except:
 
211
  try:
212
  if is_llama_derived_model and "LlamaTokenizer" in globals():
213
  tokenizer = LlamaTokenizer.from_pretrained(
214
+ base_model_config,
215
  trust_remote_code=True if cfg.trust_remote_code is True else False,
216
  )
217
  else:
218
  tokenizer = getattr(transformers, tokenizer_type).from_pretrained(
219
+ base_model_config,
220
  trust_remote_code=True if cfg.trust_remote_code is True else False,
221
  )
222
  except: