winglian commited on
Commit
247825b
1 Parent(s): cb9a887

refactor inference, warn if model is frozen

Browse files
scripts/finetune.py CHANGED
@@ -6,9 +6,11 @@ import random
6
  import signal
7
  import sys
8
  from pathlib import Path
 
9
 
10
  import fire
11
  import torch
 
12
  import yaml
13
  from attrdict import AttrDefault
14
 
@@ -46,6 +48,15 @@ def choose_device(cfg):
46
  cfg.device_map = {"": cfg.device}
47
 
48
 
 
 
 
 
 
 
 
 
 
49
  def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"):
50
  tokenizer.add_special_tokens({"unk_token": "<unk>"})
51
  tokenizer.add_special_tokens({"bos_token": "<s>"})
@@ -55,8 +66,7 @@ def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"):
55
 
56
  while True:
57
  # support for multiline inputs
58
- print("Give me an instruction (Ctrl + D to finish): ")
59
- instruction = pathlib.Path("/proc/self/fd/0").read_text()
60
  if not instruction:
61
  return
62
  prompt = prompter_module().build_prompt(instruction=instruction)
@@ -66,7 +76,7 @@ def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"):
66
  with torch.no_grad():
67
  # gc = GenerationConfig() # TODO swap out and use this
68
  generated = model.generate(
69
- inputs=batch["input_ids"].to("cuda"),
70
  do_sample=True,
71
  use_cache=True,
72
  repetition_penalty=1.1,
 
6
  import signal
7
  import sys
8
  from pathlib import Path
9
+ from typing import Optional
10
 
11
  import fire
12
  import torch
13
+ import transformers
14
  import yaml
15
  from attrdict import AttrDefault
16
 
 
48
  cfg.device_map = {"": cfg.device}
49
 
50
 
51
+ def get_multi_line_input() -> Optional[str]:
52
+ print("Give me an instruction (Ctrl + Z to finish): ")
53
+ instruction = ""
54
+ for line in sys.stdin:
55
+ instruction += line
56
+ # instruction = pathlib.Path("/proc/self/fd/0").read_text()
57
+ return instruction
58
+
59
+
60
  def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"):
61
  tokenizer.add_special_tokens({"unk_token": "<unk>"})
62
  tokenizer.add_special_tokens({"bos_token": "<s>"})
 
66
 
67
  while True:
68
  # support for multiline inputs
69
+ instruction = get_multi_line_input()
 
70
  if not instruction:
71
  return
72
  prompt = prompter_module().build_prompt(instruction=instruction)
 
76
  with torch.no_grad():
77
  # gc = GenerationConfig() # TODO swap out and use this
78
  generated = model.generate(
79
+ inputs=batch["input_ids"].to(cfg.device),
80
  do_sample=True,
81
  use_cache=True,
82
  repetition_penalty=1.1,
src/axolotl/utils/models.py CHANGED
@@ -183,6 +183,12 @@ def load_model(
183
  model.is_parallelizable = True
184
  model.model_parallel = True
185
 
 
 
 
 
 
 
186
 
187
  # TODO resume_from_checkpoint handling
188
  return model, tokenizer, lora_config
 
183
  model.is_parallelizable = True
184
  model.model_parallel = True
185
 
186
+ requires_grad = []
187
+ for name, param in model.named_parameters(recurse=True):
188
+ if param.requires_grad:
189
+ requires_grad.append(f"{name}: {param.requires_grad}")
190
+ if len(requires_grad) == 0:
191
+ logging.warning("there are no parameters that require gradient updates")
192
 
193
  # TODO resume_from_checkpoint handling
194
  return model, tokenizer, lora_config
src/axolotl/utils/trainer.py CHANGED
@@ -105,7 +105,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
105
  run_name=cfg.wandb_run_id if cfg.use_wandb else None,
106
  optim=cfg.optimizer if cfg.optimizer else None,
107
  lr_scheduler_type=cfg.lr_scheduler if cfg.lr_scheduler not in ("one_cycle", "log_sweep") else "cosine",
108
- weight_decay=cfg.weight_decay if cfg.weight_decay else 0.0,
109
  **training_arguments_kwargs,
110
  )
111
 
 
105
  run_name=cfg.wandb_run_id if cfg.use_wandb else None,
106
  optim=cfg.optimizer if cfg.optimizer else None,
107
  lr_scheduler_type=cfg.lr_scheduler if cfg.lr_scheduler not in ("one_cycle", "log_sweep") else "cosine",
108
+ weight_decay=cfg.weight_decay if cfg.weight_decay is not None else 0.0,
109
  **training_arguments_kwargs,
110
  )
111